AIGC利器:Flux模子代码底子学习

打印 上一主题 下一主题

主题 2091|帖子 2091|积分 6273

模子链接:black-forest-labs/FLUX.1-dev - Hugging Face

1:模子团体设置:

   {
    "_class_name": "FluxPipeline",
    "_diffusers_version": "0.30.0.dev0",
    "scheduler": [
      "diffusers",
  "FlowMatchEulerDiscreteScheduler"(调理器)
  (天生步骤控制: 调理器负责界说天生图像时的迭代步骤和时间进程。它确定在每个迭代中如何更新潜在表示,以实现去噪和图像天生。
  去噪过程指导: 在扩散模子中,调理器控制去噪过程的步长和方向,确保每一步都朝着天生更清楚图像的目标前进。
  参数调整: 根据天生过程的须要,调理器可以调整学习率、噪声程度等参数,以优化天生结果。)
    ],
    "text_encoder": [
      "transformers",
      "CLIPTextModel"(得当须要视觉上下文的任务。)
    ],
    "text_encoder_2": [
      "transformers",
      "T5EncoderModel"(更广泛的文本任务)
    ],
    "tokenizer": [
      "transformers",
      "CLIPTokenizer"(将原始文本转化为模子可以理解的token)
    ],
    "tokenizer_2": [
      "transformers",
      "T5TokenizerFast"(将原始文本转化为模子可以理解的token)
    ],
    "transformer": [
      "diffusers",
  "FluxTransformer2DModel"(负责图像天生过程的核心优化)
  (重要功能: 作为天生模子的核心,FluxTransformer2DModel负责在潜在空间中举行图像天生。
  特性融合: 它吸收来自文本编码器的潜在表示,并将其与图像的潜在表示结合,从而在天生过程中思量文本的语义信息。
  迭代优化: 通过调理器的指导,该模子举行多次迭代,渐渐天生和细化图像,确保天生的内容与输入文本同等。也可称之为去噪过程。)
    ],
    "vae": [
      "diffusers",
  "AutoencoderKL"(变分自编码器)
  (重要功能: 作为解码器,其负责将天生的潜在表示转换回图像空间,输出最终的图像。
  重构本领: 该模块通过学习从潜在表示重构图像,确保天生的图像在视觉上逼真且与输入文本相关。)
    ]
  2:各模块功能:


团体流程(来自Stable Diffusion 3「精神续作」FLUX.1 源码深度前瞻解读_flowmatcheulerdiscretescheduler-CSDN博客)

2.1 CLIPTextModel:

   {
    "_name_or_path": "openai/clip-vit-large-patch14",
    "architectures": [
      "CLIPTextModel"
    ],
    "attention_dropout": 0.0,(注意力机制中的dropout比例)
    "bos_token_id": 0,(句子开头的标记ID)
    "dropout": 0.0,(模子团体的dropout比例)
    "eos_token_id": 2,(句子结尾的标记ID)
    "hidden_act": "quick_gelu",
    "hidden_size": 768,隐藏层的维度
    "initializer_factor": 1.0,
    "initializer_range": 0.02,
    "intermediate_size": 3072,(中间层的大小)
    "layer_norm_eps": 1e-05,
    "max_position_embeddings": 77,模子可以处理的最大输入长度
    "model_type": "clip_text_model",
    "num_attention_heads": 12,
    "num_hidden_layers": 12,(隐藏层数目)
    "pad_token_id": 1,(填充标记的ID)
    "projection_dim": 768,(投影维度,将文本表示映射到与图像表示相同的空间)
    "torch_dtype": "bfloat16",
    "transformers_version": "4.43.3",
    "vocab_size": 49408
  }
         该模块重要负责将输入的prompt转化为池化层嵌入,之后在去噪过程中使得模子可以或许更容易地(低维,复杂性低)学习到文本与时间步之间的关系。
张量形状变化:
        输入:input_ids:[batch_size,max_length]
        嵌入层:里面有一个文本嵌入层和一个位置嵌入层(均为可学习参数的Embedding层),返回二者的和,形状为[batch_size,seq_length,hidden_size]
        input_ids先辈入嵌入层,得到hidden_states之后进入12个编码层,每一层中先归一化,后颠末一个注意力层,然后残差连接,再归一化,之后进入一个多层感知机,再残差连接,最后归一化并输出。背面的模子使用的是其池化层输出pooled_output,形状为[batch_size hidden_size]
        如果一个文本天生多个图像(注意参数num_images_per_prompt),须要重复天生的嵌入向量:调整形状为[batch_size,hidden_size * num_images_per_prompt],之后调整形状为:[batch_size *num_images_per_prompthidden_size]

2.2 T5EncoderModel:

   {
    "_name_or_path": "google/t5-v1_1-xxl",
    "architectures": [
      "T5EncoderModel"
    ],
    "classifier_dropout": 0.0,
    "d_ff": 10240, 前馈层的维度,表示每个前馈网络的隐藏层大小。
    "d_kv": 64,
    "d_model": 4096,隐藏层的维度,表示每个隐藏状态向量的大小。
    "decoder_start_token_id": 0,
    "dense_act_fn": "gelu_new",
    "dropout_rate": 0.1,
    "eos_token_id": 1,
    "feed_forward_proj": "gated-gelu",
    "initializer_factor": 1.0,
    "is_encoder_decoder": true,
    "is_gated_act": true,
    "layer_norm_epsilon": 1e-06,
    "model_type": "t5",
    "num_decoder_layers": 24,
    "num_heads": 64,
    "num_layers": 24,
    "output_past": true,
    "pad_token_id": 0,
    "relative_attention_max_distance": 128,(相对注意力机制的最大距离)
    "relative_attention_num_buckets": 32,
    "tie_word_embeddings": false,
    "torch_dtype": "bfloat16",
    "transformers_version": "4.43.3",
    "use_cache": true,
    "vocab_size": 32128
  }
         该模块重要负责将输入的prompt转化为高维的文本嵌入,其丰富的语义信息使得在去噪过程中,模子在天生图像时,可以利用这些语义信息与潜在的图像特性举行关联,从而提高天生图像的相关性和正确性。
张量形状变化:
        输入:input_ids:[batch_size,max_length]
        嵌入层:将input_ids变为嵌入向量[batch_size,max_length,d_model]
        进入一个,栈内是24个编码层的循环。在每个编码层内,先辈入一个自注意力层,在自注意力层中,先归一化,之后颠末自注意力机制计算得到输出,将嵌入向量与注意力机制输出相加得到新的向量(残差连接),此时向量形状:[batch_size,max_length,d_model]。之后输入一个前馈层,在该层中,先归一化,之后传入前馈网络,再将结果与之前的向量残差连接并输出。每个编码层输出一个元组:(嵌入向量,位置偏移、注意力权重),最后栈中嵌入向量颠末归一化返回。最后prompt_embeds形状为[batch_size sequence_lengthd_model]
        如果一个文本天生多个图像,须要重复天生的嵌入向量:prompt_embeds形状为[batch_size, sequence_length* num_images_per_prompt,d_model]之后调整为:[batch_size *num_images_per_promptsequence_lengthd_model]

2.3 FlowMatchEulerDiscreteScheduler:

   {
  "_class_name": "FlowMatchEulerDiscreteScheduler",
    "_diffusers_version": "0.30.0.dev0",
    "base_image_seq_len": 256,(输入图像序列的基准长度)
    "base_shift": 0.5,
    "max_image_seq_len": 4096, 输入图像序列的最大长度)
    "max_shift": 1.15,
    "num_train_timesteps": 1000,
    "shift": 3.0,
    "use_dynamic_shifting": true
  }
         该模块重要负责确定训练以及推理过程的时间步,以及在去噪过程中调治图像天生方向。
    获取时间步
    天生影响扩散过程中的噪声强度sigmas,每一个时间步对应一个sigma,训练时这个值控制了在该时间步中添加的噪声量。在推理时,这个值决定了模子如何从噪声规复出逼近真实数据的样本:
  1. sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
