qidao123.com技术社区-IT企服评测·应用市场

标题: Stable Diffusion机制七万字全解,工程细节全代码分析 [打印本页]

作者: tsx81428    时间: 2025-4-3 08:15
标题: Stable Diffusion机制七万字全解,工程细节全代码分析
文章分析:

        本文目标为,根据pipeline,按顺序讲授每一个模块的原理、目标,更重要的是所有工程实现的细节。通过这种一步步的推导,希望提升对SD里各个模块内部、之间如何协同工作的理解。所以本文并不会涉及过多的,论文公式的推导和由来,而侧重于如何真正的实现公式的落地与应用。
        同时要分析,本文章所用代码链接为: https://github.com/hkproj/pytorch-stable-diffusion.git因为陈诉代码流程难免会有复杂的调用关系,最好还是下载好代码对照阅读不易走丢。
        并在此保举hkproj (Umar Jamil),其Coding Stable Diffusion from scratch in PyTorch,虽然时长达到五个小时,但是团体思绪非常清晰。
团体框架:

        一般来说,从团体框架入手,一部分一部分办理模块,思绪会比力清晰。下面以文生图、图生图环境为例做一些分析(两者本质都一致)。
        文生图:


        图生图:



理论叙述:

        在这一部分,对框架中的每一个模块的作用做一个简朴的描述
CLIP:

        首先从下图的红框部分开始,即CLIP模块。

        CLIP(Contrastive Language image pre-training)是OpenAI在2021年发布的模型,其作用为,将text prompt编码为prompt embedding。
        其提出的契机为,比方在分类领域,如果用卷积神经网络来分类,其需求是必要大量的有标注的数据集,这在很多时候是很困难的。而CLIP的数据集是直接从互联网获取的,可以用很小的本钱直接获取大量的文本与图像对应的训练数据集。这种标注可以称之为天然语言监督natural language supervision。
        其训练时,用image encoder编码batch张图像至embedding space得到I向量,同时用text encoder编码 图像对应的描述 至embedding space得到T向量。之后T向量与I向量做余弦相似度。当然我们希望某个样本图像的潜在表达跟对应描述的潜在表达的相似度更高,而跟其他图像的描述的潜在表达的相似度更低。即如图,想让对角线上的值尽量大,其他位置尽量小。

        因为如何得到prompt embedding并不是SD的核心考量部分,在此不会过多阐述,可在代码部分略见一两。
VAE_Encoder:


        这一步就体现出SD的核心改善了。当图像分辨率很大时,如果每次都在UNET里对原始巨细进行计算、处理,计算量很大,所以想办法要compress图像。SD就是用VAE,保证了UNET的操纵对象是小尺寸的latent represent
        对于图生图来说,首先VAE_encoder会将图像编码为latent represent,之后对latent represent加上噪声。这个所加的噪声强度就很有讲求了,因为生成图像时是一步步去噪的过程,如果这里加的噪声很少,模型能去除的噪声就少,相当于去噪的手段变少了。也就是说这里的噪声强度,相当于SD生成图像时的机动性,如果所加噪声很弱,分析想让模型牢牢跟随original image的指挥,反之则让模型拥有更多自由发挥的空间。
        对于文生图来说,因为没有original image,这里直接把随机采样的噪声作为latent represent即可。

Diffusion:

        中间的SD核心部分,其又可以分为,下面一部分的UNET和上面一部分的scheduler。UNET会取latent represent、prompt embedding以及当前的timestep,输出其预测的,若要使该latent向prompt发展,在该timestep时,该被去掉的噪声的强度。之后scheduler将该预测噪声从latent中去除,得到噪声更少的latent版本。两者合作遍历timestep,走完timestep,则得到其认为的无噪声的latent represent。

VAE_Decoder:

        其将最终被预测为无噪声的latent,从latent space映射回原始的space,得到latent对应的图像。

pipeline代码解析:

        下面我将以构建pipeline为主线,渐渐解析每一个流程,下面是完整的pipeline代码:
  1. import torch
  2. import numpy as np
  3. from tqdm import tqdm
  4. from ddpm import DDPMSampler
  5. WIDTH = 512
  6. HEIGHT = 512
  7. LATENTS_WIDTH = WIDTH // 8
  8. LATENTS_HEIGHT = HEIGHT // 8
  9. def generate(
  10.     prompt,
  11.     uncond_prompt=None,
  12.     input_image=None,
  13.     strength=0.8,
  14.     do_cfg=True,
  15.     cfg_scale=7.5,
  16.     sampler_name="ddpm",
  17.     n_inference_steps=50,
  18.     models={},
  19.     seed=None,
  20.     device=None,
  21.     idle_device=None,
  22.     tokenizer=None,
  23. ):
  24.     with torch.no_grad():
  25.         if not 0 < strength <= 1:
  26.             raise ValueError("strength must be between 0 and 1")
  27.         if idle_device:
  28.             to_idle = lambda x: x.to(idle_device)
  29.         else:
  30.             to_idle = lambda x: x
  31.         # Initialize random number generator according to the seed specified
  32.         generator = torch.Generator(device=device)
  33.         if seed is None:
  34.             generator.seed()
  35.         else:
  36.             generator.manual_seed(seed)
  37.         clip = models["clip"]
  38.         clip.to(device)
  39.         
  40.         if do_cfg:
  41.             # Convert into a list of length Seq_Len=77
  42.             cond_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
  43.             # (Batch_Size, Seq_Len)
  44.             cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
  45.             # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
  46.             cond_context = clip(cond_tokens)
  47.             # Convert into a list of length Seq_Len=77
  48.             uncond_tokens = tokenizer.batch_encode_plus([uncond_prompt], padding="max_length", max_length=77).input_ids
  49.             # (Batch_Size, Seq_Len)
  50.             uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
  51.             # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
  52.             uncond_context = clip(uncond_tokens)
  53.             # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (2 * Batch_Size, Seq_Len, Dim)
  54.             context = torch.cat([cond_context, uncond_context])
  55.         else:
  56.             # Convert into a list of length Seq_Len=77
  57.             tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
  58.             # (Batch_Size, Seq_Len)
  59.             tokens = torch.tensor(tokens, dtype=torch.long, device=device)
  60.             # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
  61.             context = clip(tokens)
  62.         to_idle(clip)
  63.         if sampler_name == "ddpm":
  64.             sampler = DDPMSampler(generator)
  65.             sampler.set_inference_timesteps(n_inference_steps)
  66.         else:
  67.             raise ValueError("Unknown sampler value %s. ")
  68.         latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
  69.         if input_image:
  70.             encoder = models["encoder"]
  71.             encoder.to(device)
  72.             input_image_tensor = input_image.resize((WIDTH, HEIGHT))
  73.             input_image_tensor = np.array(input_image_tensor)
  74.             input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
  75.             input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
  76.             # (Height, Width, Channel) -> (Batch_Size, Height, Width, Channel)
  77.             input_image_tensor = input_image_tensor.unsqueeze(0)
  78.             # (Batch_Size, Height, Width, Channel) -> (Batch_Size, Channel, Height, Width)
  79.             input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
  80.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  81.             encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
  82.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  83.             latents = encoder(input_image_tensor, encoder_noise)
  84.             # Add noise to the latents (the encoded input image)
  85.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  86.             sampler.set_strength(strength=strength)
  87.             latents = sampler.add_noise(latents, sampler.timesteps[0])
  88.             to_idle(encoder)
  89.         else:
  90.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  91.             latents = torch.randn(latents_shape, generator=generator, device=device)
  92.         diffusion = models["diffusion"]
  93.         diffusion.to(device)
  94.         timesteps = tqdm(sampler.timesteps)
  95.         for i, timestep in enumerate(timesteps):
  96.             # (1, 320)
  97.             time_embedding = get_time_embedding(timestep).to(device)
  98.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  99.             model_input = latents
  100.             if do_cfg:
  101.                 # (Batch_Size, 4, Latents_Height, Latents_Width) -> (2 * Batch_Size, 4, Latents_Height, Latents_Width)
  102.                 model_input = model_input.repeat(2, 1, 1, 1)
  103.             # model_output is the predicted noise
  104.             # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
  105.             model_output = diffusion(model_input, context, time_embedding)
  106.             if do_cfg:
  107.                 output_cond, output_uncond = model_output.chunk(2)
  108.                 model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
  109.             # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
  110.             latents = sampler.step(timestep, latents, model_output)
  111.         to_idle(diffusion)
  112.         decoder = models["decoder"]
  113.         decoder.to(device)
  114.         # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 3, Height, Width)
  115.         images = decoder(latents)
  116.         to_idle(decoder)
  117.         images = rescale(images, (-1, 1), (0, 255), clamp=True)
  118.         # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Height, Width, Channel)
  119.         images = images.permute(0, 2, 3, 1)
  120.         images = images.to("cpu", torch.uint8).numpy()
  121.         return images[0]
  122.    
  123. def rescale(x, old_range, new_range, clamp=False):
  124.     old_min, old_max = old_range
  125.     new_min, new_max = new_range
  126.     x -= old_min
  127.     x *= (new_max - new_min) / (old_max - old_min)
  128.     x += new_min
  129.     if clamp:
  130.         x = x.clamp(new_min, new_max)
  131.     return x
  132. def get_time_embedding(timestep):
  133.     # Shape: (160,)
  134.     freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
  135.     # Shape: (1, 160)
  136.     x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
  137.     # Shape: (1, 160 * 2)
  138.     return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
复制代码
我们一点点来拆解,首先我们处理的原始图像巨细是512*512的,而其会被VAE映射到latent space,巨细被压缩为64*64。
  1. WIDTH = 512
  2. HEIGHT = 512
  3. LATENTS_WIDTH = WIDTH // 8
  4. LATENTS_HEIGHT = HEIGHT // 8
复制代码
之后generate函数会有以下输入参数,每一个会在接下来遇到的时候做具体的表明:
  1. def generate(
  2.     prompt,
  3.     uncond_prompt=None,
  4.     input_image=None,
  5.     strength=0.8,
  6.     do_cfg=True,
  7.     cfg_scale=7.5,
  8.     sampler_name="ddpm",
  9.     n_inference_steps=50,
  10.     models={},
  11.     seed=None,
  12.     device=None,
  13.     idle_device=None,
  14.     tokenizer=None,
  15. ):
