AIGC专栏3——Stable Diffusion结构剖析-以图像天生图像(图生图,img2img)为例

[复制链接]
发表于 2026-2-3 19:23:53 | 显示全部楼层 |阅读模式
学习前言

用了好久的Stable Diffusion,但从来没有好好剖析过它内部的结构,写个博客记载一下,嘿嘿。

源码下载所在

https://github.com/bubbliiiing/stable-diffusion
喜欢的可以点个star噢。
网络构建

一、什么是Stable Diffusion(SD)

Stable Diffusion是比力新的一个扩散模子,翻译过来是稳固扩散,固然名字叫稳固扩散,但实际上换个seed天生的结果就完全不一样,非常不稳固哈。
Stable Diffusion最开始的应用应该是文本天生图像,即文生图,随着技能的发展Stable Diffusion不光支持image2image图生图的天生,还支持ControlNet等各种控制方法来定制天生的图像。
Stable Diffusion基于扩散模子,以是难免包罗不停去噪的过程,假如是图生图的话,另有不停加噪的过程,此时离不开DDPM那张老图,如下:

Stable Diffusion相比于DDPM,利用了DDIM采样器,利用了隐空间的扩散,别的利用了非常大的LAION-5B数据集举行预训练。
直接Finetune Stable Diffusion大多数同砚应该是无法cover住本钱的,不外Stable Diffusion有许多轻量Finetune的方案,好比Lora、Textual Inversion等,但这是后话。
本文重要是剖析一下整个SD模子的结构构成,一次扩散,多次扩散的流程。
大模子、AIGC是当前行业的趋势,不会的话容易被镌汰,hh。
txt2img的原理如博文
Diffusion扩散模子学习2——Stable Diffusion结构剖析-以文本天生图像(txt2img)为例
所示。
二、Stable Diffusion的构成

Stable Diffusion由四大部门构成。
1、Sampler采样器。
2、Variational Autoencoder (VAE) 变分自编码器。
3、UNet 主网络,噪声推测器。
4、CLIPEmbedder文本编码器。
每一部门都很紧张,我们以图像天生图像为例举行剖析。既然是图像天生图像,那么我们的输入有两个,一个是文本,别的一个是图片。
三、img2img天生流程