复制代码
        sigmas是一个等差数列,一开始的值最大,代表一开始要添加的噪声越强,随时间的推移添加的噪音越来越小。随后调用retrieve_timesteps函数获取去噪过程的时间步。该部门重要调用FlowMatchEulerDiscreteScheduler的set_timesteps方法,本模子在确定时间步时会根据输入图像的大小动态调整时间步,这个调整依靠于根据输入图像的尺寸来计算一个偏置参数mu:
  1. def calculate_shift(
  2.     image_seq_len,
  3.     base_seq_len: int = 256,
  4.     max_seq_len: int = 4096,
  5.     base_shift: float = 0.5,
  6.     max_shift: float = 1.16,
  7. ):
  8.     # 斜率m:
  9.     m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
  10.     # 截距:
  11.     b = base_shift - m * base_seq_len
  12.     # 给定图像序列长度的偏移量
  13.     mu = image_seq_len * m + b
  14.     return mu
复制代码

        之后调用time_shift函数调整sigmas中的每一个元素:
                                             

        最后将sigmas举行缩放:将其映射到一个与训练阶段相匹配的时间步范围(这样可以确保在推理阶段与训练阶段保持同等性)。将缩放的结果作为推理阶段的时间步返回。
去噪调治
        在获取到第t个时间步的噪音noise_pred后,怎么将潜在图像由第t个时间步阶段回退到第t-1个时间步阶段呢,这时间须要调用FlowMatchEulerDiscreteScheduler的step函数更新潜在图像。
     在step函数中,须要获取第t个时间步的噪声强度sigma和第t-1个时间步的噪声强度sigma_next,之后举行缩放:
  1. new_latents=noise_pred+(sigma_next - sigma) * latents