复制代码
当然,因为当前是生成模式 ,所以不必要梯度计算 with torch.no_grad():。且因为整个生成过程,各模块是串行工作的,为避免GPU超负荷,我们提供了idle_device,让已经用好的模块存到idle_device中。
其次是随机种子的相干代码。
这里的strength即图生图时,对latent的所加的噪声的强度的控制参数,会在之后的相干函数中做更具体的分析。
  1.     with torch.no_grad():
  2.         if not 0 < strength <= 1:
  3.             raise ValueError("strength must be between 0 and 1")
  4.         if idle_device:
  5.             to_idle = lambda x: x.to(idle_device)
  6.         else:
  7.             to_idle = lambda x: x
  8.         # Initialize random number generator according to the seed specified
  9.         generator = torch.Generator(device=device)
  10.         if seed is None:
  11.             generator.seed()
  12.         else:
  13.             generator.manual_seed(seed)
复制代码
之后正是进入到第一个模块
CLIP:

  1.         clip = models["clip"]
  2.         clip.to(device)
  3.         
  4.         if do_cfg:
  5.             # Convert into a list of length Seq_Len=77
  6.             cond_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
  7.             # (Batch_Size, Seq_Len)
  8.             cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
  9.             # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
  10.             cond_context = clip(cond_tokens)
  11.             uncond_tokens = tokenizer.batch_encode_plus([uncond_prompt], padding="max_length", max_length=77).input_ids
  12.             uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
  13.             uncond_context = clip(uncond_tokens)
  14.             # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (2 * Batch_Size, Seq_Len, Dim)
  15.             context = torch.cat([cond_context, uncond_context])
  16.         else:
  17.             tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
  18.             tokens = torch.tensor(tokens, dtype=torch.long, device=device)
  19.             context = clip(tokens)
  20.         to_idle(clip)
复制代码
首先,从models中去取出clip模型。其中models由来如下:
其中v1-5-pruned-emaonly.ckpt的下载地点为:https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main 
  1. model_file = "../data/v1-5-pruned-emaonly.ckpt"
  2. models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)
复制代码
model_loader文件中的preload_models_from_standard_weights所做之事可以总结为:从模型参数集中,拆分各个模块的参数并生存下来。
  1. def preload_models_from_standard_weights(ckpt_path, device):
  2.     state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
  3.     encoder = VAE_Encoder().to(device)
  4.     encoder.load_state_dict(state_dict['encoder'], strict=True)
  5.     decoder = VAE_Decoder().to(device)
  6.     decoder.load_state_dict(state_dict['decoder'], strict=True)
  7.     diffusion = Diffusion().to(device)
  8.     diffusion.load_state_dict(state_dict['diffusion'], strict=True)
  9.     clip = CLIP().to(device)
  10.     clip.load_state_dict(state_dict['clip'], strict=True)
  11.     return {
  12.         'clip': clip,
  13.         'encoder': encoder,
  14.         'decoder': decoder,
  15.         'diffusion': diffusion,
  16.     }
复制代码
其中model_converter文件中load_from_standard_weights函数的用意为,实现参数名称的映射。因为预训练模型参数名称并不是很直观,在本文章的代码中,参数名称都改为了更清晰更直观的称谓。这就导致,如果没有把预训练的参数名称重映射为新的参数名称,则加载权重时,无法乐成。
model_converter文件具体代码可参考链接:pytorch-stable-diffusion/sd/model_converter.py at main · hkproj/pytorch-stable-diffusion
之后判断do_cfg是否为真,即do classifier-free guidance or not,是否进行无分类器引导。
而在此无分类器引导可以认为是,每张最终输出的图像,都是两张生成图像的线性组合。其中一张是由positive prompt (参数prompt)出发生成的图像OUTPUTconditioned,另一张是由negative prompt (参数uncond_prompt)出发生成的图像OUTPUTunconditioned。最终输出为:

这里的prompt就是描述我们的生成目标,而uncond_prompt可以认为是告诉模型我们不要什么,或者一般就使用空字符串,即多给模型提供一些自由度。比方prompt参数为“生成一个慵懒的猫”,但是不想让它躺在沙发上,那么就给uncond_prompt传入“沙发”,或者不提其他要求,uncond_prompt传入“”空字符串。
拿传入空字符串的环境来理解权重w(也就是当前传参cfg_scale),即若cfg_scale很高表明我们希望模型严格按照我们的提示prompt来生成图像,自由发挥的机动性小。
所以我们一次生成,必要两个提示内容,两者分别用CLIP处理得到token的embedding represent,且在最后将其连接起来,方便一起生成操纵。
  1.         if do_cfg:
  2.             # Convert into a list of length Seq_Len=77
  3.             cond_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
  4.             # (Batch_Size, Seq_Len)
  5.             cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
  6.             # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
  7.             cond_context = clip(cond_tokens)
  8.             uncond_tokens = tokenizer.batch_encode_plus([uncond_prompt], padding="max_length", max_length=77).input_ids
  9.             uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
  10.             uncond_context = clip(uncond_tokens)
  11.             # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (2 * Batch_Size, Seq_Len, Dim)
  12.             context = torch.cat([cond_context, uncond_context])
复制代码
这里用tokenizer将prompt处理为cond_tokens,将uncond_prompt处理为uncond_tokens的具体实现就不加具体阐述了,可以直接理解为把提示句子按某种具体的方法切割成每个word。由下面这行代码得到现成的tokenizer。文件下载地点: https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/tokenizer
  1. tokenizer = CLIPTokenizer("../data/tokenizer_vocab.json", merges_file="../data/tokenizer_merges.txt")
复制代码
此时重点就在于clip到底做了什么事,可以参考transformer中的编码器部分,如图


首先CLIP类定义为:
  1. class CLIP(nn.Module):
  2.     def __init__(self):
  3.         super().__init__()
  4.         self.embedding = CLIPEmbedding(49408, 768, 77)
  5.         self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)])
  6.         self.layernorm = nn.LayerNorm(768)
  7.     def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
  8.         tokens = tokens.type(torch.long)
  9.         # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
  10.         state = self.embedding(tokens)
  11.         # Apply encoder layers similar to the Transformer's encoder.
  12.         for layer in self.layers:
  13.             # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
  14.             state = layer(state)
  15.         # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
  16.         output = self.layernorm(state)
  17.         return output
复制代码
对于理解这种自定义的类,我倾向于从forward,也就是真正使用它的部分来动手。
首先一个基础的范例转换,因为之前得到的tokens是torch.LongTensor,但是之后embedding必要torch.long。
而这里的self.embedding是自定义的CLIPEmbedding,如下:
  1. class CLIPEmbedding(nn.Module):
  2.     def __init__(self, n_vocab: int, n_embd: int, n_token: int):
  3.         super().__init__()
  4.         self.token_embedding = nn.Embedding(n_vocab, n_embd)
  5.         # A learnable weight matrix encodes the position information for each token
  6.         self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))
  7.     def forward(self, tokens):
  8.         # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
  9.         x = self.token_embedding(tokens)
  10.         # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
  11.         x += self.position_embedding
  12.         return x
复制代码
对每个token,我们必要两个embedding,一个input embedding,一个position embedding。
input embedding即相当于一个词汇表,每个单词会对应一个数字,再通过 nn.Embedding把这个数字映射成word对应的input embedding。也就是把单词用一个可学习的矩阵来表达。
position embedding即相当于,我们要对单词所在的位置进行编码表达。团体的CLIPembedding是input embedding与position embedding的和,即一起关注单词本身的意思以及它所在的位置。
且这里体现出词汇表的巨细为n_vocab=49408,句子最长可为n_token=77,而两种嵌入的表达向量长度为n_embd=768。
与transformer中的position embedding差别的是,transformer中位置编码是固定的,如典范的代码如下:
  1. class PositionalEncoding(nn.Module):
  2.     def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
  3.         super().__init__()
  4.         self.d_model = d_model
  5.         self.seq_len = seq_len
  6.         self.dropout = nn.Dropout(dropout)
  7.         # Create a matrix of shape (seq_len, d_model)
  8.         pe = torch.zeros(seq_len, d_model)
  9.         # Create a vector of shape (seq_len)
  10.         position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
  11.         # Create a vector of shape (d_model)
  12.         div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
  13.         # Apply sine to even indices
  14.         pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
  15.         # Apply cosine to odd indices
  16.         pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
  17.         # Add a batch dimension to the positional encoding
  18.         pe = pe.unsqueeze(0) # (1, seq_len, d_model)
  19.         # Register the positional encoding as a buffer
  20.         self.register_buffer('pe', pe)
  21.     def forward(self, x):
  22.         x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
  23.         return self.dropout(x)
复制代码
它是不可学习的,在初始化时就固定好了。使用的是正弦和余弦的编码方式,具体函数如下图:

而在这里CLIP中位置编码是可学习的参数,初始化为0。为什么呢?我认为图像像素之间的空间关系比文本序列中的位置关系复杂得多,固定的正余弦编码大概无法充分表达这种复杂的空间依赖。
之后便是遍历全部self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)]),CLIPLayer定义如下:
  1. class CLIPLayer(nn.Module):
  2.     def __init__(self, n_head: int, n_embd: int):
  3.         super().__init__()
  4.         # Pre-attention norm
  5.         self.layernorm_1 = nn.LayerNorm(n_embd)
  6.         # Self attention
  7.         self.attention = SelfAttention(n_head, n_embd)
  8.         # Pre-FNN norm
  9.         self.layernorm_2 = nn.LayerNorm(n_embd)
  10.         # Feedforward layer
  11.         self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
  12.         self.linear_2 = nn.Linear(4 * n_embd, n_embd)
  13.     def forward(self, x):
  14.         # (Batch_Size, Seq_Len, Dim)
  15.         ### SELF-ATTENTION ###
  16.         residue = x
  17.         x = self.layernorm_1(x)
  18.         x = self.attention(x, causal_mask=True)
  19.         x += residue
  20.         ### FEEDFORWARD LAYER ###
  21.         residue = x
  22.         x = self.layernorm_2(x)
  23.         # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, 4 * Dim)
  24.         x = self.linear_1(x)
  25.         x = x * torch.sigmoid(1.702 * x)   # QuickGELU activation function
  26.         # (Batch_Size, Seq_Len, 4 * Dim) -> (Batch_Size, Seq_Len, Dim)
  27.         x = self.linear_2(x)
  28.         x += residue
  29.         return x
