ToB企服应用市场:ToB评测及商务社交产业平台
标题:
Stable Diffusion 原理先容与源码分析(一)
[打印本页]
作者:
悠扬随风
时间:
2024-7-29 00:32
标题:
Stable Diffusion 原理先容与源码分析(一)
Stable Diffusion 原理先容与源码分析(一)
前言(与正文无关,可以忽略)
Stable Diffusion 是 Stability AI 公司开源的 AI 文生图扩散模型。之前在文章 扩散模型 (Diffusion Model) 扼要先容与源码分析 中先容了扩散模型的原理与部分算法代码,满足根本的好奇心后便将其束之高阁,没成想近期 AIGC 的发展速度之快大大出乎我的料想,尤其是亲手跑出下面这张 AI 天生的图像, Stable Diffusion 终又重新回到我的视野:
作为一名算法工程师,需要有一双能看破事物本质的眼睛,这张图片最先吸引我的不是内容,而是其天生质量:图像高清、细节丰富,非之前看到的一些粗陋 Toy 可比,红框中标注出来的不和谐之处,也是瑕不掩瑜。因此,进一步分析 Stable Diffusion 整个工程框架的原理,着实是迫在眉睫,期待日后能修复红框中的不和谐之处,为 AIGC 的进一步发展做出一个技能人员应有的贡献。
总览
Stable Diffusion 整个框架的源码有上万行,没有必要全部分析。本文以 “文本天生图像(text to image)” 为主线,观察 Stable Diffusion 的运行流程以及各个告急的构成模块,在先容时采用 “总-分” 的形式,先概括整体框架,再分析各个组件(如 DDPM、DDIM 等),另外针对代码中的部分非主流逻辑,比如 predict_cids、return_ids 这些小细节谈谈我的看法。文章内容较长,准备拆分成多个部分。
源码地点:Stable Diffusion
说明
之前我写过很多代码分析文章,但在我遇到问题重新去翻阅时,发现要快速定位到目标位置并准确理解代码意图,仍然存在很大困难,密密麻麻的整块代码,每一次阅读都仿若初见,不易理解,原因在于摘录时引入过多的实现细节,降低了信息的传播效率。
经过一番思考,我不再图省事,决定采用伪代码的方式记录核心原理。平时我深度分析代码时会采用这种方式,对代码进行额外的抽象,相对会耗些时间,但私以为这是有益处的。举个例子,比如 DDPM 模型前向 Diffusion 的代码,如果我用伪代码的方式去写,将是如下的效果:
可以看到,刨除掉无关的实现细节之后,DDPM 的实现是云云的简便,倘若再配合一定的注释,可方便快速理解,让人获得一种整体而全面的掌控感。此外还应该在文中多增长框图、模型图等来对代码的实现细节进行更直观的展示。
可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号, 可以及时获取最新原创技能文章更新.
另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中.
Stable Diffusion 整体框架
首先看下 Stable Diffusion 文本天生图像整体框架(文章绘图吐血…希望有一天 AI 能进行辅助):
上图框架内的模块较多,从上到下分为 3 块,我在图中使用 Part 1、2、3 进行了标注。框架包罗练习 + 采样两个阶段,其中:
练习阶段 (查看图中 Part 1 和 Part 2),重要包罗:
使用 AutoEncoderKL 自编码器将图像 Image 从 pixel space 映射到 latent space,学习图像的隐式表达,留意 AutoEncoderKL 编码器已提前练习好,参数是固定的。此时 Image 的巨细将从 [B, C, H, W] 转换为 [B, Z, H/8, W/8],其中 Z 表现 latent space 下图像的 Channel 数。这一过程在 Stable Diffusion 代码中被称为 encode_first_stage;
使用 FrozenCLIPEmbedder 文本编码器对 Prompt 提示词进行编码,天生巨细为 [B, K, E] 的 embedding 表现(即 context),其中 K 表现文本最大编码长度 max length, E 表现 embedding 的巨细。这一过程在 Stable Diffusion 代码中被称为 get_learned_conditioning;
进行前向扩散过程(Diffusion Process),对图像的隐式表达进行不断加噪,该过程调用 UNetModel 完成;UNetModel 同时接收图像的隐式表达 latent image 以及文本 embedding context,在练习时以 context 作为 condition,使用 Attention 机制来更好的学习文本与图像的匹配关系;
扩散模型输出噪声 ϵ θ \epsilon_{\theta} ϵθ,计算和真实噪声之间的偏差作为 Loss,通过反向传播算法更新 UNetModel 模型的参数,留意这个过程中 AutoEncoderKL 和 FrozenCLIPEmbedder 中的参数不会被更新。
采样阶段(查看图中 Part 2 和 Part 3),也就是我们加载模型参数后,输入提示词就能产出图像的阶段。重要包罗:
使用 FrozenCLIPEmbedder 文本编码器对 Prompt 提示词进行编码,天生巨细为 [B, K, E] 的 embedding 表现(即 context);
随机产出巨细为 [B, Z, H/8, W/8] 的噪声 Noise,使用练习好的 UNetModel 模型,按照 DDPM/DDIM/PLMS 等算法迭代 T 次,将噪声不断去除,恢复出图像的 latent 表现;
使用 AutoEncoderKL 对图像的 latent 表现(巨细为 [B, Z, H/8, W/8])进行 decode(解码),终极恢复出 pixel space 的图像,图像巨细为 [B, C, H, W]; 这一过程在 Stable Diffusion 中被称为 decode_first_stage。
经过上面的先容,对 Stable Diffusion 整了解有个较清晰的认识,下面就可以按图索骥,将各个重点模块努力去弄明确。限于个人精力与有限的空闲时间,如今除了 FrozenCLIPEmbedder 和 DPM 算法 (图中没写),Stable Diffusion 的其他模块都大抵看了看,包括:
UNetModel
AutoEncoderKL & VQModelInterface (也是一种变分自动编码器,图上没画)
DDPM、DDIM、PLMS 算法
后面会简单先容一下,记录学习过程。
告急论文
在阅读代码的过程中,发现有些重量级的论文必须得阅读一下。扩散模型的理论推导还是有些复杂的,有时候公式推导和代码实现相互结合看,可以加深对知识的理解。这里列一下对我阅读代码有很大帮助的论文:
Denoising Diffusion Probabilistic Models : DDPM,这个是必看的,推推公式
Denoising Diffusion Implicit Models :DDIM,对 DDPM 的改进
Pseudo Numerical Methods for Diffusion Models on Manifolds :PNMD/PLMS,对 DDPM 的改进
High-Resolution Image Synthesis with Latent Diffusion Models :Latent-Diffusion,必看
Neural Discrete Representation Learning : VQVAE,简单翻了翻,示意图非常形象,很容易相识其做法
告急构成模块分析
下面临 Stable Diffusion 中的告急构成模块进行扼要分析。重要包罗:
UNetModel
DDPM、DDIM、PLMS 算法
AutoEncoderKL
对部分非主流的逻辑,如 predict_cids、return_ids 等谈谈看法
首先先容一下 UNetModel 布局,方便后续的文章直接进行引用。
UNetModel 先容
画了一下 Stable Diffusion 中使用的 UNetModel,就不分析代码了,看图很容易将代码写出来。Stable Diffusion 采用 UNetModel 这种 Encoder-Decoder 布局来实现扩散的过程,对噪声进行预估, 网络布局如下:
模型的输入包罗三个部分:
巨细为 [B, C, H, W] 的图像 image;
留意不用在意表现巨细时所用的符号,应将它们视作接口,比如 UNetModel 接收巨细为 [B, Z, H/8, W/8] 的 noise latent image 作为输入时,这里的 C 就即是 Z, H 就即是 H/8, W 就即是 W/8
;
巨细为 [B,] 的 timesteps
巨细为 [B, K, E] 的文本 embedding 表现 context, 其中 K 表现最大编码长度,E 表现 embedding 巨细
模型使用 DownSample 和 UpSample 来对样本进行下采样和上采样,此外出现最多的模块是 ResBlock 以及 SpatialTransformer,其中图中每一个 ResBlock 接收来自上一个模块的输入以及 timesteps 对应的 embedding timestep_emb (巨细为 [B, 4*M],M 是可设置的参数);而图中每一个 SpatialTransformer接收来自上一个模块的输入以及 context (Prompt 文本的 embedding 表现),使用 Cross Attention,以 context 为 condition,学习 Prompt 和图像的匹配关系。但图上只在虚线框中显示了两个模块有多个输入,其他模块没有画出来)
可以看到,最后模型的输出巨细为 [B, C, H, W], 和输入巨细雷同,也就是说 UNetModel 不改变输入输出的巨细。
下面再分别看看 ResBlock、timestep_embedding、context 以及 SpatialTransformer 的实现。
ResBlock 的实现
ResBlock 网络布局图如下,它接受两个输入,图像 x 以及 timestep 对应的 embedding:
timestep_embedding 实现
timestep_embedding 的天生方式如下,用的是 Tranformer(Attention is All you Need)这篇 paper 中的方法:
Prompt 文本 embedding 的实现
即 context 的实现。Prompt 使用 CLIP 模型进行编码,我没有对 CLIP 模型详细学习,暂时也没有深入看的筹划,后续有机会再补充;代码中使用预练习好的 CLIP 天生 context:
SpatialTransformer 的实现
最后再看下 SpatialTransformer 的实现,其模块比力多,在接收图像作为输入时,还使用 context 文本作为 condition 信息,二者使用 Cross Attention 进行建模。进一步展开 SpatialTransformer, 发现包罗 BasicTransformerBlock ,它实际调用 Cross Attention 模块,而在 Cross Attention 模块中,图像信息作为 Query,文本信息作为 Key & Value,模型会关注图像和文本各部分内容的相关性:
我觉得可以用一种朴素的想法来理解这里 Cross Attention 的作用,比如练习时给定一张马吃草的图,以及文本提示词:“一匹白色的马在沙漠吃草”,在做 Attention 时,文本中的 “马” 这个关键词和图像中的动物(也是 “马”)的关联性更强,因为权重也更大,而 “一匹”、 “白色”、“沙漠”、 “草” 等权重更低;此时,当模型被练习的很好后,模型不仅将可以学习到图像和文本之间的匹配关系,通过 Attention 还可以学习到文本中的各个关键词想突出图像中哪些主体。
而当我们输入提示词用模型来天生图像时,比如输入 “一匹马在吃草”,由于模型此时已经能捕捉图像和文本的相关性以及文本中的重点信息,当它看到文本 “马”,在黑盒魔法的运作下,会重点突出图像 “马” 的天生;当它看到 “草” 时,便重点突出图像 “草” 的天生,从而尽大概天生和文本进行匹配的图像。
至此,UNetModel 各个告急组件根本先容完毕。
小结
由于 UNetModel 模型布局并不复杂,看图根本就能写出代码,一图胜千言啊。另外我
标注了每个模块输出结果的巨细
,很方便在大脑中运行模型,哈哈哈。
本文大抵先容了一下 Stable Diffusion 文生图代码的整体框架,列出了扩散模型部分核心论文,扼要分析了 UNetModel。后续再分析其他核心组件。
发现 AIGC 的发展着实太快了,学不外来啊… 愈发觉得庄子所言甚是:以有涯随无涯,殆矣!
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/)
Powered by Discuz! X3.4