AIGC笔记--Stable Diffusion源码剖析之UNetModel

打印 上一主题 下一主题

主题 386|帖子 386|积分 1158

1--媒介

           以论文《High-Resolution Image Synthesis with Latent Diffusion Models》  开源的项目为例,剖析Stable Diffusion经典组成部分,巩固学习加深印象。
  2--UNetModel

一个可以debug的小demo:SD_UNet
           以文生图为例,剖析UNetModel核心组成模块。
  2-1--Forward统辖

   提供的文生图Demo中,实际传入的参数只有x、timesteps和context三个,此中:
          x 表现随机初始化的噪声Tensor(shape: [B*2, 4, 64, 64],*2表现利用Classifier-Free Diffusion Guidance)。
          timesteps 表现去噪过程中每一轮传入的timestep(shape: [B*2])。
          context表现颠末CLIP编码后对应的文本Prompt(shape: [B*2, 77, 768])。
  1.     def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
  2.         """
  3.         Apply the model to an input batch.
  4.         :param x: an [N x C x ...] Tensor of inputs.
  5.         :param timesteps: a 1-D batch of timesteps.
  6.         :param context: conditioning plugged in via crossattn
  7.         :param y: an [N] Tensor of labels, if class-conditional.
  8.         :return: an [N x C x ...] Tensor of outputs.
  9.         """
  10.         assert (y is not None) == (
  11.             self.num_classes is not None
  12.         ), "must specify y if and only if the model is class-conditional"
  13.         hs = []
  14.         t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # Create sinusoidal timestep embeddings.
  15.         emb = self.time_embed(t_emb) # MLP
  16.         if self.num_classes is not None:
  17.             assert y.shape == (x.shape[0],)
  18.             emb = emb + self.label_emb(y)
  19.         h = x.type(self.dtype)
  20.         for module in self.input_blocks:
  21.             h = module(h, emb, context)
  22.             hs.append(h)
  23.         h = self.middle_block(h, emb, context)
  24.         for module in self.output_blocks:
  25.             h = th.cat([h, hs.pop()], dim=1)
  26.             h = module(h, emb, context)
  27.         h = h.type(x.dtype)
  28.         if self.predict_codebook_ids:
  29.             return self.id_predictor(h)
  30.         else:
  31.             return self.out(h)
复制代码
2-2--timestep embedding生成

           利用函数 timestep_embedding() 和 self.time_embed() 对传入的timestep进行位置编码,生成sinusoidal timestep embeddings。
          此中 timestep_embedding() 函数定义如下,而self.time_embed()是一个MLP函数。
  1. def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
  2.     """
  3.     Create sinusoidal timestep embeddings.
  4.     :param timesteps: a 1-D Tensor of N indices, one per batch element.
  5.                       These may be fractional.
  6.     :param dim: the dimension of the output.
  7.     :param max_period: controls the minimum frequency of the embeddings.
  8.     :return: an [N x dim] Tensor of positional embeddings.
  9.     """
  10.     if not repeat_only:
  11.         half = dim // 2
  12.         freqs = torch.exp(
  13.             -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
  14.         ).to(device=timesteps.device)
  15.         args = timesteps[:, None].float() * freqs[None]
  16.         embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
  17.         if dim % 2:
  18.             embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
  19.     else:
  20.         embedding = repeat(timesteps, 'b -> b d', d=dim)
  21.     return embedding
复制代码
  1. self.time_embed = nn.Sequential(
  2.     linear(model_channels, time_embed_dim),
  3.     nn.SiLU(),
  4.     linear(time_embed_dim, time_embed_dim),
  5. )