复制代码
这里每个CLIPLayer所做的就是两个残差块,一个是关于多头自注意力的,另一个是向前反馈层的,我们一步步说。希望可以耐心继续阅读下去,因为注意力等模块是反面根本所有大模块的基石,在这叙述的具体一些,在之后就可直接调用了。
我们还是借助transformer的编码器部分的流程图来看:

首先来实现
多头自注意力的残差块:

  1.     def forward(self, x):
  2.         # (Batch_Size, Seq_Len, Dim)
  3.         ### SELF-ATTENTION ###
  4.         residue = x
  5.         x = self.layernorm_1(x)
  6.         x = self.attention(x, causal_mask=True)
  7.         x += residue
复制代码
保留初始输入为残差residue。对输入先进行层归一化处理self.layernorm_1 = nn.LayerNorm(n_embd)。所谓层归一化(Layer Normalization),是对每个样本的所有特征进行归一化,对于每个token的特征向量,层归一化管帐算该向量的均值和方差,然后进行归一化,使得每个token的嵌入向量在每个维度上具有零均值和单位方差。如图:

一些常见的归一化的对比,可以见下图:

之后对层归一化后的输入使用自注意力self.attention = SelfAttention(n_head, n_embd),SelfAttention函数如下:
  1. class SelfAttention(nn.Module):
  2.     def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
  3.         super().__init__()
  4.         # This combines the Wq, Wk and Wv matrices into one matrix
  5.         self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
  6.         # This one represents the Wo matrix
  7.         self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
  8.         self.n_heads = n_heads
  9.         self.d_head = d_embed // n_heads
  10.     def forward(self, x, causal_mask=False):
  11.         # x: # (Batch_Size, Seq_Len, Dim)
  12.         # (Batch_Size, Seq_Len, Dim)
  13.         input_shape = x.shape
  14.         batch_size, sequence_length, d_embed = input_shape
  15.         # (Batch_Size, Seq_Len, H, Dim / H)
  16.         interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
  17.         # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)
  18.         q, k, v = self.in_proj(x).chunk(3, dim=-1)
  19.         # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
  20.         q = q.view(interim_shape).transpose(1, 2)
  21.         k = k.view(interim_shape).transpose(1, 2)
  22.         v = v.view(interim_shape).transpose(1, 2)
  23.         # (Batch_Size, H, Seq_Len, Dim / H) @ (Batch_Size, H, Dim / H, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
  24.         weight = q @ k.transpose(-1, -2)
  25.         if causal_mask:
  26.             # Mask where the upper triangle (above the principal diagonal) is 1
  27.             mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
  28.             # Fill the upper triangle with -inf
  29.             weight.masked_fill_(mask, -torch.inf)
  30.         # Divide by d_k (Dim / H).
  31.         # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
  32.         weight /= math.sqrt(self.d_head)
  33.         # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
  34.         weight = F.softmax(weight, dim=-1)
  35.         # (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
  36.         output = weight @ v
  37.         # (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)
  38.         output = output.transpose(1, 2)
  39.         # (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)
  40.         output = output.reshape(input_shape)
  41.         # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
  42.         output = self.out_proj(output)
  43.         return output
复制代码
其流程为:

首先,其是自注意力(Self-attention allows the model to relate words to each other )。即Q(query)K(key)V(value)在这都是同一个矩阵,即input,即归一化后的x。所以这里直接把x输入到self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)中,将其映射成三个矩阵,再在之后切割回q, k, v = self.in_proj(x).chunk(3, dim=-1),qkv即可,这里的qkv便是经过WQ,WQ,WV映射后得到的Q'K'V'。
之后通过传参n_heads设置head数,以及每头的维度,self.d_head = d_embed // n_heads,从而得到interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
随后使用view对qkv做切割为多头即可(因为其存储是一连的,所以可以直接切割):
注意还必要使用transpose改变维度的顺序,因为我们希望每个头都可以包含整个句子序列,包含每个单词的差别表达
  1.         # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
  2.         q = q.view(interim_shape).transpose(1, 2)
  3.         k = k.view(interim_shape).transpose(1, 2)
  4.         v = v.view(interim_shape).transpose(1, 2)
复制代码
然后是计算注意力的基础公式

首先来算QKT:
  1.         # (Batch_Size, H, Seq_Len, Dim / H) @ (Batch_Size, H, Dim / H, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
  2.         weight = q @ k.transpose(-1, -2)
复制代码
之后除以维度的square root:
  1.         # Divide by d_k (Dim / H).
  2.         # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
  3.         weight /= math.sqrt(self.d_head)
复制代码
接着,在使用softmax之间,会使用一个因果遮罩Causal_mask。它的意图为把weight矩阵的上三角部分人为的设置为负无穷。
  1.         if causal_mask:
  2.             # Mask where the upper triangle (above the principal diagonal) is 1
  3.             mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
  4.             # Fill the upper triangle with -inf
  5.             weight.masked_fill_(mask, -torch.inf)
复制代码
这一点可以由特殊的mask入手理解,即我们可以通过人为的设置mask哪些位置为1或0,以此来决定weight的哪些位置必要被设置为负无穷(或1e-9即可),比方:
  1.     if mask is not None:
  2.         # Write a very low value (indicating -inf) to the positions where mask == 0
  3.         attention_scores.masked_fill_(mask == 0, -1e9)
复制代码
这可以理解为一种,不希望哪两个token相互产生关联的一种做法。比方不想让长发与男子直接有相干性,我们直接让这两者的token直接计算出来的注意力score,在softmax前被赋值为负无穷即可。
因为softmax的公式为:

若每一项x为负无穷,即当前项在softmax之后会被强制接近0,即强制两者无关。
这里的因果遮罩,将上三角的注意力score强制设置为0,意为让模型获取不到当前像素与未来像素的相干性,即模型无法关注到未来的像素,只会关注当前像素与之前模型见到过的像素之间的关联度。【在训练中,如果模型可以访问未来时间步的信息,它大概直接记住这些信息,而不是学习如何正确预测。遮罩机制通过屏蔽未来时间步,防止这种 谋利取巧 的行为,让模型学会真正的推理本事】
所以遮罩处理完,使用softmax,并乘上V矩阵,之后同理把维度顺序换回去,即把多头cat回一起:
  1.         weight = F.softmax(weight, dim=-1)
  2.         # (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
  3.         output = weight @ v
  4.         # (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)
  5.         output = output.transpose(1, 2)
  6.         # (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)
  7.         output = output.reshape(input_shape)
复制代码
最后WO矩阵重新表达一次output即可:
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias):
  1.         # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
  2.         output = self.out_proj(output)
  3.         return output
复制代码
到此,已经得到了x = self.attention(x, causal_mask=True),返回之前的多头自注意力残差块,此时将注意力返回矩阵加上残差residue即可完成此残差块:
  1.         ### SELF-ATTENTION ###
  2.         residue = x
  3.         x = self.layernorm_1(x)
  4.         x = self.attention(x, causal_mask=True)
  5.         x += residue
复制代码
第二个残差块为
向前反馈残差块:

这个就比力轻易了,其的作用是对每个位置的特征进行进一步非线性转换,因为Self-Attention 机制通过计算sequence中每个位置之间的关系,捕获了全局信息,但它本身并不具有强大的非线性变换本事,这个feed forward残差层,通过将输入维度扩展到 4 倍,QuickGELU激活函数,再将维度从 4 倍还原到原来的维度,增强了模型的表达本事。
  1.         ### FEEDFORWARD LAYER ###
  2.         residue = x
  3.         x = self.layernorm_2(x)
  4.         # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, 4 * Dim)
  5.         x = self.linear_1(x)
  6.         x = x * torch.sigmoid(1.702 * x)   # QuickGELU activation function
  7.         # (Batch_Size, Seq_Len, 4 * Dim) -> (Batch_Size, Seq_Len, Dim)
  8.         x = self.linear_2(x)
  9.         x += residue
复制代码
其中:
  1.         self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
  2.         self.linear_2 = nn.Linear(4 * n_embd, n_embd)
复制代码
所以至此,两个残差块一起连用,即构成了CLIPLayer。而12个CLIPLayer串行使用,构成了:self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)]),且这里head数为12。
所以,回到CLIP类的定义中,我们已经实现到此了:串行应用所有CLIPLayer,即12个编码器一起使用,每个编码器又是由自注意力残差块以及向前反馈残差块构成的。最后对输出做一个层归一化。
  1.         # Apply encoder layers similar to the Transformer's encoder.
  2.         for layer in self.layers:
  3.             # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
  4.             state = layer(state)
  5.         # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
  6.         output = self.layernorm(state)
复制代码
这便是clip forward的所有内容。即将tokens,转换为了对应的嵌入表达。
对于不使用无分类器引导的模式,context直接就是tokenizer拆解prompt后输入clip后得到的值:
之后clip的使用到此告一段落,可以将其移至idle_device了。
  1.         else:
  2.             tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
  3.             tokens = torch.tensor(tokens, dtype=torch.long, device=device)
  4.             context = clip(tokens)
  5.         to_idle(clip)
复制代码
CLIP之后的模块是
VAE_Encoder: 