复制代码
2.4 FluxTransformer2DModel:

   {
    "_class_name": "FluxTransformer2DModel",
    "_diffusers_version": "0.30.0.dev0",
    "_name_or_path": "../checkpoints/flux-dev/transformer",
    "attention_head_dim": 128,
    "guidance_embeds": true,
    "in_channels": 64,
    "joint_attention_dim": 4096,(联合注意力维度)
    "num_attention_heads": 24,
    "num_layers": 19,
    "num_single_layers": 38,
    "patch_size": 1,
    "pooled_projection_dim": 768(池化投影维度)
  }
          该模块重要负责融合时间步、引导强度、池化文本嵌入、高维文本嵌入和潜在图像,在去噪过程中猜测每一步的噪音图像,对潜在图像举行更新:
        在该模块须要注意位置嵌入、时间-引导强度-池化文本融合嵌入和模态融合。
       位置嵌入(本模块使用三维旋转位置嵌入):
  1. def get_1d_rotary_pos_embed(
  2.     dim: int,
  3.     pos: Union[np.ndarray, int],
  4.     theta: float = 10000.0,#频率计算的缩放因子
  5.     use_real=False,# 是否返回实部和虚部分开的结果
  6.     linear_factor=1.0,# 上下文外推的缩放因子
  7.     ntk_factor=1.0,#缩放因子
  8.     repeat_interleave_real=True,#以特定方式重复真实部分和虚拟部分
  9.     freqs_dtype=torch.float32,  #  torch.float32, torch.float64 (flux)
  10. ):
  11.     assert dim % 2 == 0
  12.     if isinstance(pos, int):
  13.         pos = torch.arange(pos)
  14.     if isinstance(pos, np.ndarray):
  15.         pos = torch.from_numpy(pos)  # [seq_length+h/2*w/2]
  16.     theta = theta * ntk_factor
  17.     # (1)间隔为2,生成0到dim的张量:[dim/2]
  18.     # (2)取前dim/2个:每个除以dim得到频率
  19.     freqs = (
  20.         1.0
  21.         / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
  22.         / linear_factor
  23.     )
  24.     # 位置id与频率做外积:freqs包含了每个位置的频率信息
  25.     freqs = torch.outer(pos, freqs)  # [seq_length+h/2*w/2, dim/2]
  26.     if use_real and repeat_interleave_real:
  27.         #.repeat_interleave(2, dim=1)在第一维(列)重复2遍
  28.         freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float()  # [seq_length+h/2*w/2, dim]
  29.         freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float()  # [seq_length+h/2*w/2, dim]
  30.         return freqs_cos, freqs_sin
  31.     elif use_real:
  32.         # stable audio
  33.         freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()  # [seq_length+h/2*w/2, dim]
  34.         freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()  # [seq_length+h/2*w/2, dim]
  35.         return freqs_cos, freqs_sin
  36.     else:
  37.         #freqs:相位
  38.         #x·x的旋转位置编码=将x逆时针旋转一定度数,有助于保持元素的相对顺序
  39.         freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 
  40. # [seq_length+h/2*w/2, dim/2]
  41.         return freqs_cis
  42. #三维旋转位置编码
  43. class FluxPosEmbed(nn.Module):
  44.        def __init__(self, theta: int, axes_dim: List[int]):
  45.         super().__init__()
  46.         self.theta = theta
  47.         self.axes_dim = axes_dim
  48.     def forward(self, ids: torch.Tensor) -> torch.Tensor:
  49.         # 输入张量 ids 的最后一个维度的大小,即轴的数量[16,56,56]
  50.         n_axes = ids.shape[-1]
  51.         cos_out = []
  52.         sin_out = []
  53.         pos = ids.float()
  54.         is_mps = ids.device.type == "mps"
  55.         freqs_dtype = torch.float32 if is_mps else torch.float64
  56.         # 轴的数量必然和位置id的第二维列数相同
  57.         for i in range(n_axes):
  58.             cos, sin = get_1d_rotary_pos_embed(
  59.                 self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
  60.             )
  61.             #cos, sin除最后一维外保持原样,最后一维是axes_dim[i]
  62.             cos_out.append(cos)
  63.             sin_out.append(sin)
  64.         # 在最后一维拼接
  65.         freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
  66.         freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
  67.         #形状:[seq_length+h/2*w/2 , sum(n_axes)]
  68.         return freqs_cos, freqs_sin