复制代码
2-3--self.input_blocks下采样

           在 Forward() 中,利用 self.input_blocks 将输入噪声进行分辨率下采样,颠末下采样详细维度变化为:[B*2, 4, 64, 64] > [B*2, 1280, 8, 8];
          下采样模块共有12个 module,其组成如下:
  1. ModuleList(
  2.   (0): TimestepEmbedSequential(
  3.     (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  4.   )
  5.   (1-2): 2 x TimestepEmbedSequential(
  6.     (0): ResBlock(
  7.       (in_layers): Sequential(
  8.         (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
  9.         (1): SiLU()
  10.         (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  11.       )
  12.       (h_upd): Identity()
  13.       (x_upd): Identity()
  14.       (emb_layers): Sequential(
  15.         (0): SiLU()
  16.         (1): Linear(in_features=1280, out_features=320, bias=True)
  17.       )
  18.       (out_layers): Sequential(
  19.         (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
  20.         (1): SiLU()
  21.         (2): Dropout(p=0, inplace=False)
  22.         (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  23.       )
  24.       (skip_connection): Identity()
  25.     )
  26.     (1): SpatialTransformer(
  27.       (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
  28.       (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  29.       (transformer_blocks): ModuleList(
  30.         (0): BasicTransformerBlock(
  31.           (attn1): CrossAttention(
  32.             (to_q): Linear(in_features=320, out_features=320, bias=False)
  33.             (to_k): Linear(in_features=320, out_features=320, bias=False)
  34.             (to_v): Linear(in_features=320, out_features=320, bias=False)
  35.             (to_out): Sequential(
  36.               (0): Linear(in_features=320, out_features=320, bias=True)
  37.               (1): Dropout(p=0.0, inplace=False)
  38.             )
  39.           )
  40.           (ff): FeedForward(
  41.             (net): Sequential(
  42.               (0): GEGLU(
  43.                 (proj): Linear(in_features=320, out_features=2560, bias=True)
  44.               )
  45.               (1): Dropout(p=0.0, inplace=False)
  46.               (2): Linear(in_features=1280, out_features=320, bias=True)
  47.             )
  48.           )
  49.           (attn2): CrossAttention(
  50.             (to_q): Linear(in_features=320, out_features=320, bias=False)
  51.             (to_k): Linear(in_features=768, out_features=320, bias=False)
  52.             (to_v): Linear(in_features=768, out_features=320, bias=False)
  53.             (to_out): Sequential(
  54.               (0): Linear(in_features=320, out_features=320, bias=True)
  55.               (1): Dropout(p=0.0, inplace=False)
  56.             )
  57.           )
  58.           (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  59.           (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  60.           (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  61.         )
  62.       )
  63.       (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  64.     )
  65.   )
  66.   (3): TimestepEmbedSequential(
  67.     (0): Downsample(
  68.       (op): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  69.     )
  70.   )
  71.   (4): TimestepEmbedSequential(
  72.     (0): ResBlock(
  73.       (in_layers): Sequential(
  74.         (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
  75.         (1): SiLU()
  76.         (2): Conv2d(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  77.       )
  78.       (h_upd): Identity()
  79.       (x_upd): Identity()
  80.       (emb_layers): Sequential(
  81.         (0): SiLU()
  82.         (1): Linear(in_features=1280, out_features=640, bias=True)
  83.       )
  84.       (out_layers): Sequential(
  85.         (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
  86.         (1): SiLU()
  87.         (2): Dropout(p=0, inplace=False)
  88.         (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  89.       )
  90.       (skip_connection): Conv2d(320, 640, kernel_size=(1, 1), stride=(1, 1))
  91.     )
  92.     (1): SpatialTransformer(
  93.       (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
  94.       (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
  95.       (transformer_blocks): ModuleList(
  96.         (0): BasicTransformerBlock(
  97.           (attn1): CrossAttention(
  98.             (to_q): Linear(in_features=640, out_features=640, bias=False)
  99.             (to_k): Linear(in_features=640, out_features=640, bias=False)
  100.             (to_v): Linear(in_features=640, out_features=640, bias=False)
  101.             (to_out): Sequential(
  102.               (0): Linear(in_features=640, out_features=640, bias=True)
  103.               (1): Dropout(p=0.0, inplace=False)
  104.             )
  105.           )
  106.           (ff): FeedForward(
  107.             (net): Sequential(
  108.               (0): GEGLU(
  109.                 (proj): Linear(in_features=640, out_features=5120, bias=True)
  110.               )
  111.               (1): Dropout(p=0.0, inplace=False)
  112.               (2): Linear(in_features=2560, out_features=640, bias=True)
  113.             )
  114.           )
  115.           (attn2): CrossAttention(
  116.             (to_q): Linear(in_features=640, out_features=640, bias=False)
  117.             (to_k): Linear(in_features=768, out_features=640, bias=False)
  118.             (to_v): Linear(in_features=768, out_features=640, bias=False)
  119.             (to_out): Sequential(
  120.               (0): Linear(in_features=640, out_features=640, bias=True)
  121.               (1): Dropout(p=0.0, inplace=False)
  122.             )
  123.           )
  124.           (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  125.           (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  126.           (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  127.         )
  128.       )
  129.       (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
  130.     )
  131.   )
  132.   (5): TimestepEmbedSequential(
  133.     (0): ResBlock(
  134.       (in_layers): Sequential(
  135.         (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
  136.         (1): SiLU()
  137.         (2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  138.       )
  139.       (h_upd): Identity()
  140.       (x_upd): Identity()
  141.       (emb_layers): Sequential(
  142.         (0): SiLU()
  143.         (1): Linear(in_features=1280, out_features=640, bias=True)
  144.       )
  145.       (out_layers): Sequential(
  146.         (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
  147.         (1): SiLU()
  148.         (2): Dropout(p=0, inplace=False)
  149.         (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  150.       )
  151.       (skip_connection): Identity()
  152.     )
  153.     (1): SpatialTransformer(
  154.       (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
  155.       (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
  156.       (transformer_blocks): ModuleList(
  157.         (0): BasicTransformerBlock(
  158.           (attn1): CrossAttention(
  159.             (to_q): Linear(in_features=640, out_features=640, bias=False)
  160.             (to_k): Linear(in_features=640, out_features=640, bias=False)
  161.             (to_v): Linear(in_features=640, out_features=640, bias=False)
  162.             (to_out): Sequential(
  163.               (0): Linear(in_features=640, out_features=640, bias=True)
  164.               (1): Dropout(p=0.0, inplace=False)
  165.             )
  166.           )
  167.           (ff): FeedForward(
  168.             (net): Sequential(
  169.               (0): GEGLU(
  170.                 (proj): Linear(in_features=640, out_features=5120, bias=True)
  171.               )
  172.               (1): Dropout(p=0.0, inplace=False)
  173.               (2): Linear(in_features=2560, out_features=640, bias=True)
  174.             )
  175.           )
  176.           (attn2): CrossAttention(
  177.             (to_q): Linear(in_features=640, out_features=640, bias=False)
  178.             (to_k): Linear(in_features=768, out_features=640, bias=False)
  179.             (to_v): Linear(in_features=768, out_features=640, bias=False)
  180.             (to_out): Sequential(
  181.               (0): Linear(in_features=640, out_features=640, bias=True)
  182.               (1): Dropout(p=0.0, inplace=False)
  183.             )
  184.           )
  185.           (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  186.           (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  187.           (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  188.         )
  189.       )
  190.       (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
  191.     )
  192.   )
  193.   (6): TimestepEmbedSequential(
  194.     (0): Downsample(
  195.       (op): Conv2d(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  196.     )
  197.   )
  198.   (7): TimestepEmbedSequential(
  199.     (0): ResBlock(
  200.       (in_layers): Sequential(
  201.         (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
  202.         (1): SiLU()
  203.         (2): Conv2d(640, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  204.       )
  205.       (h_upd): Identity()
  206.       (x_upd): Identity()
  207.       (emb_layers): Sequential(
  208.         (0): SiLU()
  209.         (1): Linear(in_features=1280, out_features=1280, bias=True)
  210.       )
  211.       (out_layers): Sequential(
  212.         (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  213.         (1): SiLU()
  214.         (2): Dropout(p=0, inplace=False)
  215.         (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  216.       )
  217.       (skip_connection): Conv2d(640, 1280, kernel_size=(1, 1), stride=(1, 1))
  218.     )
  219.     (1): SpatialTransformer(
  220.       (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
  221.       (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  222.       (transformer_blocks): ModuleList(
  223.         (0): BasicTransformerBlock(
  224.           (attn1): CrossAttention(
  225.             (to_q): Linear(in_features=1280, out_features=1280, bias=False)
  226.             (to_k): Linear(in_features=1280, out_features=1280, bias=False)
  227.             (to_v): Linear(in_features=1280, out_features=1280, bias=False)
  228.             (to_out): Sequential(
  229.               (0): Linear(in_features=1280, out_features=1280, bias=True)
  230.               (1): Dropout(p=0.0, inplace=False)
  231.             )
  232.           )
  233.           (ff): FeedForward(
  234.             (net): Sequential(
  235.               (0): GEGLU(
  236.                 (proj): Linear(in_features=1280, out_features=10240, bias=True)
  237.               )
  238.               (1): Dropout(p=0.0, inplace=False)
  239.               (2): Linear(in_features=5120, out_features=1280, bias=True)
  240.             )
  241.           )
  242.           (attn2): CrossAttention(
  243.             (to_q): Linear(in_features=1280, out_features=1280, bias=False)
  244.             (to_k): Linear(in_features=768, out_features=1280, bias=False)
  245.             (to_v): Linear(in_features=768, out_features=1280, bias=False)
  246.             (to_out): Sequential(
  247.               (0): Linear(in_features=1280, out_features=1280, bias=True)
  248.               (1): Dropout(p=0.0, inplace=False)
  249.             )
  250.           )
  251.           (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  252.           (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  253.           (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  254.         )
  255.       )
  256.       (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  257.     )
  258.   )
  259.   (8): TimestepEmbedSequential(
  260.     (0): ResBlock(
  261.       (in_layers): Sequential(
  262.         (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  263.         (1): SiLU()
  264.         (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  265.       )
  266.       (h_upd): Identity()
  267.       (x_upd): Identity()
  268.       (emb_layers): Sequential(
  269.         (0): SiLU()
  270.         (1): Linear(in_features=1280, out_features=1280, bias=True)
  271.       )
  272.       (out_layers): Sequential(
  273.         (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  274.         (1): SiLU()
  275.         (2): Dropout(p=0, inplace=False)
  276.         (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  277.       )
  278.       (skip_connection): Identity()
  279.     )
  280.     (1): SpatialTransformer(
  281.       (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
  282.       (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  283.       (transformer_blocks): ModuleList(
  284.         (0): BasicTransformerBlock(
  285.           (attn1): CrossAttention(
  286.             (to_q): Linear(in_features=1280, out_features=1280, bias=False)
  287.             (to_k): Linear(in_features=1280, out_features=1280, bias=False)
  288.             (to_v): Linear(in_features=1280, out_features=1280, bias=False)
  289.             (to_out): Sequential(
  290.               (0): Linear(in_features=1280, out_features=1280, bias=True)
  291.               (1): Dropout(p=0.0, inplace=False)
  292.             )
  293.           )
  294.           (ff): FeedForward(
  295.             (net): Sequential(
  296.               (0): GEGLU(
  297.                 (proj): Linear(in_features=1280, out_features=10240, bias=True)
  298.               )
  299.               (1): Dropout(p=0.0, inplace=False)
  300.               (2): Linear(in_features=5120, out_features=1280, bias=True)
  301.             )
  302.           )
  303.           (attn2): CrossAttention(
  304.             (to_q): Linear(in_features=1280, out_features=1280, bias=False)
  305.             (to_k): Linear(in_features=768, out_features=1280, bias=False)
  306.             (to_v): Linear(in_features=768, out_features=1280, bias=False)
  307.             (to_out): Sequential(
  308.               (0): Linear(in_features=1280, out_features=1280, bias=True)
  309.               (1): Dropout(p=0.0, inplace=False)
  310.             )
  311.           )
  312.           (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  313.           (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  314.           (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  315.         )
  316.       )
  317.       (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  318.     )
  319.   )
  320.   (9): TimestepEmbedSequential(
  321.     (0): Downsample(
  322.       (op): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  323.     )
  324.   )
  325.   (10-11): 2 x TimestepEmbedSequential(
  326.     (0): ResBlock(
  327.       (in_layers): Sequential(
  328.         (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  329.         (1): SiLU()
  330.         (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  331.       )
  332.       (h_upd): Identity()
  333.       (x_upd): Identity()
  334.       (emb_layers): Sequential(
  335.         (0): SiLU()
  336.         (1): Linear(in_features=1280, out_features=1280, bias=True)
  337.       )
  338.       (out_layers): Sequential(
  339.         (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  340.         (1): SiLU()
  341.         (2): Dropout(p=0, inplace=False)
  342.         (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  343.       )
  344.       (skip_connection): Identity()
  345.     )
  346.   )
  347. )
复制代码
          12个 module 都利用了 TimestepEmbedSequential 类进行封装,根据不同的网络层,将输入噪声x与timestep embedding和prompt context进行运算。
  1. class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
  2.     """
  3.     A sequential module that passes timestep embeddings to the children that
  4.     support it as an extra input.
  5.     """
  6.     def forward(self, x, emb, context=None):
  7.         for layer in self:
  8.             if isinstance(layer, TimestepBlock):
  9.                 x = layer(x, emb)
  10.             elif isinstance(layer, SpatialTransformer):
  11.                 x = layer(x, context)
  12.             else:
  13.                 x = layer(x)
  14.         return x
复制代码
2-3-1--Module0

    Module 0 是一个2D卷积层,重要对输入噪声进行特性提取;
  1. # init 初始化
  2. self.input_blocks = nn.ModuleList(
  3.     [
  4.         TimestepEmbedSequential(
  5.             conv_nd(dims, in_channels, model_channels, 3, padding=1)
  6.         )
  7.     ]
  8. )
  9. # 打印 self.input_blocks[0]
  10. TimestepEmbedSequential(
  11.   (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  12. )
复制代码
2-3-2--Module1和Module2

           Module1和Module2的结构相同,都由一个ResBlock和一个SpatialTransformer组成;
  1. # init 初始化
  2. for _ in range(num_res_blocks):
  3.                 layers = [
  4.                     ResBlock(
  5.                         ch,
  6.                         time_embed_dim,
  7.                         dropout,
  8.                         out_channels=mult * model_channels,
  9.                         dims=dims,
  10.                         use_checkpoint=use_checkpoint,
  11.                         use_scale_shift_norm=use_scale_shift_norm,
  12.                     )
  13.                 ]
  14.                 ch = mult * model_channels
  15.                 if ds in attention_resolutions:
  16.                     if num_head_channels == -1:
  17.                         dim_head = ch // num_heads
  18.                     else:
  19.                         num_heads = ch // num_head_channels
  20.                         dim_head = num_head_channels
  21.                     if legacy:
  22.                         #num_heads = 1
  23.                         dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
  24.                     layers.append(
  25.                         AttentionBlock(
  26.                             ch,
  27.                             use_checkpoint=use_checkpoint,
  28.                             num_heads=num_heads,
  29.                             num_head_channels=dim_head,
  30.                             use_new_attention_order=use_new_attention_order,
  31.                         ) if not use_spatial_transformer else SpatialTransformer(
  32.                             ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
  33.                         )
  34.                     )
  35.                 self.input_blocks.append(TimestepEmbedSequential(*layers))
  36.                 self._feature_size += ch
  37.                 input_block_chans.append(ch)
  38. # 打印 self.input_blocks[1]
  39. TimestepEmbedSequential(
  40.   (0): ResBlock(
  41.     (in_layers): Sequential(
  42.       (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
  43.       (1): SiLU()
  44.       (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  45.     )
  46.     (h_upd): Identity()
  47.     (x_upd): Identity()
  48.     (emb_layers): Sequential(
  49.       (0): SiLU()
  50.       (1): Linear(in_features=1280, out_features=320, bias=True)
  51.     )
  52.     (out_layers): Sequential(
  53.       (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
  54.       (1): SiLU()
  55.       (2): Dropout(p=0, inplace=False)
  56.       (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  57.     )
  58.     (skip_connection): Identity()
  59.   )
  60.   (1): SpatialTransformer(
  61.     (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
  62.     (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  63.     (transformer_blocks): ModuleList(
  64.       (0): BasicTransformerBlock(
  65.         (attn1): CrossAttention(
  66.           (to_q): Linear(in_features=320, out_features=320, bias=False)
  67.           (to_k): Linear(in_features=320, out_features=320, bias=False)
  68.           (to_v): Linear(in_features=320, out_features=320, bias=False)
  69.           (to_out): Sequential(
  70.             (0): Linear(in_features=320, out_features=320, bias=True)
  71.             (1): Dropout(p=0.0, inplace=False)
  72.           )
  73.         )
  74.         (ff): FeedForward(
  75.           (net): Sequential(
  76.             (0): GEGLU(
  77.               (proj): Linear(in_features=320, out_features=2560, bias=True)
  78.             )
  79.             (1): Dropout(p=0.0, inplace=False)
  80.             (2): Linear(in_features=1280, out_features=320, bias=True)
  81.           )
  82.         )
  83.         (attn2): CrossAttention(
  84.           (to_q): Linear(in_features=320, out_features=320, bias=False)
  85.           (to_k): Linear(in_features=768, out_features=320, bias=False)
  86.           (to_v): Linear(in_features=768, out_features=320, bias=False)
  87.           (to_out): Sequential(
  88.             (0): Linear(in_features=320, out_features=320, bias=True)
  89.             (1): Dropout(p=0.0, inplace=False)
  90.           )
  91.         )
  92.         (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  93.         (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  94.         (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  95.       )
  96.     )
  97.     (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  98.   )
  99. )
  100. # 打印 self.input_blocks[2]
  101. TimestepEmbedSequential(
  102.   (0): ResBlock(
  103.     (in_layers): Sequential(
  104.       (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
  105.       (1): SiLU()
  106.       (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  107.     )
  108.     (h_upd): Identity()
  109.     (x_upd): Identity()
  110.     (emb_layers): Sequential(
  111.       (0): SiLU()
  112.       (1): Linear(in_features=1280, out_features=320, bias=True)
  113.     )
  114.     (out_layers): Sequential(
  115.       (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
  116.       (1): SiLU()
  117.       (2): Dropout(p=0, inplace=False)
  118.       (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  119.     )
  120.     (skip_connection): Identity()
  121.   )
  122.   (1): SpatialTransformer(
  123.     (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
  124.     (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  125.     (transformer_blocks): ModuleList(
  126.       (0): BasicTransformerBlock(
  127.         (attn1): CrossAttention(
  128.           (to_q): Linear(in_features=320, out_features=320, bias=False)
  129.           (to_k): Linear(in_features=320, out_features=320, bias=False)
  130.           (to_v): Linear(in_features=320, out_features=320, bias=False)
  131.           (to_out): Sequential(
  132.             (0): Linear(in_features=320, out_features=320, bias=True)
  133.             (1): Dropout(p=0.0, inplace=False)
  134.           )
  135.         )
  136.         (ff): FeedForward(
  137.           (net): Sequential(
  138.             (0): GEGLU(
  139.               (proj): Linear(in_features=320, out_features=2560, bias=True)
  140.             )
  141.             (1): Dropout(p=0.0, inplace=False)
  142.             (2): Linear(in_features=1280, out_features=320, bias=True)
  143.           )
  144.         )
  145.         (attn2): CrossAttention(
  146.           (to_q): Linear(in_features=320, out_features=320, bias=False)
  147.           (to_k): Linear(in_features=768, out_features=320, bias=False)
  148.           (to_v): Linear(in_features=768, out_features=320, bias=False)
  149.           (to_out): Sequential(
  150.             (0): Linear(in_features=320, out_features=320, bias=True)
  151.             (1): Dropout(p=0.0, inplace=False)
  152.           )
  153.         )
  154.         (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  155.         (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  156.         (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  157.       )
  158.     )
  159.     (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  160.   )
  161. )
复制代码
2-3-3--Module3

           Module3是一个下采样2D卷积层。
  1. # init 初始化
  2. if level != len(channel_mult) - 1:
  3.     out_ch = ch
  4.     self.input_blocks.append(
  5.         TimestepEmbedSequential(
  6.             ResBlock(
  7.                 ch,
  8.                 time_embed_dim,
  9.                 dropout,
  10.                 out_channels=out_ch,
  11.                 dims=dims,
  12.                 use_checkpoint=use_checkpoint,
  13.                 use_scale_shift_norm=use_scale_shift_norm,
  14.                 down=True,
  15.             )
  16.             if resblock_updown
  17.             else Downsample(
  18.                 ch, conv_resample, dims=dims, out_channels=out_ch
  19.             )
  20.         )
  21.     )
  22. # 打印 self.input_blocks[3]
  23. TimestepEmbedSequential(
  24.   (0): Downsample(
  25.     (op): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  26.   )
  27. )
复制代码
2-3-4--Module4、Module5、Module7和Module8

           与Module1和Module2的结构相同,都由一个ResBlock和一个SpatialTransformer组成,只有特性维度上的区别;
  2-3-4--Module6和Module9

           与Module3的结构相同,是一个下采样2D卷积层。
  2-3--5-Module10和Module11

           Module10和Module12的结构相同,只由一个ResBlock组成。
  1. # 打印 self.input_blocks[10]
  2. TimestepEmbedSequential(
  3.   (0): ResBlock(
  4.     (in_layers): Sequential(
  5.       (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  6.       (1): SiLU()
  7.       (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  8.     )
  9.     (h_upd): Identity()
  10.     (x_upd): Identity()
  11.     (emb_layers): Sequential(
  12.       (0): SiLU()
  13.       (1): Linear(in_features=1280, out_features=1280, bias=True)
  14.     )
  15.     (out_layers): Sequential(
  16.       (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  17.       (1): SiLU()
  18.       (2): Dropout(p=0, inplace=False)
  19.       (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  20.     )
  21.     (skip_connection): Identity()
  22.   )
  23. )
  24. # 打印 self.input_blocks[11]
  25. TimestepEmbedSequential(
  26.   (0): ResBlock(
  27.     (in_layers): Sequential(
  28.       (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  29.       (1): SiLU()
  30.       (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  31.     )
  32.     (h_upd): Identity()
  33.     (x_upd): Identity()
  34.     (emb_layers): Sequential(
  35.       (0): SiLU()
  36.       (1): Linear(in_features=1280, out_features=1280, bias=True)
  37.     )
  38.     (out_layers): Sequential(
  39.       (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  40.       (1): SiLU()
  41.       (2): Dropout(p=0, inplace=False)
  42.       (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  43.     )
  44.     (skip_connection): Identity()
  45.   )
  46. )
复制代码
2-3-6--ResBlock

           ResBlock的输入是噪声图x和timestep embedding,通过卷积处理和残差毗连等方式将timestep embedding融入噪声图特性中,核心代码如下:
  1. class ResBlock(TimestepBlock):
  2.     """
  3.     A residual block that can optionally change the number of channels.
  4.     :param channels: the number of input channels.
  5.     :param emb_channels: the number of timestep embedding channels.
  6.     :param dropout: the rate of dropout.
  7.     :param out_channels: if specified, the number of out channels.
  8.     :param use_conv: if True and out_channels is specified, use a spatial
  9.         convolution instead of a smaller 1x1 convolution to change the
  10.         channels in the skip connection.
  11.     :param dims: determines if the signal is 1D, 2D, or 3D.
  12.     :param use_checkpoint: if True, use gradient checkpointing on this module.
  13.     :param up: if True, use this block for upsampling.
  14.     :param down: if True, use this block for downsampling.
  15.     """
  16.     def __init__(
  17.         self,
  18.         channels,
  19.         emb_channels,
  20.         dropout,
  21.         out_channels=None,
  22.         use_conv=False,
  23.         use_scale_shift_norm=False,
  24.         dims=2,
  25.         use_checkpoint=False,
  26.         up=False,
  27.         down=False,
  28.     ):
  29.         super().__init__()
  30.         self.channels = channels
  31.         self.emb_channels = emb_channels
  32.         self.dropout = dropout
  33.         self.out_channels = out_channels or channels
  34.         self.use_conv = use_conv
  35.         self.use_checkpoint = use_checkpoint
  36.         self.use_scale_shift_norm = use_scale_shift_norm
  37.         self.in_layers = nn.Sequential(
  38.             normalization(channels),
  39.             nn.SiLU(),
  40.             conv_nd(dims, channels, self.out_channels, 3, padding=1),
  41.         )
  42.         self.updown = up or down
  43.         if up:
  44.             self.h_upd = Upsample(channels, False, dims)
  45.             self.x_upd = Upsample(channels, False, dims)
  46.         elif down:
  47.             self.h_upd = Downsample(channels, False, dims)
  48.             self.x_upd = Downsample(channels, False, dims)
  49.         else:
  50.             self.h_upd = self.x_upd = nn.Identity()
  51.         self.emb_layers = nn.Sequential(
  52.             nn.SiLU(),
  53.             linear(
  54.                 emb_channels,
  55.                 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
  56.             ),
  57.         )
  58.         self.out_layers = nn.Sequential(
  59.             normalization(self.out_channels),
  60.             nn.SiLU(),
  61.             nn.Dropout(p=dropout),
  62.             zero_module(
  63.                 conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
  64.             ),
  65.         )
  66.         if self.out_channels == channels:
  67.             self.skip_connection = nn.Identity()
  68.         elif use_conv:
  69.             self.skip_connection = conv_nd(
  70.                 dims, channels, self.out_channels, 3, padding=1
  71.             )
  72.         else:
  73.             self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
  74.     def forward(self, x, emb):
  75.         """
  76.         Apply the block to a Tensor, conditioned on a timestep embedding.
  77.         :param x: an [N x C x ...] Tensor of features.
  78.         :param emb: an [N x emb_channels] Tensor of timestep embeddings.
  79.         :return: an [N x C x ...] Tensor of outputs.
  80.         """
  81.         return checkpoint(
  82.             self._forward, (x, emb), self.parameters(), self.use_checkpoint
  83.         )
  84.     def _forward(self, x, emb):
  85.         if self.updown:
  86.             in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
  87.             h = in_rest(x)
  88.             h = self.h_upd(h)
  89.             x = self.x_upd(x)
  90.             h = in_conv(h)
  91.         else:
  92.             h = self.in_layers(x) # [6, 320, 64, 64] -> [6, 320, 64, 64]
  93.         emb_out = self.emb_layers(emb).type(h.dtype) # [6, 1280] -> [6, 320]
  94.         while len(emb_out.shape) < len(h.shape): # [6, 320] -> [6, 320, 1, 1]
  95.             emb_out = emb_out[..., None]
  96.         if self.use_scale_shift_norm:
  97.             out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
  98.             scale, shift = th.chunk(emb_out, 2, dim=1)
  99.             h = out_norm(h) * (1 + scale) + shift
  100.             h = out_rest(h)
  101.         else:
  102.             h = h + emb_out # [6, 320, 64, 64] + [6, 320, 1, 1] -> [6, 320, 64, 64]
  103.             h = self.out_layers(h) # [6, 320, 64, 64]
  104.         return self.skip_connection(x) + h
复制代码
2-3-7--SpatialTransformer

           SpatialTransformer的输入是噪声图x和文本特性context,通过CrossAttention机制将文本特性融入到噪声图x中,完成条件驱动文生图,核心代码如下:
  1. from inspect import isfunction
  2. import math
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import nn, einsum
  6. from einops import rearrange, repeat
  7. from util import checkpoint
  8. def exists(val):
  9.     return val is not None
  10. def uniq(arr):
  11.     return{el: True for el in arr}.keys()
  12. def default(val, d):
  13.     if exists(val):
  14.         return val
  15.     return d() if isfunction(d) else d
  16. def max_neg_value(t):
  17.     return -torch.finfo(t.dtype).max
  18. def init_(tensor):
  19.     dim = tensor.shape[-1]
  20.     std = 1 / math.sqrt(dim)
  21.     tensor.uniform_(-std, std)
  22.     return tensor
  23. # feedforward
  24. class GEGLU(nn.Module):
  25.     def __init__(self, dim_in, dim_out):
  26.         super().__init__()
  27.         self.proj = nn.Linear(dim_in, dim_out * 2)
  28.     def forward(self, x):
  29.         x, gate = self.proj(x).chunk(2, dim=-1)
  30.         return x * F.gelu(gate)
  31. class FeedForward(nn.Module):
  32.     def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
  33.         super().__init__()
  34.         inner_dim = int(dim * mult)
  35.         dim_out = default(dim_out, dim)
  36.         project_in = nn.Sequential(
  37.             nn.Linear(dim, inner_dim),
  38.             nn.GELU()
  39.         ) if not glu else GEGLU(dim, inner_dim)
  40.         self.net = nn.Sequential(
  41.             project_in,
  42.             nn.Dropout(dropout),
  43.             nn.Linear(inner_dim, dim_out)
  44.         )
  45.     def forward(self, x):
  46.         return self.net(x)
  47. def zero_module(module):
  48.     """
  49.     Zero out the parameters of a module and return it.
  50.     """
  51.     for p in module.parameters():
  52.         p.detach().zero_()
  53.     return module
  54. def Normalize(in_channels):
  55.     return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
  56. class CrossAttention(nn.Module):
  57.     def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
  58.         super().__init__()
  59.         inner_dim = dim_head * heads # dim_head: 40, heads: 8
  60.         context_dim = default(context_dim, query_dim)
  61.         self.scale = dim_head ** -0.5
  62.         self.heads = heads
  63.         self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
  64.         self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
  65.         self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
  66.         self.to_out = nn.Sequential(
  67.             nn.Linear(inner_dim, query_dim),
  68.             nn.Dropout(dropout)
  69.         )
  70.     def forward(self, x, context=None, mask=None):
  71.         h = self.heads # 8
  72.         q = self.to_q(x) # [6, 4096, 320] -> [6, 4096, 320]
  73.         context = default(context, x) # return context [6, 77, 768]
  74.         k = self.to_k(context) # [6, 77, 768] -> [6, 77, 320]
  75.         v = self.to_v(context) # [6, 77, 768] -> [6, 77, 320]
  76.         q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) # [6, 4096, 320] -> [48, 4096, 40] # [6, 77, 320] -> [48, 77, 40]
  77.         sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # [48, 4096, 40] * [48, 77, 40] -> [48, 4096, 77]
  78.         if exists(mask):
  79.             mask = rearrange(mask, 'b ... -> b (...)')
  80.             max_neg_value = -torch.finfo(sim.dtype).max
  81.             mask = repeat(mask, 'b j -> (b h) () j', h=h)
  82.             sim.masked_fill_(~mask, max_neg_value)
  83.         # attention, what we cannot get enough of
  84.         attn = sim.softmax(dim=-1) # softmax
  85.         out = einsum('b i j, b j d -> b i d', attn, v) # [48, 4096, 77] * [48, 77, 40] -> [48, 4096, 40]
  86.         out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # [48, 4096, 40] -> [6, 4096, 320]
  87.         return self.to_out(out)
  88. class BasicTransformerBlock(nn.Module):
  89.     def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
  90.         super().__init__()
  91.         self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention
  92.         self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
  93.         self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
  94.                                     heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
  95.         self.norm1 = nn.LayerNorm(dim)
  96.         self.norm2 = nn.LayerNorm(dim)
  97.         self.norm3 = nn.LayerNorm(dim)
  98.         self.checkpoint = checkpoint
  99.     def forward(self, x, context=None):
  100.         return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
  101.     def _forward(self, x, context=None):
  102.         x = self.attn1(self.norm1(x)) + x # self Attention, [6, 4096, 320] -> [6, 4096, 320]
  103.         x = self.attn2(self.norm2(x), context=context) + x # cross Attention, [6, 4096, 320] -> [6, 4096, 320]
  104.         x = self.ff(self.norm3(x)) + x # FFN, [6, 4096, 320] -> [6, 4096, 320]
  105.         return x
  106. class SpatialTransformer(nn.Module):
  107.     """
  108.     Transformer block for image-like data.
  109.     First, project the input (aka embedding)
  110.     and reshape to b, t, d.
  111.     Then apply standard transformer action.
  112.     Finally, reshape to image
  113.     """
  114.     def __init__(self, in_channels, n_heads, d_head,
  115.                  depth=1, dropout=0., context_dim=None):
  116.         super().__init__()
  117.         self.in_channels = in_channels
  118.         inner_dim = n_heads * d_head
  119.         self.norm = Normalize(in_channels)
  120.         self.proj_in = nn.Conv2d(in_channels,
  121.                                  inner_dim,
  122.                                  kernel_size=1,
  123.                                  stride=1,
  124.                                  padding=0)
  125.         self.transformer_blocks = nn.ModuleList(
  126.             [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
  127.                 for d in range(depth)]
  128.         )
  129.         self.proj_out = zero_module(nn.Conv2d(inner_dim,
  130.                                               in_channels,
  131.                                               kernel_size=1,
  132.                                               stride=1,
  133.                                               padding=0))
  134.     def forward(self, x, context=None):
  135.         # note: if no context is given, cross-attention defaults to self-attention
  136.         b, c, h, w = x.shape # [6, 320, 64, 64]
  137.         x_in = x
  138.         x = self.norm(x) # [6, 320, 64, 64]
  139.         x = self.proj_in(x) # [6, 320, 64, 64]
  140.         x = rearrange(x, 'b c h w -> b (h w) c') # [6, 4096, 320]
  141.         for block in self.transformer_blocks:
  142.             x = block(x, context=context)
  143.         x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
  144.         x = self.proj_out(x)
  145.         return x + x_in
复制代码
2-4--self.middle_block

           self.middle_block由两个ResBlock和一个SpatialTransformer组成:
  1. TimestepEmbedSequential(
  2.   (0): ResBlock(
  3.     (in_layers): Sequential(
  4.       (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  5.       (1): SiLU()
  6.       (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  7.     )
  8.     (h_upd): Identity()
  9.     (x_upd): Identity()
  10.     (emb_layers): Sequential(
  11.       (0): SiLU()
  12.       (1): Linear(in_features=1280, out_features=1280, bias=True)
  13.     )
  14.     (out_layers): Sequential(
  15.       (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  16.       (1): SiLU()
  17.       (2): Dropout(p=0, inplace=False)
  18.       (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  19.     )
  20.     (skip_connection): Identity()
  21.   )
  22.   (1): SpatialTransformer(
  23.     (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
  24.     (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  25.     (transformer_blocks): ModuleList(
  26.       (0): BasicTransformerBlock(
  27.         (attn1): CrossAttention(
  28.           (to_q): Linear(in_features=1280, out_features=1280, bias=False)
  29.           (to_k): Linear(in_features=1280, out_features=1280, bias=False)
  30.           (to_v): Linear(in_features=1280, out_features=1280, bias=False)
  31.           (to_out): Sequential(
  32.             (0): Linear(in_features=1280, out_features=1280, bias=True)
  33.             (1): Dropout(p=0.0, inplace=False)
  34.           )
  35.         )
  36.         (ff): FeedForward(
  37.           (net): Sequential(
  38.             (0): GEGLU(
  39.               (proj): Linear(in_features=1280, out_features=10240, bias=True)
  40.             )
  41.             (1): Dropout(p=0.0, inplace=False)
  42.             (2): Linear(in_features=5120, out_features=1280, bias=True)
  43.           )
  44.         )
  45.         (attn2): CrossAttention(
  46.           (to_q): Linear(in_features=1280, out_features=1280, bias=False)
  47.           (to_k): Linear(in_features=768, out_features=1280, bias=False)
  48.           (to_v): Linear(in_features=768, out_features=1280, bias=False)
  49.           (to_out): Sequential(
  50.             (0): Linear(in_features=1280, out_features=1280, bias=True)
  51.             (1): Dropout(p=0.0, inplace=False)
  52.           )
  53.         )
  54.         (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  55.         (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  56.         (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  57.       )
  58.     )
  59.     (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  60.   )
  61.   (2): ResBlock(
  62.     (in_layers): Sequential(
  63.       (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  64.       (1): SiLU()
  65.       (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  66.     )
  67.     (h_upd): Identity()
  68.     (x_upd): Identity()
  69.     (emb_layers): Sequential(
  70.       (0): SiLU()
  71.       (1): Linear(in_features=1280, out_features=1280, bias=True)
  72.     )
  73.     (out_layers): Sequential(
  74.       (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  75.       (1): SiLU()
  76.       (2): Dropout(p=0, inplace=False)
  77.       (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  78.     )
  79.     (skip_connection): Identity()
  80.   )
  81. )
复制代码
2-5--self.output_blocks上采样

           在 Forward() 中,利用 self.output_blocks 将噪声图进行分辨率上采样,颠末上采样详细维度变化为:[B*2, 1280, 8, 8] > [B*2, 4, 64, 64];
          下采样模块共有12个 module,其结构与下采样模块类似,组成如下:
  1. ModuleList(
  2.   (0-1): 2 x TimestepEmbedSequential(
  3.     (0): ResBlock(
  4.       (in_layers): Sequential(
  5.         (0): GroupNorm32(32, 2560, eps=1e-05, affine=True)
  6.         (1): SiLU()
  7.         (2): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  8.       )
  9.       (h_upd): Identity()
  10.       (x_upd): Identity()
  11.       (emb_layers): Sequential(
  12.         (0): SiLU()
  13.         (1): Linear(in_features=1280, out_features=1280, bias=True)
  14.       )
  15.       (out_layers): Sequential(
  16.         (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  17.         (1): SiLU()
  18.         (2): Dropout(p=0, inplace=False)
  19.         (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  20.       )
  21.       (skip_connection): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
  22.     )
  23.   )
  24.   (2): TimestepEmbedSequential(
  25.     (0): ResBlock(
  26.       (in_layers): Sequential(
  27.         (0): GroupNorm32(32, 2560, eps=1e-05, affine=True)
  28.         (1): SiLU()
  29.         (2): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  30.       )
  31.       (h_upd): Identity()
  32.       (x_upd): Identity()
  33.       (emb_layers): Sequential(
  34.         (0): SiLU()
  35.         (1): Linear(in_features=1280, out_features=1280, bias=True)
  36.       )
  37.       (out_layers): Sequential(
  38.         (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  39.         (1): SiLU()
  40.         (2): Dropout(p=0, inplace=False)
  41.         (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  42.       )
  43.       (skip_connection): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
  44.     )
  45.     (1): Upsample(
  46.       (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  47.     )
  48.   )
  49.   (3-4): 2 x TimestepEmbedSequential(
  50.     (0): ResBlock(
  51.       (in_layers): Sequential(
  52.         (0): GroupNorm32(32, 2560, eps=1e-05, affine=True)
  53.         (1): SiLU()
  54.         (2): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  55.       )
  56.       (h_upd): Identity()
  57.       (x_upd): Identity()
  58.       (emb_layers): Sequential(
  59.         (0): SiLU()
  60.         (1): Linear(in_features=1280, out_features=1280, bias=True)
  61.       )
  62.       (out_layers): Sequential(
  63.         (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  64.         (1): SiLU()
  65.         (2): Dropout(p=0, inplace=False)
  66.         (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  67.       )
  68.       (skip_connection): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
  69.     )
  70.     (1): SpatialTransformer(
  71.       (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
  72.       (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  73.       (transformer_blocks): ModuleList(
  74.         (0): BasicTransformerBlock(
  75.           (attn1): CrossAttention(
  76.             (to_q): Linear(in_features=1280, out_features=1280, bias=False)
  77.             (to_k): Linear(in_features=1280, out_features=1280, bias=False)
  78.             (to_v): Linear(in_features=1280, out_features=1280, bias=False)
  79.             (to_out): Sequential(
  80.               (0): Linear(in_features=1280, out_features=1280, bias=True)
  81.               (1): Dropout(p=0.0, inplace=False)
  82.             )
  83.           )
  84.           (ff): FeedForward(
  85.             (net): Sequential(
  86.               (0): GEGLU(
  87.                 (proj): Linear(in_features=1280, out_features=10240, bias=True)
  88.               )
  89.               (1): Dropout(p=0.0, inplace=False)
  90.               (2): Linear(in_features=5120, out_features=1280, bias=True)
  91.             )
  92.           )
  93.           (attn2): CrossAttention(
  94.             (to_q): Linear(in_features=1280, out_features=1280, bias=False)
  95.             (to_k): Linear(in_features=768, out_features=1280, bias=False)
  96.             (to_v): Linear(in_features=768, out_features=1280, bias=False)
  97.             (to_out): Sequential(
  98.               (0): Linear(in_features=1280, out_features=1280, bias=True)
  99.               (1): Dropout(p=0.0, inplace=False)
  100.             )
  101.           )
  102.           (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  103.           (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  104.           (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  105.         )
  106.       )
  107.       (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  108.     )
  109.   )
  110.   (5): TimestepEmbedSequential(
  111.     (0): ResBlock(
  112.       (in_layers): Sequential(
  113.         (0): GroupNorm32(32, 1920, eps=1e-05, affine=True)
  114.         (1): SiLU()
  115.         (2): Conv2d(1920, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  116.       )
  117.       (h_upd): Identity()
  118.       (x_upd): Identity()
  119.       (emb_layers): Sequential(
  120.         (0): SiLU()
  121.         (1): Linear(in_features=1280, out_features=1280, bias=True)
  122.       )
  123.       (out_layers): Sequential(
  124.         (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  125.         (1): SiLU()
  126.         (2): Dropout(p=0, inplace=False)
  127.         (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  128.       )
  129.       (skip_connection): Conv2d(1920, 1280, kernel_size=(1, 1), stride=(1, 1))
  130.     )
  131.     (1): SpatialTransformer(
  132.       (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
  133.       (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  134.       (transformer_blocks): ModuleList(
  135.         (0): BasicTransformerBlock(
  136.           (attn1): CrossAttention(
  137.             (to_q): Linear(in_features=1280, out_features=1280, bias=False)
  138.             (to_k): Linear(in_features=1280, out_features=1280, bias=False)
  139.             (to_v): Linear(in_features=1280, out_features=1280, bias=False)
  140.             (to_out): Sequential(
  141.               (0): Linear(in_features=1280, out_features=1280, bias=True)
  142.               (1): Dropout(p=0.0, inplace=False)
  143.             )
  144.           )
  145.           (ff): FeedForward(
  146.             (net): Sequential(
  147.               (0): GEGLU(
  148.                 (proj): Linear(in_features=1280, out_features=10240, bias=True)
  149.               )
  150.               (1): Dropout(p=0.0, inplace=False)
  151.               (2): Linear(in_features=5120, out_features=1280, bias=True)
  152.             )
  153.           )
  154.           (attn2): CrossAttention(
  155.             (to_q): Linear(in_features=1280, out_features=1280, bias=False)
  156.             (to_k): Linear(in_features=768, out_features=1280, bias=False)
  157.             (to_v): Linear(in_features=768, out_features=1280, bias=False)
  158.             (to_out): Sequential(
  159.               (0): Linear(in_features=1280, out_features=1280, bias=True)
  160.               (1): Dropout(p=0.0, inplace=False)
  161.             )
  162.           )
  163.           (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  164.           (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  165.           (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  166.         )
  167.       )
  168.       (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  169.     )
  170.     (2): Upsample(
  171.       (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  172.     )
  173.   )
  174.   (6): TimestepEmbedSequential(
  175.     (0): ResBlock(
  176.       (in_layers): Sequential(
  177.         (0): GroupNorm32(32, 1920, eps=1e-05, affine=True)
  178.         (1): SiLU()
  179.         (2): Conv2d(1920, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  180.       )
  181.       (h_upd): Identity()
  182.       (x_upd): Identity()
  183.       (emb_layers): Sequential(
  184.         (0): SiLU()
  185.         (1): Linear(in_features=1280, out_features=640, bias=True)
  186.       )
  187.       (out_layers): Sequential(
  188.         (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
  189.         (1): SiLU()
  190.         (2): Dropout(p=0, inplace=False)
  191.         (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  192.       )
  193.       (skip_connection): Conv2d(1920, 640, kernel_size=(1, 1), stride=(1, 1))
  194.     )
  195.     (1): SpatialTransformer(
  196.       (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
  197.       (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
  198.       (transformer_blocks): ModuleList(
  199.         (0): BasicTransformerBlock(
  200.           (attn1): CrossAttention(
  201.             (to_q): Linear(in_features=640, out_features=640, bias=False)
  202.             (to_k): Linear(in_features=640, out_features=640, bias=False)
  203.             (to_v): Linear(in_features=640, out_features=640, bias=False)
  204.             (to_out): Sequential(
  205.               (0): Linear(in_features=640, out_features=640, bias=True)
  206.               (1): Dropout(p=0.0, inplace=False)
  207.             )
  208.           )
  209.           (ff): FeedForward(
  210.             (net): Sequential(
  211.               (0): GEGLU(
  212.                 (proj): Linear(in_features=640, out_features=5120, bias=True)
  213.               )
  214.               (1): Dropout(p=0.0, inplace=False)
  215.               (2): Linear(in_features=2560, out_features=640, bias=True)
  216.             )
  217.           )
  218.           (attn2): CrossAttention(
  219.             (to_q): Linear(in_features=640, out_features=640, bias=False)
  220.             (to_k): Linear(in_features=768, out_features=640, bias=False)
  221.             (to_v): Linear(in_features=768, out_features=640, bias=False)
  222.             (to_out): Sequential(
  223.               (0): Linear(in_features=640, out_features=640, bias=True)
  224.               (1): Dropout(p=0.0, inplace=False)
  225.             )
  226.           )
  227.           (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  228.           (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  229.           (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  230.         )
  231.       )
  232.       (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
  233.     )
  234.   )
  235.   (7): TimestepEmbedSequential(
  236.     (0): ResBlock(
  237.       (in_layers): Sequential(
  238.         (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
  239.         (1): SiLU()
  240.         (2): Conv2d(1280, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  241.       )
  242.       (h_upd): Identity()
  243.       (x_upd): Identity()
  244.       (emb_layers): Sequential(
  245.         (0): SiLU()
  246.         (1): Linear(in_features=1280, out_features=640, bias=True)
  247.       )
  248.       (out_layers): Sequential(
  249.         (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
  250.         (1): SiLU()
  251.         (2): Dropout(p=0, inplace=False)
  252.         (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  253.       )
  254.       (skip_connection): Conv2d(1280, 640, kernel_size=(1, 1), stride=(1, 1))
  255.     )
  256.     (1): SpatialTransformer(
  257.       (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
  258.       (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
  259.       (transformer_blocks): ModuleList(
  260.         (0): BasicTransformerBlock(
  261.           (attn1): CrossAttention(
  262.             (to_q): Linear(in_features=640, out_features=640, bias=False)
  263.             (to_k): Linear(in_features=640, out_features=640, bias=False)
  264.             (to_v): Linear(in_features=640, out_features=640, bias=False)
  265.             (to_out): Sequential(
  266.               (0): Linear(in_features=640, out_features=640, bias=True)
  267.               (1): Dropout(p=0.0, inplace=False)
  268.             )
  269.           )
  270.           (ff): FeedForward(
  271.             (net): Sequential(
  272.               (0): GEGLU(
  273.                 (proj): Linear(in_features=640, out_features=5120, bias=True)
  274.               )
  275.               (1): Dropout(p=0.0, inplace=False)
  276.               (2): Linear(in_features=2560, out_features=640, bias=True)
  277.             )
  278.           )
  279.           (attn2): CrossAttention(
  280.             (to_q): Linear(in_features=640, out_features=640, bias=False)
  281.             (to_k): Linear(in_features=768, out_features=640, bias=False)
  282.             (to_v): Linear(in_features=768, out_features=640, bias=False)
  283.             (to_out): Sequential(
  284.               (0): Linear(in_features=640, out_features=640, bias=True)
  285.               (1): Dropout(p=0.0, inplace=False)
  286.             )
  287.           )
  288.           (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  289.           (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  290.           (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  291.         )
  292.       )
  293.       (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
  294.     )
  295.   )
  296.   (8): TimestepEmbedSequential(
  297.     (0): ResBlock(
  298.       (in_layers): Sequential(
  299.         (0): GroupNorm32(32, 960, eps=1e-05, affine=True)
  300.         (1): SiLU()
  301.         (2): Conv2d(960, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  302.       )
  303.       (h_upd): Identity()
  304.       (x_upd): Identity()
  305.       (emb_layers): Sequential(
  306.         (0): SiLU()
  307.         (1): Linear(in_features=1280, out_features=640, bias=True)
  308.       )
  309.       (out_layers): Sequential(
  310.         (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
  311.         (1): SiLU()
  312.         (2): Dropout(p=0, inplace=False)
  313.         (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  314.       )
  315.       (skip_connection): Conv2d(960, 640, kernel_size=(1, 1), stride=(1, 1))
  316.     )
  317.     (1): SpatialTransformer(
  318.       (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
  319.       (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
  320.       (transformer_blocks): ModuleList(
  321.         (0): BasicTransformerBlock(
  322.           (attn1): CrossAttention(
  323.             (to_q): Linear(in_features=640, out_features=640, bias=False)
  324.             (to_k): Linear(in_features=640, out_features=640, bias=False)
  325.             (to_v): Linear(in_features=640, out_features=640, bias=False)
  326.             (to_out): Sequential(
  327.               (0): Linear(in_features=640, out_features=640, bias=True)
  328.               (1): Dropout(p=0.0, inplace=False)
  329.             )
  330.           )
  331.           (ff): FeedForward(
  332.             (net): Sequential(
  333.               (0): GEGLU(
  334.                 (proj): Linear(in_features=640, out_features=5120, bias=True)
  335.               )
  336.               (1): Dropout(p=0.0, inplace=False)
  337.               (2): Linear(in_features=2560, out_features=640, bias=True)
  338.             )
  339.           )
  340.           (attn2): CrossAttention(
  341.             (to_q): Linear(in_features=640, out_features=640, bias=False)
  342.             (to_k): Linear(in_features=768, out_features=640, bias=False)
  343.             (to_v): Linear(in_features=768, out_features=640, bias=False)
  344.             (to_out): Sequential(
  345.               (0): Linear(in_features=640, out_features=640, bias=True)
  346.               (1): Dropout(p=0.0, inplace=False)
  347.             )
  348.           )
  349.           (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  350.           (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  351.           (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
  352.         )
  353.       )
  354.       (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
  355.     )
  356.     (2): Upsample(
  357.       (conv): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  358.     )
  359.   )
  360.   (9): TimestepEmbedSequential(
  361.     (0): ResBlock(
  362.       (in_layers): Sequential(
  363.         (0): GroupNorm32(32, 960, eps=1e-05, affine=True)
  364.         (1): SiLU()
  365.         (2): Conv2d(960, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  366.       )
  367.       (h_upd): Identity()
  368.       (x_upd): Identity()
  369.       (emb_layers): Sequential(
  370.         (0): SiLU()
  371.         (1): Linear(in_features=1280, out_features=320, bias=True)
  372.       )
  373.       (out_layers): Sequential(
  374.         (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
  375.         (1): SiLU()
  376.         (2): Dropout(p=0, inplace=False)
  377.         (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  378.       )
  379.       (skip_connection): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1))
  380.     )
  381.     (1): SpatialTransformer(
  382.       (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
  383.       (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  384.       (transformer_blocks): ModuleList(
  385.         (0): BasicTransformerBlock(
  386.           (attn1): CrossAttention(
  387.             (to_q): Linear(in_features=320, out_features=320, bias=False)
  388.             (to_k): Linear(in_features=320, out_features=320, bias=False)
  389.             (to_v): Linear(in_features=320, out_features=320, bias=False)
  390.             (to_out): Sequential(
  391.               (0): Linear(in_features=320, out_features=320, bias=True)
  392.               (1): Dropout(p=0.0, inplace=False)
  393.             )
  394.           )
  395.           (ff): FeedForward(
  396.             (net): Sequential(
  397.               (0): GEGLU(
  398.                 (proj): Linear(in_features=320, out_features=2560, bias=True)
  399.               )
  400.               (1): Dropout(p=0.0, inplace=False)
  401.               (2): Linear(in_features=1280, out_features=320, bias=True)
  402.             )
  403.           )
  404.           (attn2): CrossAttention(
  405.             (to_q): Linear(in_features=320, out_features=320, bias=False)
  406.             (to_k): Linear(in_features=768, out_features=320, bias=False)
  407.             (to_v): Linear(in_features=768, out_features=320, bias=False)
  408.             (to_out): Sequential(
  409.               (0): Linear(in_features=320, out_features=320, bias=True)
  410.               (1): Dropout(p=0.0, inplace=False)
  411.             )
  412.           )
  413.           (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  414.           (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  415.           (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  416.         )
  417.       )
  418.       (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  419.     )
  420.   )
  421.   (10-11): 2 x TimestepEmbedSequential(
  422.     (0): ResBlock(
  423.       (in_layers): Sequential(
  424.         (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
  425.         (1): SiLU()
  426.         (2): Conv2d(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  427.       )
  428.       (h_upd): Identity()
  429.       (x_upd): Identity()
  430.       (emb_layers): Sequential(
  431.         (0): SiLU()
  432.         (1): Linear(in_features=1280, out_features=320, bias=True)
  433.       )
  434.       (out_layers): Sequential(
  435.         (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
  436.         (1): SiLU()
  437.         (2): Dropout(p=0, inplace=False)
  438.         (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  439.       )
  440.       (skip_connection): Conv2d(640, 320, kernel_size=(1, 1), stride=(1, 1))
  441.     )
  442.     (1): SpatialTransformer(
  443.       (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
  444.       (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  445.       (transformer_blocks): ModuleList(
  446.         (0): BasicTransformerBlock(
  447.           (attn1): CrossAttention(
  448.             (to_q): Linear(in_features=320, out_features=320, bias=False)
  449.             (to_k): Linear(in_features=320, out_features=320, bias=False)
  450.             (to_v): Linear(in_features=320, out_features=320, bias=False)
  451.             (to_out): Sequential(
  452.               (0): Linear(in_features=320, out_features=320, bias=True)
  453.               (1): Dropout(p=0.0, inplace=False)
  454.             )
  455.           )
  456.           (ff): FeedForward(
  457.             (net): Sequential(
  458.               (0): GEGLU(
  459.                 (proj): Linear(in_features=320, out_features=2560, bias=True)
  460.               )
  461.               (1): Dropout(p=0.0, inplace=False)
  462.               (2): Linear(in_features=1280, out_features=320, bias=True)
  463.             )
  464.           )
  465.           (attn2): CrossAttention(
  466.             (to_q): Linear(in_features=320, out_features=320, bias=False)
  467.             (to_k): Linear(in_features=768, out_features=320, bias=False)
  468.             (to_v): Linear(in_features=768, out_features=320, bias=False)
  469.             (to_out): Sequential(
  470.               (0): Linear(in_features=320, out_features=320, bias=True)
  471.               (1): Dropout(p=0.0, inplace=False)
  472.             )
  473.           )
  474.           (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  475.           (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  476.           (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  477.         )
  478.       )
  479.       (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  480.     )
  481.   )
  482. )
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

正序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

梦应逍遥

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表