更具体的关于VAE的叙述因为内容实在太多,且数学叙述太多,本着本篇文章尽量少公式的想法,之后会单独叙述一篇VAE的文章。
  1.         if sampler_name == "ddpm":
  2.             sampler = DDPMSampler(generator)
  3.             sampler.set_inference_timesteps(n_inference_steps)
  4.         else:
  5.             raise ValueError("Unknown sampler value %s. ")
  6.         latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
  7.         if input_image:
  8.             encoder = models["encoder"]
  9.             encoder.to(device)
  10.             input_image_tensor = input_image.resize((WIDTH, HEIGHT))
  11.             input_image_tensor = np.array(input_image_tensor)
  12.             input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
  13.             input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
  14.             # (Height, Width, Channel) -> (Batch_Size, Height, Width, Channel)
  15.             input_image_tensor = input_image_tensor.unsqueeze(0)
  16.             # (Batch_Size, Height, Width, Channel) -> (Batch_Size, Channel, Height, Width)
  17.             input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
  18.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  19.             encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
  20.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  21.             latents = encoder(input_image_tensor, encoder_noise)
  22.             # Add noise to the latents (the encoded input image)
  23.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  24.             sampler.set_strength(strength=strength)
  25.             latents = sampler.add_noise(latents, sampler.timesteps[0])
  26.             to_idle(encoder)
  27.         else:
  28.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  29.             latents = torch.randn(latents_shape, generator=generator, device=device)
复制代码
 如果使用的是文生图,即无original image,那么上面的代码可以简化为如下代码,即直接采样高斯分布噪声作为latent represent:
  1.         latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
  2.         # (Batch_Size, 4, Latents_Height, Latents_Width)
  3.         latents = torch.randn(latents_shape, generator=generator, device=device)
复制代码
若要图生图,才会使用VAE_Encoder(代码开头的ddpm部分在反面叙述),先来看VAE_Encoder部分,首先读取encoder模型:
  1.         if input_image:
  2.             encoder = models["encoder"]
  3.             encoder.to(device)
复制代码
先是对original image的处理,包罗resize缩放其巨细至(WIDTH, HEIGHT)=(64, 64),并将其值缩放到(-1,1),并unsqueeze增加一个batch维度便于广播,并调换维度顺序,保证channel维度紧跟batch维度之后:
  1.             input_image_tensor = input_image.resize((WIDTH, HEIGHT))
  2.             input_image_tensor = np.array(input_image_tensor)
  3.             input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
  4.             input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
  5.             # (Height, Width, Channel) -> (Batch_Size, Height, Width, Channel)
  6.             input_image_tensor = input_image_tensor.unsqueeze(0)
  7.             # (Batch_Size, Height, Width, Channel) -> (Batch_Size, Channel, Height, Width)
  8.             input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
复制代码
然后要明白encoder的相干操纵,必要明白
encoder:

内部的操纵:
  1.             encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
  2.             latents = encoder(input_image_tensor, encoder_noise)
复制代码
VAE_Encoder代码为:
  1. class VAE_Encoder(nn.Sequential):
  2.     def __init__(self):
  3.         super().__init__(
  4.             # (Batch_Size, Channel, Height, Width) -> (Batch_Size, 128, Height, Width)
  5.             nn.Conv2d(3, 128, kernel_size=3, padding=1),
  6.             VAE_ResidualBlock(128, 128),
  7.             VAE_ResidualBlock(128, 128),
  8.             # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height / 2, Width / 2)
  9.             nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
  10.             # (Batch_Size, 128, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
  11.             VAE_ResidualBlock(128, 256),
  12.             VAE_ResidualBlock(256, 256),
  13.             # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 4, Width / 4)
  14.             nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
  15.             # (Batch_Size, 256, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
  16.             VAE_ResidualBlock(256, 512),
  17.             VAE_ResidualBlock(512, 512),
  18.             # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 8, Width / 8)
  19.             nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
  20.             VAE_ResidualBlock(512, 512),
  21.             VAE_ResidualBlock(512, 512),
  22.             VAE_ResidualBlock(512, 512),
  23.             VAE_AttentionBlock(512),
  24.             VAE_ResidualBlock(512, 512),
  25.             nn.GroupNorm(32, 512),
  26.             nn.SiLU(),
  27.             # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 8, Height / 8, Width / 8).
  28.             nn.Conv2d(512, 8, kernel_size=3, padding=1),
  29.             # (Batch_Size, 8, Height / 8, Width / 8) -> (Batch_Size, 8, Height / 8, Width / 8)
  30.             nn.Conv2d(8, 8, kernel_size=1, padding=0),
  31.         )
  32.     def forward(self, x, noise):
  33.         # x: (Batch_Size, Channel, Height, Width)
  34.         # noise: (Batch_Size, 4, Height / 8, Width / 8)
  35.         for module in self:
  36.             if getattr(module, 'stride', None) == (2, 2):  # Padding at downsampling should be asymmetric (see #8)
  37.                 x = F.pad(x, (0, 1, 0, 1))
  38.             x = module(x)
  39.         # (Batch_Size, 8, Height / 8, Width / 8) -> two tensors of shape (Batch_Size, 4, Height / 8, Width / 8)
  40.         mean, log_variance = torch.chunk(x, 2, dim=1)
  41.         # Clamp the log variance between -30 and 20, so that the variance is between (circa) 1e-14 and 1e8.
  42.         log_variance = torch.clamp(log_variance, -30, 20)
  43.         variance = log_variance.exp()
  44.         stdev = variance.sqrt()
  45.         # Transform N(0, 1) -> N(mean, stdev)
  46.         x = mean + stdev * noise
  47.         # Scale by a constant
  48.         x *= 0.18215
  49.         return x
复制代码
首先要明白一点,图像生成使命的本质目标是学习图像数据的分布,而因为SD引入了latent represent,且将latent variable建模为一个多变量的高斯分布,即VAR_Encoder本质上学习的就是两个值,一个是高斯分布的均值,另一个就是高斯分布的方差
如此一来,若我们想从latent space采样出一个latent,我们可以先从标准高斯分布中采样出一个标准latent样本,然后用VAE_Encoder学习到的mean与stdev行止理标准latent样本,即可得到从VAE高斯空间中采样出的latent样本了。因为将标准高斯分布 Z∼N(0,1) 转换为一个新的高斯分布 X∼N(μ,σ2)可以由如下线性变换得到:
所以,在此也就是把:
从标准高斯分布中采样得到的noise,也就是encoder_noise = torch.randn(latents_shape, generator=generator, device=device),latents = encoder(input_image_tensor, encoder_noise)这里的encoder_noise。然后如此处理即可:
  1.         # Transform N(0, 1) -> N(mean, stdev)
  2.         x = mean + stdev * noise
复制代码
最后的x *= 0.18215是出于工程现实考虑,保证训练稳定。
那么我们来看看这个mean和stdev是如何得到的。
主方向即串行遍历所有定义的模块,最终的输出可以被分割为latent space的mean和log_variance。再加上clamp强制区间范围,其取指数即可得到variance,开根号得到上述必要的stdev。
  1.         mean, log_variance = torch.chunk(x, 2, dim=1)
  2.         # Clamp the log variance between -30 and 20, so that the variance is between (circa) 1e-14 and 1e8.
  3.         log_variance = torch.clamp(log_variance, -30, 20)
  4.         variance = log_variance.exp()
  5.         stdev = variance.sqrt()
复制代码
那在这,为什么训练VAE_Encoder的时候,要选择学习log_variance,而不是直接学习variance呢? 因为variance是非负的,而log_variance是可以取全实数的,此时模型的参数空间更宽广,避免了对variance进行优化时大概出现的数值限制和不稳定环境。这是一个常见的技巧,可以增强模型的机动性。
那接下来就是VAE_Encoder的模块序列了:
  1. class VAE_Encoder(nn.Sequential):
  2.     def __init__(self):
  3.         super().__init__(
  4.             # (Batch_Size, Channel, Height, Width) -> (Batch_Size, 128, Height, Width)
  5.             nn.Conv2d(3, 128, kernel_size=3, padding=1),
  6.             VAE_ResidualBlock(128, 128),
  7.             VAE_ResidualBlock(128, 128),
  8.             # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height / 2, Width / 2)
  9.             nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
  10.             # (Batch_Size, 128, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
  11.             VAE_ResidualBlock(128, 256),
  12.             VAE_ResidualBlock(256, 256),
  13.             # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 4, Width / 4)
  14.             nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
  15.             # (Batch_Size, 256, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
  16.             VAE_ResidualBlock(256, 512),
  17.             VAE_ResidualBlock(512, 512),
  18.             # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 8, Width / 8)
  19.             nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
  20.             VAE_ResidualBlock(512, 512),
  21.             VAE_ResidualBlock(512, 512),
  22.             VAE_ResidualBlock(512, 512),
  23.             VAE_AttentionBlock(512),
  24.             VAE_ResidualBlock(512, 512),
  25.             nn.GroupNorm(32, 512),
  26.             nn.SiLU(),
  27.             # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 8, Height / 8, Width / 8).
  28.             nn.Conv2d(512, 8, kernel_size=3, padding=1),
  29.             # (Batch_Size, 8, Height / 8, Width / 8) -> (Batch_Size, 8, Height / 8, Width / 8)
  30.             nn.Conv2d(8, 8, kernel_size=1, padding=0),
  31.         )
复制代码
我认为学习该序列,捉住一些本质以及大局观即可:
首先,明白目标是要让图像尺寸变小。过程中是先增加通道,之后再淘汰通道。根本的过程就为,用Conv2d淘汰尺寸,之后跟上两个残差卷积块增加深度但保持尺寸。
具体而言,第一个Conv2d,改变通道,保持尺寸,然后两个残差,不改变尺寸。
之后Conv2d的用途为保持通道,减半尺寸,然后两个残差,一个用于增多通道,一个保持。直到通道达到512,此时后一个Conv2d将尺寸变为 Height / 8, Width / 8。之后引入三个残差保持尺寸。
之后引入了注意力模块,接一个残差,然后有组归一化GroupNorm【组归一化其类似于layer normalization。但是并不是所有特征共用一个mean与variance。而是将feature分组,由每组自身的mean与variance对改组进行normalization】以及SiLU激活。
最后两个卷积,一个淘汰通道至8,保持尺寸,另一个保持通道8以及尺寸,作为最后一层的表达。
一些具体实现的分析:
首先遍历module的代码为:
  1.     def forward(self, x, noise):
  2.         # x: (Batch_Size, Channel, Height, Width)
  3.         # noise: (Batch_Size, 4, Height / 8, Width / 8)
  4.         for module in self:
  5.             if getattr(module, 'stride', None) == (2, 2):  # Padding at downsampling should be asymmetric (see #8)
  6.                 x = F.pad(x, (0, 1, 0, 1))
  7.             x = module(x)
