stable diffusion中的UNet2DConditionModel代码解读

打印 上一主题 下一主题

主题 1659|帖子 1659|积分 4977

UNet2DConditionModel总体布局

图片来自于 https://zhuanlan.zhihu.com/p/635204519
stable diffusion 运行unet部分的代码。
  1. noise_pred = self.unet(
  2.     sample=latent_model_input,  #(2,4,64,64) 生成的latent
  3.     timestep=t,  #时刻t
  4.     encoder_hidden_states=prompt_embeds, #(2,77,768) #输入的prompt和negative prompt 生成的embedding
  5.     timestep_cond=timestep_cond,#默认空
  6.     cross_attention_kwargs=self.cross_attention_kwargs, #默认空
  7.     added_cond_kwargs=added_cond_kwargs, #默认空
  8.     return_dict=False,
  9. )[0]
复制代码
1.time

get_time_embed使用了sinusoidal timestep embeddings,
time_embedding 使用了两个线性层和激活层举行映射,将320转换到1280。
如果还有class_labels,added_cond_kwargs等参数,也转换为embedding,并且相加。
  1. t_emb = self.get_time_embed(sample=sample, timestep=timestep)  #(2,320)
  2. emb = self.time_embedding(t_emb, timestep_cond)  #(2,1280)
复制代码
2.pre-process

卷积转换,输入latent从(2,4,64,64) 到(2,320,64,64)
  1. sample = self.conv_in(sample)  #(2,320,64,64)
  2. self.conv_in = nn.Conv2d(
  3.             in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
  4.         )
复制代码
3.down

down_block 由三个CrossAttnDownBlock2D和一个DownBlock2D组成。输入包罗:


  • hidden_states:latent
  • temb:时刻t的embdedding
  • encoder_hidden_states:prompt和negative prompt的embedding
网络布局
  1. CrossAttnDownBlock2D(
  2.     ResnetBlock2D()
  3.     Transformer2DModel()
  4.     ResnetBlock2D()
  5.     Transformer2DModel()   
  6.     Downsample2D()  #(2,320,32,32)
  7. )
  8. CrossAttnDownBlock2D(
  9.     ResnetBlock2D()
  10.     Transformer2DModel()
  11.     ResnetBlock2D()
  12.     Transformer2DModel()   
  13.     Downsample2D()  #(2,640,16,16)
  14. )
  15. CrossAttnDownBlock2D(
  16.     ResnetBlock2D()
  17.     Transformer2DModel()
  18.     ResnetBlock2D()
  19.     Transformer2DModel()   
  20.     Downsample2D()  #(2,1280,8,8)
  21. )
  22. DownBlock2D(
  23.     ResnetBlock2D()
  24.     ResnetBlock2D()  #(2,1280,8,8)
  25. )
复制代码
4.mid

UNetMidBlock2DCrossAttn 包含 resnet,attn,resnet三个模块,输入输出维度稳固。输入包罗:


  • hidden_states:latent
  • temb,时刻t的embdedding
  • encoder_hidden_states:prompt和negative prompt的embedding
  1. UNetMidBlock2DCrossAttn(
  2.     ResnetBlock2D()
  3.     Transformer2DModel()
  4.     ResnetBlock2D()
  5. )
复制代码
5.up

up由一个UpBlock2D和三个CrossAttnUpBlock2D组成,输入包罗:


  • hidden_states:latent
  • temb: 时刻t的embdedding
  • encoder_hidden_states:prompt和negative prompt的embedding
  • res_hidden_states_tupleL:下采样时每个block的效果,skip connection。
  1. UpBlock2D(
  2.     ResnetBlock2D()
  3.     ResnetBlock2D()
  4.     ResnetBlock2D()
  5.     Upsample2D()  #(2,1280,16,16)
  6. )
  7. CrossAttnUpBlock2D(
  8.     ResnetBlock2D()
  9.     Transformer2DModel()
  10.     ResnetBlock2D()
  11.     Transformer2DModel()   
  12.     ResnetBlock2D()
  13.     Transformer2DModel()
  14.     Downsample2D()  #(2,1280,32,32)
  15. )
  16. CrossAttnUpBlock2D(
  17.     ResnetBlock2D()
  18.     Transformer2DModel()
  19.     ResnetBlock2D()
  20.     Transformer2DModel()   
  21.     ResnetBlock2D()
  22.     Transformer2DModel()
  23.     Downsample2D()  #(2,640,64,64)
  24. )
  25. CrossAttnUpBlock2D(
  26.     ResnetBlock2D() #(2,320,64,64)
  27.     Transformer2DModel()
  28.     ResnetBlock2D()
  29.     Transformer2DModel()   
  30.     ResnetBlock2D()
  31.     Transformer2DModel()
  32. )  
复制代码
6.post-process

卷积变动通道数,得到最终效果
  1. if self.conv_norm_out:
  2.      sample = self.conv_norm_out(sample)
  3.      sample = self.conv_act(sample)
  4. sample = self.conv_out(sample) #(2,4,64,64)
复制代码
时刻t,种别class等参数作用在resnet部分,都是和输入直接相加。
由prompt,negative prompt 盘算得到的encoder_hidden_states,作用在attention部分,作为key和value,参与盘算。
ResnetBlock2D

x在标准化、激活、卷积之后,和temb相加,再次标准化、激活、卷积之后作为残差,与x相加。
  1. hidden_states = input_tensor
  2. hidden_states = self.norm1(hidden_states)
  3. hidden_states = self.nonlinearity(hidden_states) #激活函数
  4. hidden_states = self.conv1(hidden_states)
  5. if self.time_emb_proj is not None:
  6.     if not self.skip_time_act:
  7.         temb = self.nonlinearity(temb)
  8.     temb = self.time_emb_proj(temb)[:, :, None, None] #(2,320,1,1)
  9. if self.time_embedding_norm == "default":
  10.     if temb is not None:
  11.         hidden_states = hidden_states + temb  #与temb相加
  12.     hidden_states = self.norm2(hidden_states)            
  13. hidden_states = self.nonlinearity(hidden_states)
  14. hidden_states = self.dropout(hidden_states)
  15. hidden_states = self.conv2(hidden_states)
  16. output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
  17. return output_tensor
复制代码
Transformer2DModel attentions部分

每个attention 包罗 Self-Attention 和Cross-Attention两部分。
  1. #Self-Attention ,encoder_hidden_states=None
  2. attn_output = self.attn1(
  3.     norm_hidden_states,
  4.     encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
  5.     attention_mask=attention_mask,
  6.     **cross_attention_kwargs,
  7.        )
  8. #Cross-Attention,encoder_hidden_states由prompt计算得来,在这里和latent交互。
  9. attn_output = self.attn2(
  10.     norm_hidden_states,
  11.     encoder_hidden_states=encoder_hidden_states,
  12.     attention_mask=encoder_attention_mask,
  13.     **cross_attention_kwargs,
  14. )
  15. #query由norm_hidden_states计算而来,
  16. #key、value由encoder_hidden_states计算而来。
  17. query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  #(2,8,4096,40)
  18. key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  #(2,8,77,40)
  19. value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) #(2,8,77,40)
  20. hidden_states = F.scaled_dot_product_attention(
  21.     query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False  #(2,8,4096,40)
  22. )
复制代码
参考:stable diffusion 中使用的 UNet 2D Condition Model 布局解析(diffusers库)

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

九天猎人

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表