天生流程分为四个部门:
1、对图片举行VAE编码,根据denoise数值举行加噪声。
2、Prompt文本编码。
3、根据denoise数值举行多少次采样。
4、利用VAE举行解码。
相比于文生图,图生图的输入发生了变革,不再以Gaussian noise作为初始化,而是以加噪后的图像特性为初始化如许便以图像的方式为模子注入了信息。
详细来讲,如上图所示:


  • 第一步为对输入的图像利用VAE编码,得到输入图像的Latent特性;然后利用该Latent特性基于DDIM Sampler举行加噪,此时得到输入图片加噪后的特性。假设我们设置denoise数值为0.8,总步数为20步,那么第一步中,我们会对输入图片举行0.8x20次的加噪声,剩下4步不加,可明白为打乱了80%的特性,保存20%的特性。
  • 第二步是对输入的文本举行编码,得到文本特性;
  • 第三步是根据denoise数值对 第一步中得到的 加噪后的特性 举行多少次采样。还是以第一步中denoise数值为0.8为例,我们只加了0.8x20次噪声那么我们也只必要举行0.8x20次采样就可以规复出图片了
  • 第四步是将采样后的图片利用VAE的Decoder举行规复。
  1. with torch.no_grad():
  2.     if seed == -1:
  3.         seed = random.randint(0, 65535)
  4.     seed_everything(seed)
  5.     # ----------------------- #
  6.     #   对输入图片进行编码并加噪
  7.     # ----------------------- #
  8.     if image_path is not None:
  9.         img = HWC3(np.array(img, np.uint8))
  10.         img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
  11.         img = torch.stack([img for _ in range(num_samples)], dim=0)
  12.         img = einops.rearrange(img, 'b h w c -> b c h w').clone()
  13.         ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
  14.         t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
  15.         z = model.get_first_stage_encoding(model.encode_first_stage(img))
  16.         z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
  17.     # ----------------------- #
  18.     #   获得编码后的prompt
  19.     # ----------------------- #
  20.     cond    = {"c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
  21.     un_cond = {"c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
  22.     H, W    = input_shape
  23.     shape   = (4, H // 8, W // 8)
  24.     if image_path is not None:
  25.         samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
  26.     else:
  27.         # ----------------------- #
  28.         #   进行采样
  29.         # ----------------------- #
  30.         samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
  31.                                                         shape, cond, verbose=False, eta=eta,
  32.                                                         unconditional_guidance_scale=scale,
  33.                                                         unconditional_conditioning=un_cond)
  34.     # ----------------------- #
  35.     #   进行解码
  36.     # ----------------------- #
  37.     x_samples = model.decode_first_stage(samples)
  38.     x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
复制代码
1、输入图片编码


在图生图中,我们起主要指定一张参考的图像,然后在这个参考图像上开始工作:
1、利用VAE编码器对这张参考图像举行编码,使其进入隐空间,只有进入了隐空间,网络才知道这个图像是什么
2、然后利用该Latent特性基于DDIM Sampler举行加噪,此时得到输入图片加噪后的特性。加噪的逻辑如下:


  • denoise可以为是重修的比例,1代表全部重修,0代表不重修;
  • 假设我们设置denoise数值为0.8,总步数为20步;我们会对输入图片举行0.8x20次的加噪声,剩下4步不加,可明白为80%的特性,保存20%的特性;不外就算加完20步噪声原始输入图片的信息还是有一点保存的,不是完全不保存。
此时我们便得到在隐空间加噪后的图像,后续会在这个 隐空间加噪后的图像 的底子上举行采样。
  1. with torch.no_grad():
  2.     if seed == -1:
  3.         seed = random.randint(0, 65535)
  4.     seed_everything(seed)
  5.     # ----------------------- #
  6.     #   对输入图片进行编码并加噪
  7.     # ----------------------- #
  8.     if image_path is not None:
  9.         img = HWC3(np.array(img, np.uint8))
  10.         img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
  11.         img = torch.stack([img for _ in range(num_samples)], dim=0)
  12.         img = einops.rearrange(img, 'b h w c -> b c h w').clone()
  13.         ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
  14.         t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
  15.         z = model.get_first_stage_encoding(model.encode_first_stage(img))
  16.         z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
复制代码
2、文本编码


文本编码的思绪比力简朴,直接利用CLIP的文本编码器举行编码就可以了,在代码中界说了一个FrozenCLIPEmbedder种别,利用了transformers库的CLIPTokenizer和CLIPTextModel。
在前传过程中,我们对输入进来的文本起首利用CLIPTokenizer举行编码,然后利用CLIPTextModel举行特性提取,通过FrozenCLIPEmbedder,我们可以得到一个[batch_size, 77, 768]的特性向量。
  1. class FrozenCLIPEmbedder(AbstractEncoder):
  2.     """Uses the CLIP transformer encoder for text (from huggingface)"""
  3.     LAYERS = [
  4.         "last",
  5.         "pooled",
  6.         "hidden"
  7.     ]
  8.     def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
  9.                  freeze=True, layer="last", layer_idx=None):  # clip-vit-base-patch32
  10.         super().__init__()
  11.         assert layer in self.LAYERS
  12.         # 定义文本的tokenizer和transformer
  13.         self.tokenizer      = CLIPTokenizer.from_pretrained(version)
  14.         self.transformer    = CLIPTextModel.from_pretrained(version)
  15.         self.device         = device
  16.         self.max_length     = max_length
  17.         # 冻结模型参数
  18.         if freeze:
  19.             self.freeze()
  20.         self.layer = layer
  21.         self.layer_idx = layer_idx
  22.         if layer == "hidden":
  23.             assert layer_idx is not None
  24.             assert 0 <= abs(layer_idx) <= 12
  25.     def freeze(self):
  26.         self.transformer = self.transformer.eval()
  27.         # self.train = disabled_train
  28.         for param in self.parameters():
  29.             param.requires_grad = False
  30.     def forward(self, text):
  31.         # 对输入的图片进行分词并编码,padding直接padding到77的长度。
  32.         batch_encoding  = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
  33.                                         return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
  34.         # 拿出input_ids然后传入transformer进行特征提取。
  35.         tokens          = batch_encoding["input_ids"].to(self.device)
  36.         outputs         = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
  37.         # 取出所有的token
  38.         if self.layer == "last":
  39.             z = outputs.last_hidden_state
  40.         elif self.layer == "pooled":
  41.             z = outputs.pooler_output[:, None, :]
  42.         else:
  43.             z = outputs.hidden_states[self.layer_idx]
  44.         return z
  45.     def encode(self, text):
  46.         return self(text)
复制代码
3、采样流程


a、天生初始噪声

在图生图中,我们的初始噪声获取于参考图片,以是参考第一步就可以得到图生图的噪声
b、对噪声举行N次采样

既然Stable Diffusion是一个不停扩散的过程,那么少不了不停的去噪声,那么怎么去噪声便是一个题目。
在上一步中,我们已经得到了一个图生图的噪声,它是一个符合正态分布的向量,我们便从它开始去噪声。
我们会对ddim_timesteps的时间步取反,由于我们现在是去噪声而非加噪声,然后对其举行一个循环,由于我们此时不再是txt2img中的采样流程,我们利用sampler的别的一个方法decode,循环的代码如下:
  1. @torch.no_grad()
  2. def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
  3.            use_original_steps=False):
  4.     # 使用ddim的时间步
  5.     # 这里内容看起来很多,但是其实很少,本质上就是取了self.ddim_timesteps,然后把它reversed一下
  6.     timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
  7.     timesteps = timesteps[:t_start]
  8.     time_range = np.flip(timesteps)
  9.     total_steps = timesteps.shape[0]
  10.     print(f"Running DDIM Sampling with {total_steps} timesteps")
  11.     iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
  12.     x_dec = x_latent
  13.     for i, step in enumerate(iterator):
  14.         index = total_steps - i - 1
  15.         ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
  16.         # 进行单次采样
  17.         x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
  18.                                       unconditional_guidance_scale=unconditional_guidance_scale,
  19.                                       unconditional_conditioning=unconditional_conditioning)
  20.     return x_dec