复制代码
可以看出,stride=2的Conv2d会在卷积前对特征图进行非对称的添补,以保证尺寸能正常减半。
其次,其中的注意力模块,在之前叙述过,在这里,其就是一个单头的自注意力的实现:
  1. class VAE_AttentionBlock(nn.Module):
  2.     def __init__(self, channels):
  3.         super().__init__()
  4.         self.groupnorm = nn.GroupNorm(32, channels)
  5.         self.attention = SelfAttention(1, channels)
  6.     def forward(self, x):
  7.         # x: (Batch_Size, Features, Height, Width)
  8.         residue = x
  9.         x = self.groupnorm(x)
  10.         n, c, h, w = x.shape
  11.         # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)
  12.         x = x.view((n, c, h * w))
  13.         # (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features).
  14.         # Each pixel becomes a feature of size "Features", the sequence length is "Height * Width".
  15.         x = x.transpose(-1, -2)
  16.         x = self.attention(x)
  17.         # (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
  18.         x = x.transpose(-1, -2)
  19.         # (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)
  20.         x = x.view((n, c, h, w))
  21.         x += residue
  22.         return x
复制代码
这里的residual模块也比力简朴,即包含两层组归一化、silu激活、卷积。以及典范的保证残差连接正常(即residue与最后的x通道维度保持一致)的residual_layer。
  1. class VAE_ResidualBlock(nn.Module):
  2.     def __init__(self, in_channels, out_channels):
  3.         super().__init__()
  4.         self.groupnorm_1 = nn.GroupNorm(32, in_channels)
  5.         self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
  6.         self.groupnorm_2 = nn.GroupNorm(32, out_channels)
  7.         self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
  8.         if in_channels == out_channels:
  9.             self.residual_layer = nn.Identity()
  10.         else:
  11.             self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
  12.     def forward(self, x):
  13.         residue = x
  14.         x = self.groupnorm_1(x)
  15.         x = F.silu(x)
  16.         x = self.conv_1(x)
  17.         x = self.groupnorm_2(x)
  18.         x = F.silu(x)
  19.         x = self.conv_2(x)
  20.         return x + self.residual_layer(residue)
复制代码
到此我们实现到了pipeline的这一步:
  1.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  2.             encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
  3.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  4.             latents = encoder(input_image_tensor, encoder_noise)
复制代码
在图文图的环境下,已经将original image转换为latents,现在我们必要向其添加噪声了。

  1.             # Add noise to the latents (the encoded input image)
  2.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  3.             sampler.set_strength(strength=strength)
  4.             latents = sampler.add_noise(latents, sampler.timesteps[0])
复制代码
 而这里的sampler_name也是generate的一个传参,且:
  1.         if sampler_name == "ddpm":
  2.             sampler = DDPMSampler(generator)
  3.             sampler.set_inference_timesteps(n_inference_steps)
  4.         else:
  5.             raise ValueError("Unknown sampler value %s. ")
复制代码
那我们首先得来看这个
DDPMSampler:

扩散模型的根本原理是将一个真实的图像逐渐加上噪声,直到它变成纯噪声,然后再通过反向过程(逆向扩散)渐渐去除噪声,恢复出原始图像。DDPMSampler类实现了这个过程。
我们先给出完整的DDPMSampler类的代码:
  1. class DDPMSampler:
  2.     def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
  3.         # Params "beta_start" and "beta_end" taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8
  4.         # For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)
  5.         self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
  6.         self.alphas = 1.0 - self.betas
  7.         self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
  8.         self.one = torch.tensor(1.0)
  9.         self.generator = generator
  10.         self.num_train_timesteps = num_training_steps
  11.         self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
  12.         # 是否这样即可 self.timesteps = torch.arange(num_training_steps - 1, -1, -1)  # 生成倒序张量
  13.     def set_inference_timesteps(self, num_inference_steps=50):
  14.         self.num_inference_steps = num_inference_steps
  15.         step_ratio = self.num_train_timesteps // self.num_inference_steps
  16.         self.timesteps = (torch.arange(num_inference_steps - 1, -1, -1) * step_ratio).long()
  17.     def _get_previous_timestep(self, timestep: int) -> int:
  18.         prev_t = timestep - self.num_train_timesteps // self.num_inference_steps
  19.         return prev_t
  20.    
  21.     def _get_variance(self, timestep: int) -> torch.Tensor:
  22.         prev_t = self._get_previous_timestep(timestep)
  23.         alpha_prod_t = self.alphas_cumprod[timestep]
  24.         alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
  25.         current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
  26.         # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
  27.         # and sample from it to get previous sample
  28.         # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
  29.         variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
  30.         # we always take the log of variance, so clamp it to ensure it's not 0
  31.         variance = torch.clamp(variance, min=1e-20)
  32.         return variance
  33.    
  34.     def set_strength(self, strength=1):
  35.         """
  36.             Set how much noise to add to the input image.
  37.             More noise (strength ~ 1) means that the output will be further from the input image.
  38.             Less noise (strength ~ 0) means that the output will be closer to the input image.
  39.         """
  40.         # start_step is the number of noise levels to skip
  41.         start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
  42.         self.timesteps = self.timesteps[start_step:]
  43.         self.start_step = start_step
  44.     def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
  45.         t = timestep
  46.         prev_t = self._get_previous_timestep(t)
  47.         # 1. compute alphas, betas
  48.         alpha_prod_t = self.alphas_cumprod[t]
  49.         alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
  50.         beta_prod_t = 1 - alpha_prod_t
  51.         beta_prod_t_prev = 1 - alpha_prod_t_prev
  52.         current_alpha_t = alpha_prod_t / alpha_prod_t_prev
  53.         current_beta_t = 1 - current_alpha_t
  54.         # 2. compute predicted original sample from predicted noise also called
  55.         # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
  56.         pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
  57.         # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
  58.         # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
  59.         pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
  60.         current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
  61.         # 5. Compute predicted previous sample µ_t
  62.         # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
  63.         pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
  64.         # 6. Add noise
  65.         variance = 0
  66.         if t > 0:
  67.             noise = torch.randn(model_output.shape, generator=self.generator, device=model_output.device, dtype=model_output.dtype)
  68.             variance = (self._get_variance(t) ** 0.5) * noise
  69.         pred_prev_sample = pred_prev_sample + variance
  70.         return pred_prev_sample
  71.    
  72.     def add_noise(self,original_samples: torch.FloatTensor,timesteps: torch.IntTensor,) -> torch.FloatTensor:
  73.         alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
  74.         timesteps = timesteps.to(original_samples.device)
  75.         sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
  76.         sqrt_alpha_prod = sqrt_alpha_prod.flatten()
  77.         while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
  78.             sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
  79.         sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
  80.         sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
  81.         while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
  82.             sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
  83.         # Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
  84.         # Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
  85.         # here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
  86.         noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
  87.         noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
  88.         return noisy_samples
复制代码
我们先不看其他方法,来看看init中创建了什么:
  1.     def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
  2.         # Params "beta_start" and "beta_end" taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8
  3.         # For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)
  4.         self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
  5.         self.alphas = 1.0 - self.betas
  6.         self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
  7.         self.one = torch.tensor(1.0)
  8.         self.generator = generator
  9.         self.num_train_timesteps = num_training_steps
  10.         self.timesteps = torch.arange(num_training_steps - 1, -1, -1)  # 生成倒序张量
复制代码
这里的β是指,DDPM forward时每一步所加的噪声的方差,如DDPM论文中所述:

而这种β的序列,在确定起始β与竣事β后,又有差别的变化schedule可以选,如余弦,线性等schedule,这里使用linear schedule。因为设定了一共处理多少步,再给定起始β与竣事β,即可得到β schedule:
这里使用了num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120
而现实上,这个schedule是针对标准差的:
  1. self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
复制代码
为什么前期加的噪声小,后前加的噪声多?因为噪声较小,模型可以较轻易地学习到数据的布局,后期加噪多是为了加快扩散过程,使得模型可以或许更快地从噪声中恢复出清晰的数据。
而在真正使用时,我们会引入α(这里累乘操纵,用torch.cumprod可以轻松实现):

在引入α之后,我们就不再必要一步一步按顺序用β来处理图像了,比方加噪我们可以直接:

之后我们还是只看目前遇到的必要用的方法,sampler.set_inference_timesteps(n_inference_steps):
  1.     def set_inference_timesteps(self, num_inference_steps=50):
  2.         self.num_inference_steps = num_inference_steps
  3.         step_ratio = self.num_train_timesteps // self.num_inference_steps
  4.         self.timesteps = (torch.arange(num_inference_steps - 1, -1, -1) * step_ratio).long()
复制代码
这个方法的意图为重新创建timesteps。重要目标是在推理(或生成)阶段设置扩散过程的时间步数,从而控制反向扩散过程的细节和生成图像的质量。相当于是,训练的时候模型学的是怎么样去噪1000次,而得到无噪声图像,但是我们现实在推理时,不必要这么多步就可以得到很好的结果。以推理50步来生成为例,我们只必要让模型对图像进行第1000次去噪,第980次去噪,第960次去噪...直到第0次去噪即可,而不必要第999,998,997等次的去噪。
之后我们遇到了sampler.set_strength和sampler.add_noise。
  1.             # Add noise to the latents (the encoded input image)
  2.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  3.             sampler.set_strength(strength=strength)
  4.             latents = sampler.add_noise(latents, sampler.timesteps[0])
复制代码
sampler.set_strength其的寄义是,如果我不希望模型的输出太偏离我给定的original image,那么我当然只能给original image的latent少加点noise,淘汰去噪扩撒推理时的机动性。
如果没有set_strength,之前我们的timestep是[50,49,48..0]->[1000,980,960…0],也就是一开始的噪声加的是time1000时的噪声,这个噪声很大。
而若此时set_strength处理:
  1.     def set_strength(self, strength=1):
  2.         """
  3.             Set how much noise to add to the input image.
  4.             More noise (strength ~ 1) means that the output will be further from the input image.
  5.             Less noise (strength ~ 0) means that the output will be closer to the input image.
  6.         """
  7.         # start_step is the number of noise levels to skip
  8.         start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
  9.         self.timesteps = self.timesteps[start_step:]
  10.         self.start_step = start_step