复制代码
       位置编码的输入是文本id和图像id,文本id为全零的形状为[seq_length,3]的张量,图像id天生过程为:
  1. def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
  2.     #创建一个高度和宽度均为height和width一半的全零张量,第三个维度为3,代表RGB图像的三个颜色通道。
  3.     latent_image_ids = torch.zeros(height // 2, width // 2, 3)
  4.     #更改第二个通道,为每一行的所有像素点添加一个递增的值
  5.     #latent_image_ids[..., 1]的形状是[height // 2, width // 2]。
  6.     #torch.arange(height // 2)[:, None]的形状是[height // 2, 1]。
  7.     latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
  8.     # 更改第三个通道,为每一列的所有像素点添加一个递增的值
  9.     # torch.arange(width // 2)[ None,:]的形状是[ 1,width//2]。
  10.     latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
  11.     latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
  12.     #原本形状为[h,w,3],形状变为[h*w,3]
  13.     latent_image_ids = latent_image_ids.reshape(
  14.         latent_image_id_height * latent_image_id_width, latent_image_id_channels
  15.     )
  16.     #[h*w,3]
  17.     return latent_image_ids.to(device=device, dtype=dtype)
复制代码
       二者在第0维拼接为位置编码的输入id,形状为[seq_length+h/2*w/2,3],该嵌入在注意力层被应用于q和k上,应用方式为:
  1. # 该函数将旋转嵌入应用于输入张量 x,通常表示注意力机制中的查询或键张量
  2. def apply_rotary_emb(
  3. #x: [batch_size, heads, seq_length+h//2*w//2, head_dim]
  4.     x: torch.Tensor,   
  5. freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],# 一个包含两个PyTorch张量的元组,这两个张量分别代表余弦和正弦的频率。
  6.     use_real: bool = True,
  7.     use_real_unbind_dim: int = -1,
  8. ) -> Tuple[torch.Tensor, torch.Tensor]:
  9.     if use_real:
  10.         cos, sin = freqs_cis  # [seq_length+h//2*w//2, sum(n_axes)]
  11.         # 最前面添加两维,以符合x的维度:[batch_size , heads , seq_length+h//2*w//2 , head_dim]
  12.         cos = cos[None, None] # [1 , 1 , seq_length+h//2*w//2 , sum(n_axes)]
  13.         sin = sin[None, None]
  14.         cos, sin = cos.to(x.device), sin.to(x.device)
  15.         #在最后一维解绑
  16.         if use_real_unbind_dim == -1:
  17.             #*x.shape[:-1]表示将x的前三维维度作为新的x的前三维维度,第四维是自适应,第五维是2,之后取消第五维,分为两部分
  18.             x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)  # [B, S, H, D//2]
  19.             x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
  20.         elif use_real_unbind_dim == -2:
  21.             # Used for Stable Audio
  22.             x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)  # [B, S, H, D//2]
  23.             x_rotated = torch.cat([-x_imag, x_real], dim=-1)
  24.         else:
  25.             raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
  26.         out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
  27.         return out
复制代码
时间-引导强度-池化文本融合嵌入:
        使用CombinedTimestepGuidanceTextProjEmbeddings:
  1. class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
  2.     def __init__(self, embedding_dim, pooled_projection_dim):
  3.         super().__init__()
  4.         #生成每个时间步和频率的组合并将其正弦值和余弦值拼接,最终形状为[batch_size,num_channels]
  5.         self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
  6.         #embedding=inner_dim
  7.         self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
  8.         self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
  9.         self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
  10.     def forward(self, timestep, guidance, pooled_projection):
  11.         timesteps_proj = self.time_proj(timestep)
  12.         timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
  13.         guidance_proj = self.time_proj(guidance)
  14.         guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) 
  15.         time_guidance_emb = timesteps_emb + guidance_emb
  16.         pooled_projections = self.text_embedder(pooled_projection)
  17.         conditioning = time_guidance_emb + pooled_projections
  18.         #形状为[batch_size,num_channels]
  19.         return conditioning
  20. class Timesteps(nn.Module):
  21.     def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
  22.         super().__init__()
  23.         self.num_channels = num_channels
  24.         self.flip_sin_to_cos = flip_sin_to_cos
  25.         self.downscale_freq_shift = downscale_freq_shift
  26.         self.scale = scale
  27.     def forward(self, timesteps):
  28.         t_emb = get_timestep_embedding(
  29.             timesteps,
  30.             self.num_channels,
  31.             flip_sin_to_cos=self.flip_sin_to_cos,
  32.             downscale_freq_shift=self.downscale_freq_shift,
  33.             scale=self.scale,
  34.         )
  35.         return t_emb
  36. def get_timestep_embedding(
  37.     timesteps: torch.Tensor,
  38.     embedding_dim: int,
  39.     flip_sin_to_cos: bool = False,#否将嵌入顺序从 sin, cos 翻转为 cos, sin。
  40.     downscale_freq_shift: float = 1,#控制维度间频率变化的缩放因子。
  41.     scale: float = 1,#应用于嵌入的缩放因子。
  42.     max_period: int = 10000,#控制嵌入的最大频率。
  43. ):
  44.        assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
  45.     half_dim = embedding_dim // 2
  46.     exponent = -math.log(max_period) * torch.arange(
  47.         start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
  48.     )
  49.     exponent = exponent / (half_dim - downscale_freq_shift)
  50.     emb = torch.exp(exponent)
  51.     #生成每个时间步和频率的组合:[batch_size, half_dim]
  52.     emb = timesteps[:, None].float() * emb[None, :]
  53.     # scale embeddings
  54.     emb = scale * emb
  55.     # concat sine and cosine embeddings
  56.     emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
  57.     # flip sine and cosine embeddings
  58.     if flip_sin_to_cos:
  59.         emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
  60.     # zero pad
  61.     #如果 embedding_dim 是奇数,使用 torch.nn.functional.pad 在第二个维度上补零,以确保输出张量的形状是偶数维度。
  62.     if embedding_dim % 2 == 1:
  63.         emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
  64.     return emb