复制代码
c、单次采样剖析

I、推测噪声

在举行单词采样前,必要起首判断是否有neg prompt,假如有,我们必要同时处置处罚neg prompt,否则仅仅必要处置处罚pos prompt。实际利用的时间一样平常都有neg prompt(结果会好一些),以是默认进入对应的处置处罚过程。
在处置处罚neg prompt时,我们对输入进来的隐向量和步数举行复制,一个属于pos prompt,一个属于neg prompt。torch.cat默认堆叠维度为0,以是是在batch_size维度举行堆叠,二者不会相互影响。然后我们将pos prompt和neg prompt堆叠到一个batch中,也是在batch_size维度堆叠。
  1. # 首先判断是否由neg prompt,unconditional_conditioning是由neg prompt获得的
  2. if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
  3.     e_t = self.model.apply_model(x, t, c)
  4. else:
  5.     # 一般都是有neg prompt的,所以进入到这里
  6.     # 在这里我们对隐向量和步数进行复制,一个属于pos prompt,一个属于neg prompt
  7.     # torch.cat默认堆叠维度为0,所以是在bs维度进行堆叠,二者不会互相影响
  8.     x_in = torch.cat([x] * 2)
  9.     t_in = torch.cat([t] * 2)
  10.     # 然后我们将pos prompt和neg prompt堆叠到一个batch中
  11.     if isinstance(c, dict):
  12.         assert isinstance(unconditional_conditioning, dict)
  13.         c_in = dict()
  14.         for k in c:
  15.             if isinstance(c[k], list):
  16.                 c_in[k] = [
  17.                     torch.cat([unconditional_conditioning[k][i], c[k][i]])
  18.                     for i in range(len(c[k]))
  19.                 ]
  20.             else:
  21.                 c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
  22.     else:
  23.         c_in = torch.cat([unconditional_conditioning, c])
复制代码

堆叠完后,我们将隐向量、步数和prompt条件一起传入网络中,将结果在bs维度举行利用chunk举行分割。
由于我们在堆叠时,neg prompt放在了前面。因此分割好后,前半部门e_t_uncond属于利用neg prompt得到的,后半部门e_t属于利用pos prompt得到的,我们本质上应该扩大pos prompt的影响,阔别neg prompt的影响。因此,我们利用e_t-e_t_uncond盘算二者的间隔,利用scale扩大二者的间隔。在e_t_uncond底子上,得到末了的隐向量。
  1. # 堆叠完后,隐向量、步数和prompt条件一起传入网络中,将结果在bs维度进行使用chunk进行分割
  2. e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
  3. e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
复制代码

此时得到的e_t就是通过隐向量和prompt共同得到的推测噪声啦。
II、施加噪声

得到噪声就OK了吗?显然不是的,我们还要将得到的新噪声,按照肯定的比例添加到原来的原始噪声上。
这个地方我们最好团结ddim中的公式来看,我们必要得到                                                 α                            ˉ                                  t                                  \bar{\alpha}_t               αˉt​、                                                 α                            ˉ                                            t                            −                            1                                           \bar{\alpha}_{t-1}               αˉt−1​、                                       σ                         t                                  \sigma_t               σt​、                                                 1                            −                                                   α                                  ˉ                                          t                                                     \sqrt{1-\bar{\alpha}_t}               1−αˉt​           ​。


代码中,我们实在已经预先盘算好了这些参数。我们只必要直接取出即可,下方的a_t也就是公式中括号外的                                                 α                            ˉ                                  t                                  \bar{\alpha}_t               αˉt​,a_prev 就是公式中的                                                 α                            ˉ                                            t                            −                            1                                           \bar{\alpha}_{t-1}               αˉt−1​,sigma_t就是公式中的                                       σ                         t                                  \sigma_t               σt​,sqrt_one_minus_at就是公式中的                                                 1                            −                                                   α                                  ˉ                                          t                                                     \sqrt{1-\bar{\alpha}_t}               1−αˉt​           ​。
  1. # 根据采样器选择参数
  2. alphas      = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
  3. alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
  4. sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
  5. sigmas      = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
  6. # 根据步数选择参数,
  7. # 这里的index就是上面循环中的total_steps - i - 1
  8. a_t         = torch.full((b, 1, 1, 1), alphas[index], device=device)
  9. a_prev      = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
  10. sigma_t     = torch.full((b, 1, 1, 1), sigmas[index], device=device)
  11. sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