复制代码
若strength=0.8,则其把[50,49,48..0]->[40,39,38…0]
也就是[40,39,38…0]*(1000//50) == [800,780,760…0]
也就是现在,此步骤中加上的噪声是time800的噪声,噪声更少了。
因为注意到latents = sampler.add_noise(latents, sampler.timesteps[0]),为latent加上的噪声是与timesteps[0]对应的。那接下来再来看看add_noise方法:
  1.     def add_noise(self,original_samples: torch.FloatTensor,timesteps: torch.IntTensor,) -> torch.FloatTensor:
  2.         alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
  3.         timesteps = timesteps.to(original_samples.device)
  4.         sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
  5.         sqrt_alpha_prod = sqrt_alpha_prod.flatten()
  6.         while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
  7.             sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
  8.         sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
  9.         sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
  10.         while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
  11.             sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
  12.         # Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
  13.         # Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
  14.         # here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
  15.         noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
  16.         noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
  17.         return noisy_samples
复制代码
首先由之前加噪声的闭式解可以看到,其是对原始图像加上一个噪声,该噪声是从对应的高斯分布中采样得来的。【现实上不是从目标高斯分布中直接采样,而是采样标准高斯分布,用均值与标准差将其处理为目标高斯分布】。
然后,先首先保证设备一致
  1.         alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
  2.         timesteps = timesteps.to(original_samples.device)
复制代码
之后得到均值和方差,且用while、unsqueeze,添加维度使之其与要融合的原始图像original_samples【注意这里用的是闭式解,是直接对原始图像加噪声,而不是马尔科夫链对上一步的输出图像加噪声】的维度数一致,保证可以正常广播:
  1.         sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
  2.         sqrt_alpha_prod = sqrt_alpha_prod.flatten()
  3.         while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
  4.             sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
  5.         sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
  6.         sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
  7.         while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
  8.             sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
  9.         # Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
  10.         # Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
  11.         # here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
  12.         noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
  13.         noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
  14.         return noisy_samples
复制代码
到此,我们已经完成了VAE_Encoder的编写了,其实现了(若是图生图)将original image映射到latent,并向其中添加可控巨细的噪声,从而控制diffusion的机动性。
接下来老样子,把encoder转到idle_device,且若只要文生图的话,直接采样高斯噪声作为latent即可,下面重新回顾一遍encoder部分的代码吧:
  1.         if sampler_name == "ddpm":
  2.             sampler = DDPMSampler(generator)
  3.             sampler.set_inference_timesteps(n_inference_steps)
  4.         else:
  5.             raise ValueError("Unknown sampler value %s. ")
  6.         latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
  7.         if input_image:
  8.             encoder = models["encoder"]
  9.             encoder.to(device)
  10.             input_image_tensor = input_image.resize((WIDTH, HEIGHT))
  11.             input_image_tensor = np.array(input_image_tensor)
  12.             input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
  13.             input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
  14.             # (Height, Width, Channel) -> (Batch_Size, Height, Width, Channel)
  15.             input_image_tensor = input_image_tensor.unsqueeze(0)
  16.             # (Batch_Size, Height, Width, Channel) -> (Batch_Size, Channel, Height, Width)
  17.             input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
  18.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  19.             encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
  20.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  21.             latents = encoder(input_image_tensor, encoder_noise)
  22.             # Add noise to the latents (the encoded input image)
  23.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  24.             sampler.set_strength(strength=strength)
  25.             latents = sampler.add_noise(latents, sampler.timesteps[0])
  26.             to_idle(encoder)
  27.         else:
  28.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  29.             latents = torch.randn(latents_shape, generator=generator, device=device)
复制代码
那接下来就到了
Diffusion:

完整代码为:
  1.         diffusion = models["diffusion"]
  2.         diffusion.to(device)
  3.         timesteps = tqdm(sampler.timesteps)
  4.         for i, timestep in enumerate(timesteps):
  5.             # (1, 320)
  6.             time_embedding = get_time_embedding(timestep).to(device)
  7.             # (Batch_Size, 4, Latents_Height, Latents_Width)
  8.             model_input = latents
  9.             if do_cfg:
  10.                 # (Batch_Size, 4, Latents_Height, Latents_Width) -> (2 * Batch_Size, 4, Latents_Height, Latents_Width)
  11.                 model_input = model_input.repeat(2, 1, 1, 1)
  12.             # model_output is the predicted noise
  13.             # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
  14.             model_output = diffusion(model_input, context, time_embedding)
  15.             if do_cfg:
  16.                 output_cond, output_uncond = model_output.chunk(2)
  17.                 model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
  18.             # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
  19.             latents = sampler.step(timestep, latents, model_output)
  20.         to_idle(diffusion)
复制代码
首先老样子读取diffusion的模型,之后将timesteps创建一个进度条,来显示当前的diffusion进度。
  1.         diffusion = models["diffusion"]
  2.         diffusion.to(device)
  3.         timesteps = tqdm(sampler.timesteps)
复制代码
之后就是遍历timesteps,一步步对输入的latent进行去噪:
每一步去噪,有如下几个小步骤,首先,要明白当前的time是第几步,这里会对timestep进行一个time embedding:

   这里的time embedding采用的是transformer中类似position embedding的余弦正弦编码嵌入,这里不过多表明了:   
  1. def get_time_embedding(timestep):
  2.     # Shape: (160,)
  3.     freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
  4.     # Shape: (1, 160)
  5.     x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
  6.     # Shape: (1, 160 * 2)
  7.     return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
复制代码
得到time embedding之后,因为如果要do_cfg的话,之前其的context是有prompt的latent和无prompt的context cat到一起的,所以这里为了匹配维度,将其可以一起计算,所以这个VAE_Encoder的输出,在do_cfg的模式下,必要被repeat:
  1.             model_input = latents
  2.             if do_cfg:
  3.                 # (Batch_Size, 4, Latents_Height, Latents_Width) -> (2 * Batch_Size, 4, Latents_Height, Latents_Width)
  4.                 model_input = model_input.repeat(2, 1, 1, 1)
复制代码
之后就是把model_input,context以及time_embedding一起输入给diffusion,让其预测当前图像,在当前时间步时,应该被去掉的噪声是什么样的。
  1.             # model_output is the predicted noise
  2.             # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
  3.             model_output = diffusion(model_input, context, time_embedding)
复制代码
那么接下来来看看
  diffusion(UNET):

  来看diffusion类的定义:
  1. class Diffusion(nn.Module):
  2.     def __init__(self):
  3.         super().__init__()
  4.         self.time_embedding = TimeEmbedding(320)
  5.         self.unet = UNET()
  6.         self.final = UNET_OutputLayer(320, 4)
  7.    
  8.     def forward(self, latent, context, time):
  9.         # latent: (Batch_Size, 4, Height / 8, Width / 8)
  10.         # context: (Batch_Size, Seq_Len, Dim)
  11.         # time: (1, 320)
  12.         # (1, 320) -> (1, 1280)
  13.         time = self.time_embedding(time)
  14.         # (Batch, 4, Height / 8, Width / 8) -> (Batch, 320, Height / 8, Width / 8)
  15.         output = self.unet(latent, context, time)
  16.         # (Batch, 320, Height / 8, Width / 8) -> (Batch, 4, Height / 8, Width / 8)
  17.         output = self.final(output)
  18.         # (Batch, 4, Height / 8, Width / 8)
  19.         return output
复制代码
首先,之前已经对当前timestep做了一次正余弦的编码了:
  1. time_embedding = get_time_embedding(timestep).to(device)
复制代码
这里又把time_embedding再次输入到TimeEmbedding函数中继续映射,代码为:
  1. class TimeEmbedding(nn.Module):
  2.     def __init__(self, n_embd):
  3.         super().__init__()
  4.         self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
  5.         self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
  6.     def forward(self, x):
  7.         # x: (1, 320)
  8.         # (1, 320) -> (1, 1280)
  9.         x = self.linear_1(x)
  10.         x = F.silu(x)
  11.         x = self.linear_2(x)
  12.         return x
复制代码
因为类似于transformer的位置编码,第一步time embedding是固定值,这里在使用前,重新来一次可学习的编码,用两层全连接以及silu激活,以获得更有效的时间嵌入表达。
  得到time表达后,就正式进入UNET:
  首先明白UNET的目标是根据当前latent以实时间步以及prompt,预测此时必要被去除的噪声。
  

  观察UNET的流程图可以发现,其还可以被分为三个小部分,encoder,bottleneck,decoder。
  其次另有一点显著的特点,每一层对应的encoder和decoder之间都有残差连接。
  所以UNET的forward可以如下:
  1.     def forward(self, x, context, time):
  2.         # x: (Batch_Size, 4, Height / 8, Width / 8)
  3.         # context: (Batch_Size, Seq_Len, Dim)
  4.         # time: (1, 1280)
  5.         skip_connections = []
  6.         for layers in self.encoders:
  7.             x = layers(x, context, time)
  8.             skip_connections.append(x)
  9.         x = self.bottleneck(x, context, time)
  10.         for layers in self.decoders:
  11.             # Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer
  12.             x = torch.cat((x, skip_connections.pop()), dim=1)
  13.             x = layers(x, context, time)
  14.         return x
复制代码
接着来看encoders:
  1.         self.encoders = nn.ModuleList([
  2.             # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
  3.             SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
  4.             SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
  5.             SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
  6.             # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)
  7.             SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
  8.             # (Batch_Size, 320, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
  9.             SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
  10.             SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
  11.             # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)
  12.             SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
  13.             # (Batch_Size, 640, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
  14.             SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
  15.             SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
  16.             # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)
  17.             SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
  18.             SwitchSequential(UNET_ResidualBlock(1280, 1280)),
  19.             SwitchSequential(UNET_ResidualBlock(1280, 1280)),
  20.         ])
复制代码
注意到SwitchSequential,其实其很简朴,就相当于为当前函数设置输入参数:
  1. class SwitchSequential(nn.Sequential):
  2.     def forward(self, x, context, time):
  3.         for layer in self:
  4.             if isinstance(layer, UNET_AttentionBlock):
  5.                 x = layer(x, context)
  6.             elif isinstance(layer, UNET_ResidualBlock):
  7.                 x = layer(x, time)
  8.             else:
  9.                 x = layer(x)
  10.         return x
复制代码
这里的residual,代码为:
  1. class UNET_ResidualBlock(nn.Module):
  2.     def __init__(self, in_channels, out_channels, n_time=1280):
  3.         super().__init__()
  4.         self.groupnorm_feature = nn.GroupNorm(32, in_channels)
  5.         self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
  6.         self.linear_time = nn.Linear(n_time, out_channels)
  7.         self.groupnorm_merged = nn.GroupNorm(32, out_channels)
  8.         self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
  9.         if in_channels == out_channels:
  10.             self.residual_layer = nn.Identity()
  11.         else:
  12.             self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
  13.    
  14.     def forward(self, feature, time):
  15.         # feature: (Batch_Size, In_Channels, Height, Width)
  16.         # time: (1, 1280)
  17.         residue = feature
  18.         feature = self.groupnorm_feature(feature)
  19.         feature = F.silu(feature)
  20.         # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
  21.         feature = self.conv_feature(feature)
  22.         time = F.silu(time)
  23.         # (1, 1280) -> (1, Out_Channels)
  24.         time = self.linear_time(time)
  25.         # Add width and height dimension to time.
  26.         # 广播(Batch_Size, Out_Channels, Height, Width) + (1, Out_Channels, 1, 1) -> (Batch_Size, Out_Channels, Height, Width)
  27.         merged = feature + time.unsqueeze(-1).unsqueeze(-1)
  28.         merged = self.groupnorm_merged(merged)
  29.         merged = F.silu(merged)
  30.         merged = self.conv_merged(merged)
  31.         return merged + self.residual_layer(residue)
复制代码
首先其将特征图保留residue,对齐组归一化、silu激活,并进行卷积。之后对time embedding做一个全连接层的变换表达,将其维度扩展到至于feature相同后,将两个相加混淆在一起。再用组归一化、silu激活,以及再一个卷积操纵得到merged。最后merged与残差residue=feature相加,作为该模块的返回。相当于是该模块实现了特征与时间步的混淆。
  另有一个为attention,代码为:
  1. class UNET_AttentionBlock(nn.Module):
  2.     def __init__(self, n_head: int, n_embd: int, d_context=768):
  3.         super().__init__()
  4.         channels = n_head * n_embd
  5.         
  6.         self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
  7.         self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
  8.         self.layernorm_1 = nn.LayerNorm(channels)
  9.         self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
  10.         self.layernorm_2 = nn.LayerNorm(channels)
  11.         self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
  12.         self.layernorm_3 = nn.LayerNorm(channels)
  13.         self.linear_geglu_1  = nn.Linear(channels, 4 * channels * 2)
  14.         self.linear_geglu_2 = nn.Linear(4 * channels, channels)
  15.         self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
  16.    
  17.     def forward(self, x, context):
  18.         # x: (Batch_Size, Features, Height, Width)
  19.         # context: (Batch_Size, Seq_Len, Dim)
  20.         residue_long = x
  21.         # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
  22.         x = self.groupnorm(x)
  23.         # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
  24.         x = self.conv_input(x)
  25.         n, c, h, w = x.shape
  26.         # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features)
  27.         x = x.view((n, c, h * w)).transpose(-1, -2)
  28.         # Normalization + Self-Attention with skip connection
  29.         # (Batch_Size, Height * Width, Features)
  30.         residue_short = x
  31.         # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  32.         x = self.layernorm_1(x)
  33.         # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  34.         x = self.attention_1(x)
  35.         # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  36.         x += residue_short
  37.         # (Batch_Size, Height * Width, Features)
  38.         residue_short = x
  39.         # Normalization + Cross-Attention with skip connection
  40.         # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  41.         x = self.layernorm_2(x)
  42.         # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  43.         x = self.attention_2(x, context)
  44.         # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  45.         x += residue_short
  46.         # (Batch_Size, Height * Width, Features)
  47.         residue_short = x
  48.         # Normalization + FFN with GeGLU and skip connection
  49.         # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  50.         x = self.layernorm_3(x)
  51.         # GeGLU as implemented in the original code: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/attention.py#L37C10-L37C10
  52.         # (Batch_Size, Height * Width, Features) -> two tensors of shape (Batch_Size, Height * Width, Features * 4)
  53.         x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
  54.         # Element-wise product: (Batch_Size, Height * Width, Features * 4) * (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features * 4)
  55.         x = x * F.gelu(gate)
  56.         # (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features)
  57.         x = self.linear_geglu_2(x)
  58.         # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  59.         x += residue_short
  60.         # (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
  61.         x = x.transpose(-1, -2)
  62.         # (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)
  63.         x = x.view((n, c, h, w))
  64.         # Final skip connection between initial input and output of the block
  65.         # (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
  66.         return self.conv_output(x) + residue_long
复制代码
首先一个前置的卷积,先把输入进来的x进行一次卷积重表达。
  然后一个小的残差块,主体操纵是做自注意力:
  1.         residue_short = x
  2.         # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  3.         x = self.layernorm_1(x)
  4.         # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  5.         x = self.attention_1(x)
  6.         # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  7.         x += residue_short
复制代码
attention_1为无偏置的8头自注意力块,其使上一步中feature与time的混淆进一步内部消化,自己找出相干性:
  1. self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
复制代码
之后一个残差块,主体操纵是做交叉注意力:
  1.         residue_short = x
  2.         # Normalization + Cross-Attention with skip connection
  3.         # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  4.         x = self.layernorm_2(x)
  5.         # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  6.         x = self.attention_2(x, context)
  7.         # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
  8.         x += residue_short
复制代码
attention_2为无偏置的8头交叉注意力块:
  1. self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
复制代码
其内部为:
  1.     def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
  2.         super().__init__()
  3.         self.q_proj   = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
  4.         self.k_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
  5.         self.v_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
  6.         self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
  7.         self.n_heads = n_heads
  8.         self.d_head = d_embed // n_heads
  9.    
  10.     def forward(self, x, y):
  11.         # x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)
  12.         # y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)
  13.         input_shape = x.shape
  14.         batch_size, sequence_length, d_embed = input_shape
  15.         # Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Q
  16.         interim_shape = (batch_size, -1, self.n_heads, self.d_head)
  17.         # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
  18.         q = self.q_proj(x)
  19.         # (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
  20.         k = self.k_proj(y)
  21.         # (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
  22.         v = self.v_proj(y)
  23.         # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
  24.         q = q.view(interim_shape).transpose(1, 2)
  25.         # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
  26.         k = k.view(interim_shape).transpose(1, 2)
  27.         # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
  28.         v = v.view(interim_shape).transpose(1, 2)
  29.         # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
  30.         weight = q @ k.transpose(-1, -2)
  31.         # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
  32.         weight /= math.sqrt(self.d_head)
  33.         # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
  34.         weight = F.softmax(weight, dim=-1)
  35.         # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
  36.         output = weight @ v
  37.         # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)
  38.         output = output.transpose(1, 2).contiguous()
  39.         # (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)
  40.         output = output.view(input_shape)
  41.         # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
  42.         output = self.out_proj(output)
  43.         # (Batch_Size, Seq_Len_Q, Dim_Q)
  44.         return output
复制代码
团体和自注意力都差不多,只是这里,上一个自注意力的输出(也就是上一个残差块在融合feature与time embedding之后其自身的自注意力矩阵)仅仅作为Q矩阵query查询,而之前由prompt得到的context会在此作为k与v,即key键与value值。相当于计算了当前latent在当前timestep,与prompt描述的目标之间的相干性。
  这便构成了UNET中的encoders,再来看看中间的bottleneck,很简便,即融合时间-计算自注意力-融合时间:
  1.         self.bottleneck = SwitchSequential(
  2.             UNET_ResidualBlock(1280, 1280),
  3.             UNET_AttentionBlock(8, 160),
  4.             UNET_ResidualBlock(1280, 1280),
  5.         )
复制代码
之后便是进入UET的decoder:
  1. self.decoders = nn.ModuleList([
  2.             # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
  3.             SwitchSequential(UNET_ResidualBlock(2560, 1280)),
  4.             # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
  5.             SwitchSequential(UNET_ResidualBlock(2560, 1280)),
  6.             # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 32, Width / 32)
  7.             SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
  8.             # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
  9.             SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
  10.             # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
  11.             SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
  12.             # (Batch_Size, 1920, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 16, Width / 16)
  13.             SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
  14.             # (Batch_Size, 1920, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
  15.             SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
  16.             # (Batch_Size, 1280, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
  17.             SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
  18.             # (Batch_Size, 960, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)
  19.             SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
  20.             # (Batch_Size, 960, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
  21.             SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
  22.             # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
  23.             SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
  24.             # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
  25.             SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
  26.         ])
复制代码
模组都对称,只不过这里有一个新的Upsample模块:
  1. class Upsample(nn.Module):
  2.     def __init__(self, channels):
  3.         super().__init__()
  4.         self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
  5.    
  6.     def forward(self, x):
  7.         # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * 2, Width * 2)
  8.         x = F.interpolate(x, scale_factor=2, mode='nearest')
  9.         return self.conv(x)
复制代码
其实现也非常简朴,简朴的使用interpolate,做nearest模式的插值,使特征图尺寸翻倍,之后用一个卷积增强表达即可。
  到此,已经完成了如下代码:
  1.         output = self.unet(latent, context, time)
复制代码
此时维度与所需不符,必要来一个final层控制输出维度:
  1.         # (Batch, 320, Height / 8, Width / 8) -> (Batch, 4, Height / 8, Width / 8)
  2.         output = self.final(output)
复制代码
final为:
  1. self.final = UNET_OutputLayer(320, 4)
复制代码
  1. class UNET_OutputLayer(nn.Module):
  2.     def __init__(self, in_channels, out_channels):
  3.         super().__init__()
  4.         self.groupnorm = nn.GroupNorm(32, in_channels)
  5.         self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
  6.    
  7.     def forward(self, x):
  8.         # x: (Batch_Size, 320, Height / 8, Width / 8)
  9.         # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
  10.         x = self.groupnorm(x)
  11.         # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
  12.         x = F.silu(x)
  13.         # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
  14.         x = self.conv(x)
  15.         # (Batch_Size, 4, Height / 8, Width / 8)
  16.         return x
复制代码
其就是简朴的组归一化后silu激活,用一层卷积来改变通道数即可。
  到此我们正式完成了当前噪声的预测:
  1. model_output = diffusion(model_input, context, time_embedding)
复制代码
但是若要do_cfg,因为输入相当于batch为2,一个是用prompt一个是不消prompt,且最终输出为:
  

  所以会有如下处理:
  1.             if do_cfg:
  2.                 output_cond, output_uncond = model_output.chunk(2)
  3.                 model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
复制代码
至此model_output真正得到了对当前latent中噪声的预测。
  接着就是对latent去掉噪声model_output:
  1. latents = sampler.step(timestep, latents, model_output)
复制代码
我们来看DDPMSampler的step:
  1.     def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
  2.         t = timestep
  3.         prev_t = self._get_previous_timestep(t)
  4.         # 1. compute alphas, betas
  5.         alpha_prod_t = self.alphas_cumprod[t]
  6.         alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
  7.         beta_prod_t = 1 - alpha_prod_t
  8.         beta_prod_t_prev = 1 - alpha_prod_t_prev
  9.         current_alpha_t = alpha_prod_t / alpha_prod_t_prev
  10.         current_beta_t = 1 - current_alpha_t
  11.         # 2. compute predicted original sample from predicted noise also called
  12.         # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
  13.         pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
  14.         # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
  15.         # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
  16.         pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
  17.         current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
  18.         # 5. Compute predicted previous sample µ_t
  19.         # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
  20.         pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
  21.         # 6. Add noise
  22.         variance = 0
  23.         if t > 0:
  24.             noise = torch.randn(model_output.shape, generator=self.generator, device=model_output.device, dtype=model_output.dtype)
  25.             variance = (self._get_variance(t) ** 0.5) * noise
  26.         pred_prev_sample = pred_prev_sample + variance
  27.         return pred_prev_sample
复制代码
首先我们来明白去噪的公式,首先论文中提出了两种公式来去噪:
  第一种:
  

  第二种为:
  

  其中这里的x0不是真实的original image,而是我们预测的其大概的原始图像,且有如下公式:
  

  所以可以理解为去噪后的latent是目前的latent与预测的原始无噪声latent的之间的线性组合。
  我们在这选择第二种实现方法:
  首先我们要得到上一个时候t-1时的α相干值:
  1.         t = timestep
  2.         prev_t = self._get_previous_timestep(t)
复制代码
  1. def _get_previous_timestep(self, timestep: int) -> int:
  2.     prev_t = timestep - self.num_train_timesteps // self.num_inference_steps
  3.     return prev_t
复制代码
  1.         # 1. compute alphas, betas
  2.         alpha_prod_t = self.alphas_cumprod[t]
  3.         alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
  4.         beta_prod_t = 1 - alpha_prod_t
  5.         beta_prod_t_prev = 1 - alpha_prod_t_prev
  6.         current_alpha_t = alpha_prod_t / alpha_prod_t_prev
  7.         current_beta_t = 1 - current_alpha_t
复制代码
然后我们必要取计算 预测的无噪声的latent,直接带公式即可:
  1.         pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
复制代码
接着因为是原始与当前的线性组合,两者的系数必要计算:
 
  1.         pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
  2.         current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
复制代码
  1.         pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
复制代码
要注意这只是去噪后的latent的mean的表达,最后另有加上variance,所以:
  1.         # 6. Add noise
  2.         variance = 0
  3.         if t > 0:
  4.             noise = torch.randn(model_output.shape, generator=self.generator, device=model_output.device, dtype=model_output.dtype)
  5.             variance = (self._get_variance(t) ** 0.5) * noise
  6.         pred_prev_sample = pred_prev_sample + variance
  7.         return pred_prev_sample
复制代码
其中get_variance为:
  1.     def _get_variance(self, timestep: int) -> torch.Tensor:
  2.         prev_t = self._get_previous_timestep(timestep)
  3.         alpha_prod_t = self.alphas_cumprod[timestep]
  4.         alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
  5.         current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
  6.         # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
  7.         # and sample from it to get previous sample
  8.         # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
  9.         variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
  10.         # we always take the log of variance, so clamp it to ensure it's not 0
  11.         variance = torch.clamp(variance, min=1e-20)
  12.         return variance
复制代码
其也是按公式创建即可,且最终用clamp也最小值限制在1e-20.
  所以此时我们已经得到了pred_prev_sample,所以循环
  1.         timesteps = tqdm(sampler.timesteps)
  2.         for i, timestep in enumerate(timesteps):
  3.             ...
  4.             latents = sampler.step(timestep, latents, model_output)
复制代码
会得到最终的预测的无噪声的latent。到此我们diffusion部分就竣事了,将其移到idle_device:
  1. to_idle(diffusion)
复制代码
下一步,我们就只剩
  VAE_decoder:

  把latent映射回图像了:
  1.         decoder = models["decoder"]
  2.         decoder.to(device)
  3.         # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 3, Height, Width)
  4.         images = decoder(latents)
  5.         to_idle(decoder)
复制代码
来看看decoder其forward:
  1.     def forward(self, x):
  2.         # x: (Batch_Size, 4, Height / 8, Width / 8)
  3.         # Remove the scaling added by the Encoder.
  4.         x /= 0.18215
  5.         for module in self:
  6.             x = module(x)
  7.         # (Batch_Size, 3, Height, Width)
  8.         return x
复制代码
首先,因为encoder时处于工程原因,回对latent进行x *= 0.18215的缩小,这里进入decoder module前先将其放大回来 x /= 0.18215,之后就是遍历module:
  1. class VAE_Decoder(nn.Sequential):
  2.     def __init__(self):
  3.         super().__init__(
  4.             # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
  5.             nn.Conv2d(4, 4, kernel_size=1, padding=0),
  6.             # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
  7.             nn.Conv2d(4, 512, kernel_size=3, padding=1),
  8.             VAE_ResidualBlock(512, 512),
  9.             VAE_AttentionBlock(512),
  10.             VAE_ResidualBlock(512, 512),
  11.             VAE_ResidualBlock(512, 512),
  12.             VAE_ResidualBlock(512, 512),
  13.             VAE_ResidualBlock(512, 512),
  14.             # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 4, Width / 4)
  15.             nn.Upsample(scale_factor=2),
  16.             nn.Conv2d(512, 512, kernel_size=3, padding=1),
  17.             VAE_ResidualBlock(512, 512),
  18.             VAE_ResidualBlock(512, 512),
  19.             VAE_ResidualBlock(512, 512),
  20.             # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 2, Width / 2)
  21.             nn.Upsample(scale_factor=2),
  22.             nn.Conv2d(512, 512, kernel_size=3, padding=1),
  23.             VAE_ResidualBlock(512, 256),
  24.             VAE_ResidualBlock(256, 256),
  25.             VAE_ResidualBlock(256, 256),
  26.             # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height, Width)
  27.             nn.Upsample(scale_factor=2),
  28.             nn.Conv2d(256, 256, kernel_size=3, padding=1),
  29.             VAE_ResidualBlock(256, 128),
  30.             VAE_ResidualBlock(128, 128),
  31.             VAE_ResidualBlock(128, 128),
  32.             # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
  33.             nn.GroupNorm(32, 128),
  34.             nn.SiLU(),
  35.             # (Batch_Size, 128, Height, Width) -> (Batch_Size, 3, Height, Width)
  36.             nn.Conv2d(128, 3, kernel_size=3, padding=1),
  37.         )
复制代码
所有模块都在之前介绍过了,而decoder的布局也与encoder对称,这里不再做重复的叙述了。
  最后,我们将decoder解码出的图像,重新表达为原始图像的存储格式:
  1.         images = rescale(images, (-1, 1), (0, 255), clamp=True)
  2.         # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Height, Width, Channel)
  3.         images = images.permute(0, 2, 3, 1)
  4.         images = images.to("cpu", torch.uint8).numpy()
  5.         return images[0]
复制代码
到此,终于获得了想要的图像了!
  最后的最后来写个demo运行整个逻辑:
  1. import model_loaderimport pipelinefrom PIL import Imagefrom pathlib import Pathfrom transformers import CLIPTokenizerimport torchDEVICE = "cpu"ALLOW_CUDA = FalseALLOW_MPS = Falseif torch.cuda.is_available() and ALLOW_CUDA:    DEVICE = "cuda"elif (torch.has_mps or torch.backends.mps.is_available()) and ALLOW_MPS:    DEVICE = "mps"print(f"Using device: {DEVICE}")tokenizer = CLIPTokenizer("../data/tokenizer_vocab.json", merges_file="../data/tokenizer_merges.txt")model_file = "../data/v1-5-pruned-emaonly.ckpt"
  2. models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)## TEXT TO IMAGEprompt = "A cat stretching on the floor, highly detailed, ultra sharp, cinematic, 8k resolution"uncond_prompt = ""  # Also known as negative promptdo_cfg = Truecfg_scale = 8  # min: 1, max: 14## IMAGE TO IMAGEinput_image = None# Comment to disable image to imageimage_path = "../images/yibo.jpg"# input_image = Image.open(image_path)# Higher values means more noise will be added to the input image, so the result will further from the input image.# Lower values means less noise is added to the input image, so output will be closer to the input image.strength = 0.9## SAMPLERsampler = "ddpm"num_inference_steps = 50seed = 42output_image = pipeline.generate(    prompt=prompt,    uncond_prompt=uncond_prompt,    input_image=input_image,    strength=strength,    do_cfg=do_cfg,    cfg_scale=cfg_scale,    sampler_name=sampler,    n_inference_steps=num_inference_steps,    seed=seed,    models=models,    device=DEVICE,    idle_device="cpu",    tokenizer=tokenizer,)# Combine the input image and the output image into a single image.Image.fromarray(output_image)
复制代码
最终,生成了文章封面的图像:
  

  至此整个流程全部竣事,撒花。
       <script type="text/javascript"></script>
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 qidao123.com技术社区-IT企服评测·应用市场 (https://dis.qidao123.com/) Powered by Discuz! X3.4