复制代码
模态融合:
        在自注意力层中须要举行模态融合,本模子重要包括文本模态与图像模态,在模态融合前须要举行维度对齐,图像模态由初始的潜在图像[batch_size, num_channels_latents, height, width],颠末图像压缩,变为[batch_size, (h/2) * (w /2), in_channels],之后通过一个线性层形状变化为[batch_size (h/2) * (w /2) inner_dim],此中inner_dim= num_attention_heads * attention_head_dim。
        文本模态由形状[batch_size, sequence_length,d_dim]颠末一个线性层变化为[batch_size sequence_lengthinner_dim]
        本模子的模态融合具体实现为:先将潜在图像颠末映射得到query、key、value,之后调整形状为[batch_size,heads,seq_length,head_dim],同理将文本嵌入映射得到encoder_hidden_states_query_proj、encoder_hidden_states_key_proj、encoder_hidden_states_value_proj,之后调整形状为[batch_size,heads,seq_length,head_dim],最后举行拼接
  1. query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
  2. key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
  3. value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
复制代码
     之后应用位置嵌入,得到注意力机制的结果,调整形状为[batch_size, seq_length+h//2*w//2, inner_dim],取结果的第一维的前seq_length为encoder_hidden_states,取结果的第一维的后h//2*w//2为hidden_states。
张量形状变化:
        模态对齐:潜入图像形状变为[batch_size, (h//2) * (w// 2),inner_dim],融合时间步、引导嵌入、文本池化嵌入变为[batch_size,256],之后变为[batch_size,inner_dim],将高维文本嵌入形状变为[batch_size, sequence_length,inner_dim] ,融合文本id和图像id并举行旋转位置嵌入为[(h/2)*(w/2)+seq_length,16+56+56]。
        之后进入19个MMDiT模块循环:

MMDiT模块实行流程

       此中,y是时间步-引导嵌入-池化文本嵌入,c是高维文本嵌入,x是潜入图像,具体为:潜在图像hidden_states和高维文本嵌入encoder_hidden_states分别举行(归一化->注意力模块(模态融合)->残差连接->归一化->前馈网络->残差连接),循环后将二者在第二维拼接为hidden_states:[batch_size, (h//2) * (w// 2)+ sequence_length,inner_dim],这是新的hidden_states。
        在MMDiT中,须要注意的是自顺应归一化AdaLayerNormZero和归一化后利用天生的门控和偏移调整注意力权重:
  1. class AdaLayerNormZero(nn.Module):
  2.        def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
  3.         super().__init__()
  4.         if num_embeddings is not None:
  5.             self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
  6.         else:
  7.             self.emb = None
  8.         self.silu = nn.SiLU()
  9.         self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
  10.         if norm_type == "layer_norm":
  11.             self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
  12.         elif norm_type == "fp32_layer_norm":
  13.             self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
  14.         else:
  15.             raise ValueError(
  16.                 f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
  17.             )
  18.     def forward(
  19.         self,
  20.         x: torch.Tensor,
  21.         timestep: Optional[torch.Tensor] = None,
  22.         class_labels: Optional[torch.LongTensor] = None,
  23.         hidden_dtype: Optional[torch.dtype] = None,
  24.         emb: Optional[torch.Tensor] = None,
  25.     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  26.         if self.emb is not None:
  27.             emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
  28.         emb = self.linear(self.silu(emb))
  29.         # 分割成6块:每一块形状【batch_size,embedding_dim(也就是inner_dim)】
  30.         #shift_msa:调整多头自注意力输出的偏移量(用法:+)
  31.         #scale_msa:缩放多头自注意力输出(用法:扩展维度后,*)
  32.         #gate_msa:控制多头自注意力输出的激活程度(用法:扩展维度后,*)
  33.         #shift_mlp:用于调整前馈网络(MLP)输出的偏移量
  34.         #scale_mlp:用于缩放前馈网络输出
  35.         #gate_mlp:用于控制前馈网络输出的激活程度
  36.         shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
  37.         #scale_msa[:, None]形状【batch_size,1,embedding_dim】,*操作在第一维广播(即重复4096次变成【batch_size,4096,inner_dim】)
  38.         x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
  39.         return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
复制代码
        在AdaLayerNormZero类中,flux模子须要将输入的emb(即融合的时间步、引导嵌入、文本池化嵌入)和x(即潜在图像)举行处理,创建输出维度为6倍输入维度的线性层用于天生shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp,其功能在代码中均有描述。
       在MMDiT模块中,当举行多头自注意力计算后,会天生attn_output和 context_attn_output,之后举行如下操纵(二者同理,可看图MMDiT模块实行流程):
  1. #gate_msa表示一个门控信号,用于控制attn_output中的信息流。
  2. attn_output = gate_msa.unsqueeze(1) * attn_output
  3. #残差连接
  4. hidden_states = hidden_states + attn_output
  5. #归一化,以后进行放缩与平移
  6. norm_hidden_states = self.norm2(hidden_states)
  7. norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
  8. #通过前馈网络并使用门控信号控制信息流
  9. ff_output = self.ff(norm_hidden_states)
  10. ff_output = gate_mlp.unsqueeze(1) * ff_output
  11. #残差连接
  12. hidden_states = hidden_states + ff_output
复制代码
        之后进入38个DiT模块:
        hidden_states举行(归一化->线性层->注意力模块->注意力结果与线性层结果第二维拼接[batch_size, (h//2) * (w// 2)+ sequence_length,inner_dim+ inner_dim*4]->线性层->残差连接),输出为hidden_states,形状为[batch_size, (h//2) * (w// 2)+ sequence_length,inner_dim]。
        在DiT模块中,与MMDiT模块差别的是,一个是输入只有hidden_states不用举行模态融合;其二是自顺应归一化AdaLayerNormZeroSingle只输出x和x, gate_msa,但是自顺应归一化的原理和流程基本同等;其三是用了并行注意力层:

并行自注意层( 来自Stable Diffusion 3「精神续作」FLUX.1 源码深度前瞻解读_flowmatcheulerdiscretescheduler-CSDN博客 ) 

         可以看代码发现注意力层和多层感知机的输入同等,可看作并行,对比MMDiT先辈入注意力层后进入多层感知机,是串行的:
  1. #用于后续残差连接
  2. residual = hidden_states
  3. #自适应归一化
  4. norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
  5. #输入norm_hidden_states进入多层感知机
  6. mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
  7. joint_attention_kwargs = joint_attention_kwargs or {}
  8. #输入norm_hidden_states进行多头自主力计算,体现并行性
  9. attn_output = self.attn(
  10.     hidden_states=norm_hidden_states,
  11.     image_rotary_emb=image_rotary_emb,
  12.     **joint_attention_kwargs,
  13. )
  14. #拼接
  15. hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
  16. gate = gate.unsqueeze(1)
  17. #使用门控信号控制信息流
  18. hidden_states = gate * self.proj_out(hidden_states)
  19. #残差连接
  20. hidden_states = residual + hidden_states
复制代码
        最后取hidden_states第一维的后(h//2) * (w// 2)部门举行归一化,之后颠末线性层作为本模块的输出noise_pred:[batch_size,(h//2) * (w// 2),patch_size * patch_size * out_channels](与一开始潜在图像hidden_states形状完全同等)
        去噪循环:先辈入FluxTransformer2Dmodel变更器得到noise_pred,之后根据调理器的step方法更新潜在图像latents。
2.5 AutoencoderKL:

   {
    "_class_name": "AutoencoderKL",
    "_diffusers_version": "0.30.0.dev0",
    "_name_or_path": "../checkpoints/flux-dev",
    "act_fn": "silu",
    "block_out_channels": [(每个编码块的输出通道
      128,
      256,
      512,
      512
    ],
    "down_block_types": [
      "DownEncoderBlock2D",
      "DownEncoderBlock2D",
      "DownEncoderBlock2D",
      "DownEncoderBlock2D"
    ],
    "force_upcast": true,
    "in_channels": 3,
    "latent_channels": 16,潜在空间特性维度
    "latents_mean": null,
    "latents_std": null,
    "layers_per_block": 2,
    "mid_block_add_attention": true,
    "norm_num_groups": 32,
    "out_channels": 3,输出通道
    "sample_size": 1024,输出尺寸
    "scaling_factor": 0.3611,
    "shift_factor": 0.1159,
    "up_block_types": [
      "UpDecoderBlock2D",
      "UpDecoderBlock2D",
      "UpDecoderBlock2D",
      "UpDecoderBlock2D"
    ],
    "use_post_quant_conv": false,
    "use_quant_conv": false
  }
         本模块是将颠末去噪天生的潜在图像举行解码,得到天生的图像表示。
        首先须要将得到的潜在图像latents解压缩,这时的形状为[batch_size, num_channels_latents, h/ vae_scale_factor , w/vae_scale_factor]
  1. def _unpack_latents(latents, height, width, vae_scale_factor):
  2.     #vae_scale_factor:用于调整潜在图像分辨率的因子
  3.     #num_patches:被分割的块数
  4.     batch_size, num_patches, channels = latents.shape
  5.     height = height // vae_scale_factor
  6.     width = width // vae_scale_factor
  7.     latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
  8.     #新形状:[batch_size,  num_channels_latents, height, 2 , width, 2 ]
  9.     latents = latents.permute(0, 3, 1, 4, 2, 5)
  10.     latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
  11.     return latents
复制代码
        之后根据设置的调治因子对潜在图像放缩宁静移,然后调用vae的decode方法解码。
        使用的类为:
  1. class Decoder(nn.Module):
  2.     def __init__(
  3.         self,
  4.         ch: int,# 解码器中的初始通道数为 128
  5.         out_ch: int,# 输出图像的通道数为 3
  6.         ch_mult: list[int],# 通道的倍增系数表,定义不同分辨率层的通道数扩展情况
  7.         num_res_blocks: int,# 每个分辨率层包含两个残差块,用于特征提取
  8.         in_channels: int,# 输入的图像通道数为 3
  9.         resolution: int,# 输入图像的目标分辨率
  10.         z_channels: int,#输入图像的通道数
  11.     ):
  12.         super().__init__()
  13.         self.ch = ch
  14.         self.num_resolutions = len(ch_mult)
  15.         self.num_res_blocks = num_res_blocks
  16.         self.resolution = resolution
  17.         self.in_channels = in_channels
  18.         self.ffactor = 2 ** (self.num_resolutions - 1)
  19.         # compute in_ch_mult, block_in and curr_res at lowest res
  20.         block_in = ch * ch_mult[self.num_resolutions - 1]
  21.         curr_res = resolution // 2 ** (self.num_resolutions - 1)
  22.         self.z_shape = (1, z_channels, curr_res, curr_res)
  23.         # 卷积
  24.         self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
  25.         # 中间层
  26.         self.mid = nn.Module()
  27.         self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
  28.         self.mid.attn_1 = AttnBlock(block_in)
  29.         self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
  30.         #上采样
  31.         self.up = nn.ModuleList()
  32.         for i_level in reversed(range(self.num_resolutions)):
  33.             block = nn.ModuleList()
  34.             attn = nn.ModuleList()
  35.             block_out = ch * ch_mult[i_level]# #依次为512,512,256,128
  36.             for _ in range(self.num_res_blocks + 1):
  37.                 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
  38.                 block_in = block_out
  39.             up = nn.Module()
  40.             up.block = block
  41.             up.attn = attn
  42.             if i_level != 0:
  43.                 up.upsample = Upsample(block_in)#除最顶层外h与w倍增
  44.                 curr_res = curr_res * 2
  45.             self.up.insert(0, up)  # 将该层插入到上采样模块的第一层
  46.         # end
  47.         self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
  48.         self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
  49.     def forward(self, z: Tensor) -> Tensor:
  50.         # z to block_in
  51.         h = self.conv_in(z)
  52.         # middle
  53.         h = self.mid.block_1(h)
  54.         h = self.mid.attn_1(h)
  55.         h = self.mid.block_2(h)
  56.         # upsampling
  57.         for i_level in reversed(range(self.num_resolutions)):
  58.             for i_block in range(self.num_res_blocks + 1):
  59.                 h = self.up[i_level].block[i_block](h)
  60.                 if len(self.up[i_level].attn) > 0:
  61.                     h = self.up[i_level].attn[i_block](h)
  62.             if i_level != 0:
  63.                 h = self.up[i_level].upsample(h)
  64.         # end
  65.         h = self.norm_out(h)
  66.         h = swish(h)
  67.         h = self.conv_out(h)
  68.         return h
复制代码
        给定参数为:
  1. ae_params=AutoEncoderParams(
  2.     resolution=256,
  3.     in_channels=3,
  4.     ch=128,
  5.     out_ch=3,
  6.     ch_mult=[1, 2, 4, 4],
  7.     num_res_blocks=2,
  8.     z_channels=16,
  9.     scale_factor=0.3611,
  10.     shift_factor=0.1159,
  11. )
复制代码
        输入的潜在图像z形状为[batch_size, num_channels_latents, h/ vae_scale_factor , w/vae_scale_factor]:[1,16,128,128]
        首先辈行卷积操纵:conv_in将z从【1, 16, 128, 128】映射到【1, 512, 128, 128】。这是通过一个卷积层(nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1))实现的,此中z_channels = 16,block_in = 512。输出形状: 【1, 512, 128, 128】。
        之后颠末中间层,中间层包含两个ResnetBlock和一个注意力层,颠末处理形状稳定:【1, 512, 128, 128】。
        最后是上采样层,一共有四层,每层有3个ResnetBlock和一个注意力层,除了上采样层的第一层(最顶层)外,都有一个上采样模块:
  1. class Upsample(nn.Module):
  2.     def __init__(self, in_channels: int):
  3.         super().__init__()
  4.         self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
  5.     def forward(self, x: Tensor):
  6.         x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
  7.         x = self.conv(x)
  8.         return x
复制代码
       在上采样层的潜在图像形状变化为:
       第1层(i_level = 3,最低分辨率层)
输入形状: 1, 512, 128, 128。该层有3个ResnetBlock:
第一个ResnetBlock:输入为【1, 512, 128, 128】,输出为【1, 512, 128, 128】。
第二个ResnetBlock:输入为【1, 512, 128, 128】,输出为【1, 512, 128, 128】。
第三个ResnetBlock:输入为【1, 512, 128, 128】,输出为【1, 512, 128, 128】。
上采样操纵:将分辨率从128×128扩大到256×256。
输出形状: 1, 512, 256, 256
        第2层(i_level = 2)
输入形状: 1, 512, 256, 256。该层有3个ResnetBlock:
第一个ResnetBlock:输入为【1, 512, 256, 256】,输出为【1, 512, 256, 256】。
第二个ResnetBlock:输入为【1, 512, 256, 256】,输出为【1, 512, 256, 256】。
第三个ResnetBlock:输入为【1, 512, 256, 256】,输出为【1, 512, 256, 256】。
上采样操纵:将分辨率从256×256上采样到512×512。
输出形状: 1, 512, 512, 512
        第3层(i_level = 1)
输入形状: 【1, 512, 512, 512】。该层有3个ResnetBlock:
第一个ResnetBlock:输入为【1, 512, 512, 512】,输出为【1, 256, 512, 512】。
第二个ResnetBlock:输入为【1, 256, 512, 512】,输出为【1,256, 512, 512】。
第三个ResnetBlock:输入为【1, 256, 512, 512】,输出为【1, 256, 512, 512】。
上采样操纵:将分辨率从512×512上采样到1024×1024。
输出形状: 1, 256, 1024, 1024
        第4层(i_level = 0,最高分辨率层)
输入形状: 【1, 128, 1024, 1024】。该层有3个ResnetBlock:
第一个ResnetBlock:输入为【1, 256, 1024, 1024】,输出为【1, 128, 1024, 1024】。
第二个ResnetBlock:输入为【1, 128, 1024, 1024】,输出为【1, 128, 1024, 1024】。
第三个ResnetBlock:输入为【1, 128, 1024, 1024】,输出为【1, 128, 1024, 1024】。
不举行上采样,由于分辨率已经是目标大小。
输出形状: 1, 128, 1024, 1024
       最后再举行卷积操纵,将z从【1, 128, 1024, 1024】映射到【1,3, 1024, 1024】。
在解码中用到的ResnetBlock如下,它用于在神经网络中引入跳跃连接使得网络可以或许更容易地举行梯度传播,
  1. class ResnetBlock(nn.Module):
  2.     def __init__(self, in_channels: int, out_channels: int):
  3.         super().__init__()
  4.         self.in_channels = in_channels
  5.         out_channels = in_channels if out_channels is None else out_channels
  6.         self.out_channels = out_channels
  7.         #组归一,不依赖于批量大小,因此适合较小批次或动态批次大小的训练
  8.         self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
  9.         self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
  10.         self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
  11.         self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
  12.         if self.in_channels != self.out_channels:
  13.             self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
  14.     def forward(self, x):
  15.         h = x
  16.         h = self.norm1(h)
  17.         h = swish(h)
  18.         h = self.conv1(h)
  19.         h = self.norm2(h)
  20.         h = swish(h)
  21.         h = self.conv2(h)
  22.         #当输入输出通道数不一致时,通过该卷积层调整通道数使其一致,方便残差连接
  23.         if self.in_channels != self.out_channels:
  24.             x = self.nin_shortcut(x)
  25.         return x + h
复制代码
        完成解码以后,利用vae的图像天生方法将图像表示(形状【1,3, 1024, 1024】)转化为PIL图像。
        首先将图像表示由张量范例变为numpy数组,之后将图像的像素值从[0, 1]范围(对应于浮点型图像)转换到[0, 255]范围(对应于uint8范例的图像),然后取整为整数范例(uint8)。随后调用Image库的fromarray方法传入图像表示,得到一个图像对象,即为最后的结果:一个分辨率为1024x1024的RGB图像。
3.示例代码:

  1. import torch
  2. from diffusers import FluxPipeline
  3. pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
  4. pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
  5. prompt = "A boy with short hair, glasses, 1.8 meters tall, wearing a light green short down jacket, blue jeans, beige Martin boots, and a gray shirt inside the down jacket"
  6. image = pipe(
  7.     prompt,
  8.     height=1024,
  9.     width=1024,
  10.     guidance_scale=3.5,
  11.     num_inference_steps=50,
  12.     max_sequence_length=512,
  13.     generator=torch.Generator("cpu").manual_seed(0)
  14. ).images[0]
  15. image.save("flux-dev.png")
复制代码
运行结果:
                         
 

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
继续阅读请点击广告

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

守听

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表