复制代码
实在这一步我们只是把公式必要用到的系数全都拿了出来,方便背面的加减乘除。然后我们便在代码中实现上述的公式。
  1. # current prediction for x_0
  2. # 公式中的最左边
  3. pred_x0             = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
  4. if quantize_denoised:
  5.     pred_x0, _, *_  = self.model.first_stage_model.quantize(pred_x0)
  6. # direction pointing to x_t
  7. # 公式的中间
  8. dir_xt              = (1. - a_prev - sigma_t**2).sqrt() * e_t
  9. # 公式最右边
  10. noise               = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
  11. if noise_dropout > 0.:
  12.     noise           = torch.nn.functional.dropout(noise, p=noise_dropout)
  13. x_prev              = a_prev.sqrt() * pred_x0 + dir_xt + noise
  14. # 输出添加完公式的结果
  15. return x_prev, pred_x0
复制代码

d、推测噪声过程中的网络结构剖析

I、apply_model方法剖析

在3.a的推测噪声过程中,我们利用了model.apply_model方法举行噪声的推测,这个方法详细做了什么被隐掉了,我们看看详细做的工作。
apply_model方法在ldm.models.diffusion.ddpm.py文件中。在apply_model中,我们将x_noisy传入self.model中推测噪声。
  1. x_recon = self.model(x_noisy, t, **cond)
复制代码

self.model是一个预先构建好的类,界说在ldm.models.diffusion.ddpm.py文件的1416行,内部包罗Stable Diffusion的Unet网络,self.model的功能有点雷同于包装器,根据模子选择的特性融合方式,举行文本与上文天生的噪声的融合。
c_concat代表利用堆叠的方式举行融合,c_crossattn代表利用attention的方式融合。
  1. class DiffusionWrapper(pl.LightningModule):
  2.     def __init__(self, diff_model_config, conditioning_key):
  3.         super().__init__()
  4.         self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
  5.         # stable diffusion的unet网络
  6.         self.diffusion_model = instantiate_from_config(diff_model_config)
  7.         self.conditioning_key = conditioning_key
  8.         assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
  9.     def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
  10.         if self.conditioning_key is None:
  11.             out = self.diffusion_model(x, t)
  12.         elif self.conditioning_key == 'concat':
  13.             xc = torch.cat([x] + c_concat, dim=1)
  14.             out = self.diffusion_model(xc, t)
  15.         elif self.conditioning_key == 'crossattn':
  16.             if not self.sequential_cross_attn:
  17.                 cc = torch.cat(c_crossattn, 1)
  18.             else:
  19.                 cc = c_crossattn
  20.             out = self.diffusion_model(x, t, context=cc)
  21.         elif self.conditioning_key == 'hybrid':
  22.             xc = torch.cat([x] + c_concat, dim=1)
  23.             cc = torch.cat(c_crossattn, 1)
  24.             out = self.diffusion_model(xc, t, context=cc)
  25.         elif self.conditioning_key == 'hybrid-adm':
  26.             assert c_adm is not None
  27.             xc = torch.cat([x] + c_concat, dim=1)
  28.             cc = torch.cat(c_crossattn, 1)
  29.             out = self.diffusion_model(xc, t, context=cc, y=c_adm)
  30.         elif self.conditioning_key == 'crossattn-adm':
  31.             assert c_adm is not None
  32.             cc = torch.cat(c_crossattn, 1)
  33.             out = self.diffusion_model(x, t, context=cc, y=c_adm)
  34.         elif self.conditioning_key == 'adm':
  35.             cc = c_crossattn[0]
  36.             out = self.diffusion_model(x, t, y=cc)
  37.         else:
  38.             raise NotImplementedError()
  39.         return out
复制代码

代码中的self.diffusion_model便是Stable Diffusion的Unet网络,网络结构位于ldm.modules.diffusionmodules.openaimodel.py文件中的UNetModel类。
II、UNetModel模子剖析

UNetModel重要做的工作是结适时间步t和文本Embedding盘算这一时间的噪声。只管UNet的思绪非常简朴,但是在StableDiffusion中,UNetModel由ResBlock和Transformer模块构成,团体来讲相比于平凡的UNet复杂一些。
Prompt通过Frozen CLIP Text Encoder得到Text Embedding,Timesteps通过全毗连(MLP)得到Timesteps Embedding;
ResBlock用于结适时间步Timesteps Embedding,Transformer模块用于团结文本Text Embedding。
我在这里放一张大图,同砚们可以看到内部shape的变革。

Unet代码如下所示:
  1. class UNetModel(nn.Module):
  2.     """
  3.     The full UNet model with attention and timestep embedding.
  4.     :param in_channels: channels in the input Tensor.
  5.     :param model_channels: base channel count for the model.
  6.     :param out_channels: channels in the output Tensor.
  7.     :param num_res_blocks: number of residual blocks per downsample.
  8.     :param attention_resolutions: a collection of downsample rates at which
  9.         attention will take place. May be a set, list, or tuple.
  10.         For example, if this contains 4, then at 4x downsampling, attention
  11.         will be used.
  12.     :param dropout: the dropout probability.
  13.     :param channel_mult: channel multiplier for each level of the UNet.
  14.     :param conv_resample: if True, use learned convolutions for upsampling and
  15.         downsampling.
  16.     :param dims: determines if the signal is 1D, 2D, or 3D.
  17.     :param num_classes: if specified (as an int), then this model will be
  18.         class-conditional with `num_classes` classes.
  19.     :param use_checkpoint: use gradient checkpointing to reduce memory usage.
  20.     :param num_heads: the number of attention heads in each attention layer.
  21.     :param num_heads_channels: if specified, ignore num_heads and instead use
  22.                                a fixed channel width per attention head.
  23.     :param num_heads_upsample: works with num_heads to set a different number
  24.                                of heads for upsampling. Deprecated.
  25.     :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
  26.     :param resblock_updown: use residual blocks for up/downsampling.
  27.     :param use_new_attention_order: use a different attention pattern for potentially
  28.                                     increased efficiency.
  29.     """
  30.     def __init__(
  31.         self,
  32.         image_size,
  33.         in_channels,
  34.         model_channels,
  35.         out_channels,
  36.         num_res_blocks,
  37.         attention_resolutions,
  38.         dropout=0,
  39.         channel_mult=(1, 2, 4, 8),
  40.         conv_resample=True,
  41.         dims=2,
  42.         num_classes=None,
  43.         use_checkpoint=False,
  44.         use_fp16=False,
  45.         num_heads=-1,
  46.         num_head_channels=-1,
  47.         num_heads_upsample=-1,
  48.         use_scale_shift_norm=False,
  49.         resblock_updown=False,
  50.         use_new_attention_order=False,
  51.         use_spatial_transformer=False,    # custom transformer support
  52.         transformer_depth=1,              # custom transformer support
  53.         context_dim=None,                 # custom transformer support
  54.         n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
  55.         legacy=True,
  56.     ):
  57.         super().__init__()
  58.         if use_spatial_transformer:
  59.             assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
  60.         if context_dim is not None:
  61.             assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
  62.             from omegaconf.listconfig import ListConfig
  63.             if type(context_dim) == ListConfig:
  64.                 context_dim = list(context_dim)
  65.         if num_heads_upsample == -1:
  66.             num_heads_upsample = num_heads
  67.         if num_heads == -1:
  68.             assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
  69.         if num_head_channels == -1:
  70.             assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
  71.         self.image_size = image_size
  72.         self.in_channels = in_channels
  73.         self.model_channels = model_channels
  74.         self.out_channels = out_channels
  75.         self.num_res_blocks = num_res_blocks
  76.         self.attention_resolutions = attention_resolutions
  77.         self.dropout = dropout
  78.         self.channel_mult = channel_mult
  79.         self.conv_resample = conv_resample
  80.         self.num_classes = num_classes
  81.         self.use_checkpoint = use_checkpoint
  82.         self.dtype = th.float16 if use_fp16 else th.float32
  83.         self.num_heads = num_heads
  84.         self.num_head_channels = num_head_channels
  85.         self.num_heads_upsample = num_heads_upsample
  86.         self.predict_codebook_ids = n_embed is not None
  87.         # 用于计算当前采样时间t的embedding
  88.         time_embed_dim  = model_channels * 4
  89.         self.time_embed = nn.Sequential(
  90.             linear(model_channels, time_embed_dim),
  91.             nn.SiLU(),
  92.             linear(time_embed_dim, time_embed_dim),
  93.         )
  94.         if self.num_classes is not None:
  95.             self.label_emb = nn.Embedding(num_classes, time_embed_dim)
  96.         
  97.         # 定义输入模块的第一个卷积
  98.         # TimestepEmbedSequential也可以看作一个包装器,根据层的种类进行时间或者文本的融合。
  99.         self.input_blocks = nn.ModuleList(
  100.             [
  101.                 TimestepEmbedSequential(
  102.                     conv_nd(dims, in_channels, model_channels, 3, padding=1)
  103.                 )
  104.             ]
  105.         )
  106.         self._feature_size  = model_channels
  107.         input_block_chans   = [model_channels]
  108.         ch                  = model_channels
  109.         ds                  = 1
  110.         # 对channel_mult进行循环,channel_mult一共有四个值,代表unet四个部分通道的扩张比例
  111.         # [1, 2, 4, 4]
  112.         for level, mult in enumerate(channel_mult):
  113.             # 每个部分循环两次
  114.             # 添加一个ResBlock和一个AttentionBlock
  115.             for _ in range(num_res_blocks):
  116.                 # 先添加一个ResBlock
  117.                 # 用于对输入的噪声进行通道数的调整,并且融合t的特征
  118.                 layers = [
  119.                     ResBlock(
  120.                         ch,
  121.                         time_embed_dim,
  122.                         dropout,
  123.                         out_channels=mult * model_channels,
  124.                         dims=dims,
  125.                         use_checkpoint=use_checkpoint,
  126.                         use_scale_shift_norm=use_scale_shift_norm,
  127.                     )
  128.                 ]
  129.                 # ch便是上述ResBlock的输出通道数
  130.                 ch = mult * model_channels
  131.                 if ds in attention_resolutions:
  132.                     # num_heads=8
  133.                     if num_head_channels == -1:
  134.                         dim_head = ch // num_heads
  135.                     else:
  136.                         num_heads = ch // num_head_channels
  137.                         dim_head = num_head_channels
  138.                     if legacy:
  139.                         #num_heads = 1
  140.                         dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
  141.                     # 使用了SpatialTransformer自注意力,加强全局特征,融合文本的特征
  142.                     layers.append(
  143.                         AttentionBlock(
  144.                             ch,
  145.                             use_checkpoint=use_checkpoint,
  146.                             num_heads=num_heads,
  147.                             num_head_channels=dim_head,
  148.                             use_new_attention_order=use_new_attention_order,
  149.                         ) if not use_spatial_transformer else SpatialTransformer(
  150.                             ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
  151.                         )
  152.                     )
  153.                 self.input_blocks.append(TimestepEmbedSequential(*layers))
  154.                 self._feature_size += ch
  155.                 input_block_chans.append(ch)
  156.             # 如果不是四个部分中的最后一个部分,那么都要进行下采样。
  157.             if level != len(channel_mult) - 1:
  158.                 out_ch = ch
  159.                 # 在此处进行下采样
  160.                 # 一般直接使用Downsample模块
  161.                 self.input_blocks.append(
  162.                     TimestepEmbedSequential(
  163.                         ResBlock(
  164.                             ch,
  165.                             time_embed_dim,
  166.                             dropout,
  167.                             out_channels=out_ch,
  168.                             dims=dims,
  169.                             use_checkpoint=use_checkpoint,
  170.                             use_scale_shift_norm=use_scale_shift_norm,
  171.                             down=True,
  172.                         )
  173.                         if resblock_updown
  174.                         else Downsample(
  175.                             ch, conv_resample, dims=dims, out_channels=out_ch
  176.                         )
  177.                     )
  178.                 )
  179.                 # 为下一阶段定义参数。
  180.                 ch = out_ch
  181.                 input_block_chans.append(ch)
  182.                 ds *= 2
  183.                 self._feature_size += ch
  184.         if num_head_channels == -1:
  185.             dim_head = ch // num_heads
  186.         else:
  187.             num_heads = ch // num_head_channels
  188.             dim_head = num_head_channels
  189.         if legacy:
  190.             #num_heads = 1
  191.             dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
  192.         # 定义中间层
  193.         # ResBlock + SpatialTransformer + ResBlock
  194.         self.middle_block = TimestepEmbedSequential(
  195.             ResBlock(
  196.                 ch,
  197.                 time_embed_dim,
  198.                 dropout,
  199.                 dims=dims,
  200.                 use_checkpoint=use_checkpoint,
  201.                 use_scale_shift_norm=use_scale_shift_norm,
  202.             ),
  203.             AttentionBlock(
  204.                 ch,
  205.                 use_checkpoint=use_checkpoint,
  206.                 num_heads=num_heads,
  207.                 num_head_channels=dim_head,
  208.                 use_new_attention_order=use_new_attention_order,
  209.             ) if not use_spatial_transformer else SpatialTransformer(
  210.                             ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
  211.                         ),
  212.             ResBlock(
  213.                 ch,
  214.                 time_embed_dim,
  215.                 dropout,
  216.                 dims=dims,
  217.                 use_checkpoint=use_checkpoint,
  218.                 use_scale_shift_norm=use_scale_shift_norm,
  219.             ),
  220.         )
  221.         self._feature_size += ch
  222.         # 定义Unet上采样过程
  223.         self.output_blocks = nn.ModuleList([])
  224.         # 循环把channel_mult反了过来
  225.         for level, mult in list(enumerate(channel_mult))[::-1]:
  226.             # 上采样时每个部分循环三次
  227.             for i in range(num_res_blocks + 1):
  228.                 ich = input_block_chans.pop()
  229.                 # 首先添加ResBlock层
  230.                 layers = [
  231.                     ResBlock(
  232.                         ch + ich,
  233.                         time_embed_dim,
  234.                         dropout,
  235.                         out_channels=model_channels * mult,
  236.                         dims=dims,
  237.                         use_checkpoint=use_checkpoint,
  238.                         use_scale_shift_norm=use_scale_shift_norm,
  239.                     )
  240.                 ]
  241.                 ch = model_channels * mult
  242.                 # 然后进行SpatialTransformer自注意力
  243.                 if ds in attention_resolutions:
  244.                     if num_head_channels == -1:
  245.                         dim_head = ch // num_heads
  246.                     else:
  247.                         num_heads = ch // num_head_channels
  248.                         dim_head = num_head_channels
  249.                     if legacy:
  250.                         #num_heads = 1
  251.                         dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
  252.                     layers.append(
  253.                         AttentionBlock(
  254.                             ch,
  255.                             use_checkpoint=use_checkpoint,
  256.                             num_heads=num_heads_upsample,
  257.                             num_head_channels=dim_head,
  258.                             use_new_attention_order=use_new_attention_order,
  259.                         ) if not use_spatial_transformer else SpatialTransformer(
  260.                             ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
  261.                         )
  262.                     )
  263.                 # 如果不是channel_mult循环的第一个
  264.                 # 且
  265.                 # 是num_res_blocks循环的最后一次,则进行上采样
  266.                 if level and i == num_res_blocks:
  267.                     out_ch = ch
  268.                     layers.append(
  269.                         ResBlock(
  270.                             ch,
  271.                             time_embed_dim,
  272.                             dropout,
  273.                             out_channels=out_ch,
  274.                             dims=dims,
  275.                             use_checkpoint=use_checkpoint,
  276.                             use_scale_shift_norm=use_scale_shift_norm,
  277.                             up=True,
  278.                         )
  279.                         if resblock_updown
  280.                         else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
  281.                     )
  282.                     ds //= 2
  283.                 self.output_blocks.append(TimestepEmbedSequential(*layers))
  284.                 self._feature_size += ch
  285.         # 最后在输出部分进行一次卷积
  286.         self.out = nn.Sequential(
  287.             normalization(ch),
  288.             nn.SiLU(),
  289.             zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
  290.         )
  291.         if self.predict_codebook_ids:
  292.             self.id_predictor = nn.Sequential(
  293.             normalization(ch),
  294.             conv_nd(dims, model_channels, n_embed, 1),
  295.             #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
  296.         )
  297.     def convert_to_fp16(self):
  298.         """
  299.         Convert the torso of the model to float16.
  300.         """
  301.         self.input_blocks.apply(convert_module_to_f16)
  302.         self.middle_block.apply(convert_module_to_f16)
  303.         self.output_blocks.apply(convert_module_to_f16)
  304.     def convert_to_fp32(self):
  305.         """
  306.         Convert the torso of the model to float32.
  307.         """
  308.         self.input_blocks.apply(convert_module_to_f32)
  309.         self.middle_block.apply(convert_module_to_f32)
  310.         self.output_blocks.apply(convert_module_to_f32)
  311.     def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
  312.         """
  313.         Apply the model to an input batch.
  314.         :param x: an [N x C x ...] Tensor of inputs.
  315.         :param timesteps: a 1-D batch of timesteps.
  316.         :param context: conditioning plugged in via crossattn
  317.         :param y: an [N] Tensor of labels, if class-conditional.
  318.         :return: an [N x C x ...] Tensor of outputs.
  319.         """
  320.         assert (y is not None) == (
  321.             self.num_classes is not None
  322.         ), "must specify y if and only if the model is class-conditional"
  323.         hs      = []
  324.         # 用于计算当前采样时间t的embedding
  325.         t_emb   = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
  326.         emb     = self.time_embed(t_emb)
  327.         if self.num_classes is not None:
  328.             assert y.shape == (x.shape[0],)
  329.             emb = emb + self.label_emb(y)
  330.         # 对输入模块进行循环,进行下采样并且融合时间特征与文本特征。
  331.         h = x.type(self.dtype)
  332.         for module in self.input_blocks:
  333.             h = module(h, emb, context)
  334.             hs.append(h)
  335.         # 中间模块的特征提取
  336.         h = self.middle_block(h, emb, context)
  337.         # 上采样模块的特征提取
  338.         for module in self.output_blocks:
  339.             h = th.cat([h, hs.pop()], dim=1)
  340.             h = module(h, emb, context)
  341.         h = h.type(x.dtype)
  342.         # 输出模块
  343.         if self.predict_codebook_ids:
  344.             return self.id_predictor(h)
  345.         else:
  346.             return self.out(h)
复制代码
4、隐空间解码天生图片


通过上述步调,已经可以多次采样得到结果,然后我们便可以通过隐空间解码天生图片。
隐空间解码天生图片的过程非常简朴,将上文多次采样后的结果,利用decode_first_stage方法即可天生图片。
在decode_first_stage方法中,网络调用VAE对获取到的64x64x3的隐向量举行解码,得到512x512x3的图片。
  1. @torch.no_grad()
  2. def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
  3.     if predict_cids:
  4.         if z.dim() == 4:
  5.             z = torch.argmax(z.exp(), dim=1).long()
  6.         z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
  7.         z = rearrange(z, 'b h w c -> b c h w').contiguous()
  8.     z = 1. / self.scale_factor * z
  9.         # 一般无需分割输入,所以直接将x_noisy传入self.model中,在下面else进行
  10.     if hasattr(self, "split_input_params"):
  11.             ......
  12.     else:
  13.         if isinstance(self.first_stage_model, VQModelInterface):
  14.             return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
  15.         else:
  16.             return self.first_stage_model.decode(z)
复制代码
图像到图像推测过程代码

团体推测代码如下:
  1. import os
  2. import random
  3. import cv2
  4. import einops
  5. import numpy as np
  6. import torch
  7. from PIL import Image
  8. from pytorch_lightning import seed_everything
  9. from ldm_hacked import *
  10. # ----------------------- #
  11. #   使用的参数
  12. # ----------------------- #
  13. # config的地址
  14. config_path = "model_data/sd_v15.yaml"
  15. # 模型的地址
  16. model_path  = "model_data/v1-5-pruned-emaonly.safetensors"
  17. # fp16,可以加速与节省显存
  18. sd_fp16     = True
  19. vae_fp16    = True
  20. # ----------------------- #
  21. #   生成图片的参数
  22. # ----------------------- #
  23. # 生成的图像大小为input_shape,对于img2img会进行Centter Crop
  24. input_shape = [512, 512]
  25. # 一次生成几张图像
  26. num_samples = 1
  27. # 采样的步数
  28. ddim_steps  = 20
  29. # 采样的种子,为-1的话则随机。
  30. seed        = 12345
  31. # eta
  32. eta         = 0
  33. # denoise强度,for img2img
  34. denoise_strength = 1.0
  35. # ----------------------- #
  36. #   提示词相关参数
  37. # ----------------------- #
  38. # 提示词
  39. prompt      = "a cute cat, with yellow leaf, trees"
  40. # 正面提示词
  41. a_prompt    = "best quality, extremely detailed"
  42. # 负面提示词
  43. n_prompt    = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
  44. # 正负扩大倍数
  45. scale       = 9
  46. # img2img使用,如果不想img2img这设置为None。
  47. image_path  = None
  48. # ----------------------- #
  49. #   保存路径
  50. # ----------------------- #
  51. save_path   = "imgs/outputs_imgs"
  52. # ----------------------- #
  53. #   创建模型
  54. # ----------------------- #
  55. model   = create_model(config_path).cpu()
  56. model.load_state_dict(load_state_dict(model_path, location='cuda'), strict=False)
  57. model   = model.cuda()
  58. ddim_sampler = DDIMSampler(model)
  59. if sd_fp16:
  60.     model = model.half()
  61. if image_path is not None:
  62.     img = Image.open(image_path)
  63.     img = crop_and_resize(img, input_shape[0], input_shape[1])
  64. with torch.no_grad():
  65.     if seed == -1:
  66.         seed = random.randint(0, 65535)
  67.     seed_everything(seed)
  68.    
  69.     # ----------------------- #
  70.     #   对输入图片进行编码并加噪
  71.     # ----------------------- #
  72.     if image_path is not None:
  73.         img = HWC3(np.array(img, np.uint8))
  74.         img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
  75.         img = torch.stack([img for _ in range(num_samples)], dim=0)
  76.         img = einops.rearrange(img, 'b h w c -> b c h w').clone()
  77.         if vae_fp16:
  78.             img = img.half()
  79.             model.first_stage_model = model.first_stage_model.half()
  80.         else:
  81.             model.first_stage_model = model.first_stage_model.float()
  82.         ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
  83.         t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
  84.         z = model.get_first_stage_encoding(model.encode_first_stage(img))
  85.         z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
  86.         z_enc = z_enc.half() if sd_fp16 else z_enc.float()
  87.     # ----------------------- #
  88.     #   获得编码后的prompt
  89.     # ----------------------- #
  90.     cond    = {"c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
  91.     un_cond = {"c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
  92.     H, W    = input_shape
  93.     shape   = (4, H // 8, W // 8)
  94.     if image_path is not None:
  95.         samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
  96.     else:
  97.         # ----------------------- #
  98.         #   进行采样
  99.         # ----------------------- #
  100.         samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
  101.                                                         shape, cond, verbose=False, eta=eta,
  102.                                                         unconditional_guidance_scale=scale,
  103.                                                         unconditional_conditioning=un_cond)
  104.     # ----------------------- #
  105.     #   进行解码
  106.     # ----------------------- #
  107.     x_samples = model.decode_first_stage(samples.half() if vae_fp16 else samples.float())
  108.     x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
  109. # ----------------------- #
  110. #   保存图片
  111. # ----------------------- #
  112. if not os.path.exists(save_path):
  113.     os.makedirs(save_path)
  114. for index, image in enumerate(x_samples):
  115.     cv2.imwrite(os.path.join(save_path, str(index) + ".jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!qidao123.com:ToB企服之家,中国第一个企服评测及软件市场,开放入驻,技术点评得现金

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
回复

使用道具 举报

登录后关闭弹窗

登录参与点评抽奖  加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表