OPPO开源Diffusion多语言适配器—— MultilingualSD3-adapter 和 ChineseFL ...

打印 上一主题 下一主题

主题 1019|帖子 1019|积分 3057


MultilingualSD3-adapter 是为 SD3 量身定制的多语言适配器。 它源自 ECCV 2024 的一篇题为 PEA-Diffusion 的论文。ChineseFLUX.1-adapter是为Flux.1系列机型量身定制的多语言适配器,理论上继续了ByT5,可支持100多种语言,但在中文方面做了额外优化。 它源于一篇题为 PEA-Diffusion 的 ECCV 2024 论文。

PEA-Diffusion

作者首先提到了一种名为 “知识蒸馏”(Knowledge Distillation,KD)的方法。KD 是一个过程,在这个过程中,一个较小的机器学习模型(称为学生)会学习模仿一个较大、较复杂的模型(称为西席)。在这种情况下,目的是让学生模型尽可能接近西席的输出或预测。这被称为 “逼迫学生分布与西席分布相匹配”。
然而,作者还盼望确保西席模型和学生模型不仅能产生相似的输出效果,而且能以相似的方式处理信息。这就是特征对齐的作用所在。他们盼望西席模型和学生模型的中间计算或特征图也能相似。因此,他们引入了一种技术,在这些模型的中间层增强这种特征对齐。如许做的目的是减少所谓的分布偏移,使学生模型更像西席模型。
现在,当模型之间存在维度差异时,OPPOer/ChineseFLUX.1-adapter 就会发挥作用。在这种情况下,他们使用的是名为 CLIP(对比语言图像预练习)的模型,一个用于英语数据,一个用于非英语数据。适配器是添加到模型中的一个小型附加组件,用于对齐或转换特征(中间计算)以匹配维度。该适配器颠末练习,可以学习特定语言的信息,而且设计高效,只有 600 万个参数。
因此,总的来说,OPPOer/ChineseFLUX.1-adapter 是一种用于调整不同模型(在本例中为英语和非英语 CLIP 模型)中特征维度的技术,以促进知识提炼并进步学生模型的性能。它是这个复杂系统中一个小而紧张的组成部分!

论文:PEA-Diffusion: Parameter-Efficient Adapter with Knowledge Distillation in non-English Text-to-Image Generation
Code: https://github.com/OPPO-Mente-Lab/PEA-Diffusion
MultilingualSD3-adapter

我们使用了多语言编码器 umt5-xxl、Mul-OpenCLIP 和 HunyuanDiT_CLIP。 我们在蒸馏练习中接纳了反向去噪处理。
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from typing import Any, Callable, Dict, List, Optional, Union
  5. import inspect
  6. from diffusers.models.transformers import SD3Transformer2DModel
  7. from diffusers.image_processor import VaeImageProcessor
  8. from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
  9. from diffusers import AutoencoderKL
  10. from tqdm import tqdm
  11. from PIL import Image
  12. from transformers import T5Tokenizer,T5EncoderModel,BertModel, BertTokenizer
  13. import open_clip
  14. class MLP(nn.Module):
  15.     def __init__(self, in_dim=1024, out_dim=2048, hidden_dim=2048, out_dim1=4096, use_residual=True):
  16.         super().__init__()
  17.         if use_residual:
  18.             assert in_dim == out_dim
  19.         self.layernorm = nn.LayerNorm(in_dim)
  20.         self.projector = nn.Sequential(
  21.             nn.Linear(in_dim, hidden_dim, bias=False),
  22.             nn.GELU(),
  23.             nn.Linear(hidden_dim, hidden_dim, bias=False),
  24.             nn.GELU(),
  25.             nn.Linear(hidden_dim, hidden_dim, bias=False),
  26.             nn.GELU(),
  27.             nn.Linear(hidden_dim, out_dim, bias=False),
  28.         )
  29.         self.fc = nn.Linear(out_dim, out_dim1)
  30.         self.use_residual = use_residual
  31.     def forward(self, x):
  32.         residual = x
  33.         x = self.layernorm(x)
  34.         x = self.projector(x)
  35.         x2 = nn.GELU()(x)
  36.         x2 = self.fc(x2)
  37.         return x2
  38. class Transformer(nn.Module):
  39.     def __init__(self, d_model,  n_heads, out_dim1, out_dim2,num_layers=1) -> None:
  40.         super().__init__()
  41.         self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=2048, batch_first=True)
  42.         self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
  43.         self.linear1 = nn.Linear(d_model, out_dim1)
  44.         self.linear2 = nn.Linear(d_model, out_dim2)
  45.    
  46.     def forward(self, x):
  47.         x = self.transformer_encoder(x)
  48.         x1 = self.linear1(x)
  49.         x1 = torch.mean(x1,1)
  50.         x2 = self.linear2(x)
  51.         return x1,x2
  52. def image_grid(imgs, rows, cols):
  53.     assert len(imgs) == rows*cols
  54.     w, h = imgs[0].size
  55.     grid = Image.new('RGB', size=(cols*w, rows*h))
  56.     grid_w, grid_h = grid.size
  57.     for i, img in enumerate(imgs):
  58.         grid.paste(img, box=(i%cols*w, i//cols*h))
  59.     return grid
  60. def retrieve_timesteps(
  61.     scheduler,
  62.     num_inference_steps: Optional[int] = None,
  63.     device: Optional[Union[str, torch.device]] = None,
  64.     timesteps: Optional[List[int]] = None,
  65.     sigmas: Optional[List[float]] = None,
  66.     **kwargs,
  67. ):
  68.     if timesteps is not None and sigmas is not None:
  69.         raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
  70.     if timesteps is not None:
  71.         accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
  72.         if not accepts_timesteps:
  73.             raise ValueError(
  74.                 f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
  75.                 f" timestep schedules. Please check whether you are using the correct scheduler."
  76.             )
  77.         scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
  78.         timesteps = scheduler.timesteps
  79.         num_inference_steps = len(timesteps)
  80.     elif sigmas is not None:
  81.         accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
  82.         if not accept_sigmas:
  83.             raise ValueError(
  84.                 f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
  85.                 f" sigmas schedules. Please check whether you are using the correct scheduler."
  86.             )
  87.         scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
  88.         timesteps = scheduler.timesteps
  89.         num_inference_steps = len(timesteps)
  90.     else:
  91.         scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
  92.         timesteps = scheduler.timesteps
  93.     return timesteps, num_inference_steps
  94. class StableDiffusionTest():
  95.     def __init__(self,model_path,text_encoder_path,text_encoder_path1,text_encoder_path2,proj_path,proj_t5_path):
  96.         super().__init__()
  97.         self.transformer = SD3Transformer2DModel.from_pretrained(model_path, subfolder="transformer",torch_dtype=dtype).to(device)
  98.         self.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(device,dtype=dtype)
  99.         self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
  100.         self.vae_scale_factor = (
  101.             2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
  102.         )
  103.         self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
  104.         self.default_sample_size = (
  105.             self.transformer.config.sample_size
  106.             if hasattr(self, "transformer") and self.transformer is not None
  107.             else 128
  108.         )
  109.         self.text_encoder_t5 = T5EncoderModel.from_pretrained(text_encoder_path).to(device,dtype=dtype)
  110.         self.tokenizer_t5 = T5Tokenizer.from_pretrained(text_encoder_path)
  111.         self.text_encoder = BertModel.from_pretrained(f"{text_encoder_path1}/clip_text_encoder", False, revision=None).to(device,dtype=dtype)
  112.         self.tokenizer = BertTokenizer.from_pretrained(f"{text_encoder_path1}/tokenizer")
  113.         self.text_encoder2, _, _ = open_clip.create_model_and_transforms('xlm-roberta-large-ViT-H-14', pretrained=text_encoder_path2)
  114.         self.tokenizer2 = open_clip.get_tokenizer('xlm-roberta-large-ViT-H-14')
  115.         self.text_encoder2.text.output_tokens = True
  116.         self.text_encoder2 = self.text_encoder2.to(device,dtype=dtype)
  117.         self.proj = MLP(2048, 2048, 2048, 4096, use_residual=False).to(device,dtype=dtype)
  118.         self.proj.load_state_dict(torch.load(proj_path, map_location="cpu"))
  119.         self.proj_t5 = Transformer(d_model=4096, n_heads=8, out_dim1=2048, out_dim2=4096).to(device,dtype=dtype)
  120.         self.proj_t5.load_state_dict(torch.load(proj_t5_path, map_location="cpu"))
  121.     def encode_prompt(self, prompt, device, do_classifier_free_guidance=True, negative_prompt=None):
  122.         batch_size = len(prompt) if isinstance(prompt, list) else 1
  123.         text_input_ids_t5 = self.tokenizer_t5(
  124.             prompt,
  125.             padding="max_length",
  126.             max_length=77,
  127.             truncation=True,
  128.             add_special_tokens=False,
  129.             return_tensors="pt",
  130.         ).input_ids.to(device)
  131.         text_embeddings = self.text_encoder_t5(text_input_ids_t5)
  132.         text_inputs = self.tokenizer(
  133.             prompt,
  134.             padding="max_length",
  135.             max_length=77,
  136.             truncation=True,
  137.             return_tensors="pt",
  138.         )
  139.         input_ids = text_inputs.input_ids.to(device)
  140.         attention_mask = text_inputs.attention_mask.to(device)
  141.         encoder_hidden_states  = self.text_encoder(input_ids,attention_mask=attention_mask)[0]
  142.         text_input_ids = self.tokenizer2(prompt).to(device)
  143.         _,encoder_hidden_states2  = self.text_encoder2.encode_text(text_input_ids)
  144.         encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states2], dim=-1)
  145.         encoder_hidden_states_t5 = text_embeddings[0]
  146.         encoder_hidden_states = self.proj(encoder_hidden_states)
  147.         add_text_embeds,encoder_hidden_states_t5 = self.proj_t5(encoder_hidden_states_t5.half())
  148.         prompt_embeds = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=-2)
  149.         # get unconditional embeddings for classifier free guidance
  150.         if do_classifier_free_guidance:
  151.             if negative_prompt is None:
  152.                 uncond_tokens = [""] * batch_size
  153.             else:
  154.                 uncond_tokens = negative_prompt
  155.             text_input_ids_t5 = self.tokenizer_t5(
  156.                 uncond_tokens,
  157.                 padding="max_length",
  158.                 max_length=77,
  159.                 truncation=True,
  160.                 add_special_tokens=False,
  161.                 return_tensors="pt",
  162.             ).input_ids.to(device)
  163.             text_embeddings = self.text_encoder_t5(text_input_ids_t5)
  164.             text_inputs = self.tokenizer(
  165.                 uncond_tokens,
  166.                 padding="max_length",
  167.                 max_length=77,
  168.                 truncation=True,
  169.                 return_tensors="pt",
  170.             )
  171.             input_ids = text_inputs.input_ids.to(device)
  172.             attention_mask = text_inputs.attention_mask.to(device)
  173.             encoder_hidden_states  = self.text_encoder(input_ids,attention_mask=attention_mask)[0]
  174.             text_input_ids = self.tokenizer2(uncond_tokens).to(device)
  175.             _,encoder_hidden_states2  = self.text_encoder2.encode_text(text_input_ids)
  176.             encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states2], dim=-1)
  177.             encoder_hidden_states_t5 = text_embeddings[0]
  178.             encoder_hidden_states_uncond = self.proj(encoder_hidden_states)
  179.             add_text_embeds_uncond,encoder_hidden_states_t5_uncond = self.proj_t5(encoder_hidden_states_t5.half())
  180.             prompt_embeds_uncond = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_t5_uncond], dim=-2)
  181.             prompt_embeds = torch.cat([prompt_embeds_uncond, prompt_embeds], dim=0)
  182.             pooled_prompt_embeds = torch.cat([add_text_embeds_uncond, add_text_embeds], dim=0)
  183.         return prompt_embeds,pooled_prompt_embeds
  184.     def prepare_latents(
  185.         self,
  186.         batch_size,
  187.         num_channels_latents,
  188.         height,
  189.         width,
  190.         dtype,
  191.         device,
  192.         generator,
  193.         latents=None,
  194.     ):
  195.         if latents is not None:
  196.             return latents.to(device=device, dtype=dtype)
  197.         shape = (
  198.             batch_size,
  199.             num_channels_latents,
  200.             int(height) // self.vae_scale_factor,
  201.             int(width) // self.vae_scale_factor,
  202.         )
  203.         if isinstance(generator, list) and len(generator) != batch_size:
  204.             raise ValueError(
  205.                 f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
  206.                 f" size of {batch_size}. Make sure the batch size matches the length of the generators."
  207.             )
  208.         latents = torch.randn(shape, generator=generator, dtype=dtype).to(device)
  209.         return latents
  210.     @property
  211.     def guidance_scale(self):
  212.         return self._guidance_scale
  213.     @property
  214.     def clip_skip(self):
  215.         return self._clip_skip
  216.     # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
  217.     # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
  218.     # corresponds to doing no classifier free guidance.
  219.     @property
  220.     def do_classifier_free_guidance(self):
  221.         return self._guidance_scale > 1
  222.     @property
  223.     def joint_attention_kwargs(self):
  224.         return self._joint_attention_kwargs
  225.     @property
  226.     def num_timesteps(self):
  227.         return self._num_timesteps
  228.     @property
  229.     def interrupt(self):
  230.         return self._interrupt
  231.     @torch.no_grad()
  232.     def __call__(
  233.         self,
  234.         prompt: Union[str, List[str]] = None,
  235.         prompt_2: Optional[Union[str, List[str]]] = None,
  236.         prompt_3: Optional[Union[str, List[str]]] = None,
  237.         height: Optional[int] = None,
  238.         width: Optional[int] = None,
  239.         num_inference_steps: int = 28,
  240.         timesteps: List[int] = None,
  241.         guidance_scale: float = 7.0,
  242.         negative_prompt: Optional[Union[str, List[str]]] = None,
  243.         negative_prompt_2: Optional[Union[str, List[str]]] = None,
  244.         negative_prompt_3: Optional[Union[str, List[str]]] = None,
  245.         num_images_per_prompt: Optional[int] = 1,
  246.         generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
  247.         latents: Optional[torch.FloatTensor] = None,
  248.         prompt_embeds: Optional[torch.FloatTensor] = None,
  249.         negative_prompt_embeds: Optional[torch.FloatTensor] = None,
  250.         pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
  251.         negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
  252.         output_type: Optional[str] = "pil",
  253.         return_dict: bool = True,
  254.         joint_attention_kwargs: Optional[Dict[str, Any]] = None,
  255.         clip_skip: Optional[int] = None,
  256.         callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
  257.         callback_on_step_end_tensor_inputs: List[str] = ["latents"],
  258.     ):
  259.         height = height or self.default_sample_size * self.vae_scale_factor
  260.         width = width or self.default_sample_size * self.vae_scale_factor
  261.         self._guidance_scale = guidance_scale
  262.         self._clip_skip = clip_skip
  263.         self._joint_attention_kwargs = joint_attention_kwargs
  264.         self._interrupt = False
  265.         if prompt is not None and isinstance(prompt, str):
  266.             batch_size = 1
  267.         elif prompt is not None and isinstance(prompt, list):
  268.             batch_size = len(prompt)
  269.         else:
  270.             batch_size = prompt_embeds.shape[0]
  271.         prompt_embeds,pooled_prompt_embeds = self.encode_prompt(prompt, device)
  272.         timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
  273.         num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
  274.         self._num_timesteps = len(timesteps)
  275.         num_channels_latents = self.transformer.config.in_channels
  276.         latents = self.prepare_latents(
  277.             batch_size * num_images_per_prompt,
  278.             num_channels_latents,
  279.             height,
  280.             width,
  281.             prompt_embeds.dtype,
  282.             device,
  283.             generator,
  284.             latents,
  285.         )
  286.         for i, t in tqdm(enumerate(timesteps)):
  287.             if self.interrupt:
  288.                 continue
  289.             latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
  290.             timestep = t.expand(latent_model_input.shape[0]).to(dtype=dtype)
  291.             noise_pred = self.transformer(
  292.                 hidden_states=latent_model_input,
  293.                 timestep=timestep,
  294.                 encoder_hidden_states=prompt_embeds.to(dtype=self.transformer.dtype),
  295.                 pooled_projections=pooled_prompt_embeds.to(dtype=self.transformer.dtype),
  296.                 joint_attention_kwargs=self.joint_attention_kwargs,
  297.                 return_dict=False,
  298.             )[0]
  299.             if self.do_classifier_free_guidance:
  300.                 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
  301.                 noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
  302.             latents_dtype = latents.dtype
  303.             latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
  304.             if latents.dtype != latents_dtype:
  305.                 if torch.backends.mps.is_available():
  306.                     latents = latents.to(latents_dtype)
  307.             if callback_on_step_end is not None:
  308.                 callback_kwargs = {}
  309.                 for k in callback_on_step_end_tensor_inputs:
  310.                     callback_kwargs[k] = locals()[k]
  311.                 callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
  312.                 latents = callback_outputs.pop("latents", latents)
  313.                 prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
  314.                 negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
  315.                 negative_pooled_prompt_embeds = callback_outputs.pop(
  316.                     "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
  317.                 )
  318.         if output_type == "latent":
  319.             image = latents
  320.         else:
  321.             latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
  322.             image = self.vae.decode(latents, return_dict=False)[0]
  323.             image = self.image_processor.postprocess(image, output_type=output_type)
  324.         return image
  325. if __name__ == '__main__':
  326.     device = "cuda"
  327.     dtype = torch.float16
  328.     text_encoder_path = 'google/umt5-xxl'
  329.     text_encoder_path1 = "Tencent-Hunyuan/HunyuanDiT/t2i"
  330.     text_encoder_path2 = 'laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/open_clip_pytorch_model.bin'
  331.     model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
  332.     proj_path =  "OPPOer/MultilingualSD3-adapter/pytorch_model.bin"
  333.     proj_t5_path =  "OPPOer/MultilingualSD3-adapter/pytorch_model_t5.bin"
  334.     sdt = StableDiffusionTest(model_path,text_encoder_path,text_encoder_path1,text_encoder_path2,proj_path,proj_t5_path)
  335.     batch=2
  336.     height = 1024
  337.     width = 1024      
  338.     while True:
  339.         raw_text = input("\nPlease Input Query (stop to exit) >>> ")
  340.         if not raw_text:
  341.             print('Query should not be empty!')
  342.             continue
  343.         if raw_text == "stop":
  344.             break
  345.         images = sdt([raw_text]*batch,height=height,width=width)
  346.         grid = image_grid(images, rows=1, cols=batch)
  347.         grid.save("MultilingualSD3.png")
复制代码
ChineseFLUX.1-adapter

使用了多语言编码器 byt5-xxl,在顺应过程中使用的西席模型是 FLUX.1-schnell。 我们接纳了反向去噪过程进行蒸馏练习。 理论上,该适配器可应用于任何 FLUX.1 系列模型。 我们在此提供以下应用示例。
  1. from diffusers import FluxPipeline, AutoencoderKL
  2. from diffusers.image_processor import VaeImageProcessor
  3. from transformers import T5ForConditionalGeneration,AutoTokenizer
  4. import torch
  5. import torch.nn as nn
  6. class MLP(nn.Module):
  7.     def __init__(self, in_dim=4096, out_dim=4096, hidden_dim=4096, out_dim1=768, use_residual=True):
  8.         super().__init__()
  9.         self.layernorm = nn.LayerNorm(in_dim)
  10.         self.projector = nn.Sequential(
  11.             nn.Linear(in_dim, hidden_dim, bias=False),
  12.             nn.GELU(),
  13.             nn.Linear(hidden_dim, hidden_dim, bias=False),
  14.             nn.GELU(),
  15.             nn.Linear(hidden_dim, out_dim, bias=False),
  16.         )
  17.         self.fc = nn.Linear(out_dim, out_dim1)
  18.     def forward(self, x):
  19.         x = self.layernorm(x)
  20.         x = self.projector(x)
  21.         x2 = nn.GELU()(x)
  22.         x1 = self.fc(x2)
  23.         x1 = torch.mean(x1,1)
  24.         return x1,x2
  25. dtype = torch.bfloat16
  26. device = "cuda"
  27. ckpt_id = "black-forest-labs/FLUX.1-schnell"
  28. text_encoder_ckpt_id = 'google/byt5-xxl'
  29. proj_t5 = MLP(in_dim=4672, out_dim=4096, hidden_dim=4096, out_dim1=768).to(device=device,dtype=dtype)
  30. text_encoder_t5 = T5ForConditionalGeneration.from_pretrained(text_encoder_ckpt_id).get_encoder().to(device=device,dtype=dtype)
  31. tokenizer_t5 = AutoTokenizer.from_pretrained(text_encoder_ckpt_id)
  32. proj_t5_save_path = f"diffusion_pytorch_model.bin"
  33. state_dict = torch.load(proj_t5_save_path, map_location="cpu")
  34. state_dict_new = {}
  35. for k,v in state_dict.items():
  36.     k_new = k.replace("module.","")
  37.     state_dict_new[k_new] = v
  38. proj_t5.load_state_dict(state_dict_new)
  39. pipeline = FluxPipeline.from_pretrained(
  40.     ckpt_id, text_encoder=None, text_encoder_2=None,
  41.     tokenizer=None, tokenizer_2=None, vae=None,
  42.     torch_dtype=torch.bfloat16
  43. ).to(device)
  44. vae = AutoencoderKL.from_pretrained(
  45.     ckpt_id,
  46.     subfolder="vae",
  47.     torch_dtype=torch.bfloat16
  48. ).to(device)
  49. vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
  50. image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
  51. while True:
  52.     raw_text = input("\nPlease Input Query (stop to exit) >>> ")
  53.     if not raw_text:
  54.         print('Query should not be empty!')
  55.         continue
  56.     if raw_text == "stop":
  57.         break
  58.     with torch.no_grad():
  59.         text_inputs = tokenizer_t5(
  60.             raw_text,
  61.             padding="max_length",
  62.             max_length=256,
  63.             truncation=True,
  64.             add_special_tokens=True,
  65.             return_tensors="pt",
  66.         ).input_ids.to(device)
  67.         text_embeddings = text_encoder_t5(text_inputs)[0]
  68.         pooled_prompt_embeds,prompt_embeds = proj_t5(text_embeddings)
  69.         height, width = 1024, 1024
  70.         latents = pipeline(
  71.             prompt_embeds=prompt_embeds,
  72.             pooled_prompt_embeds=pooled_prompt_embeds,
  73.             num_inference_steps=4, guidance_scale=0,
  74.             height=height, width=width,
  75.             output_type="latent",
  76.         ).images
  77.         latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
  78.         latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
  79.         image = vae.decode(latents, return_dict=False)[0]
  80.         image = image_processor.postprocess(image, output_type="pil")
  81.         image[0].save("ChineseFLUX.jpg")
复制代码
MultilingualOpenFLUX.1

OpenFLUX.1
是对 FLUX.1-schnell 模型的微调,已对其进行了蒸馏练习。 请务必更新以下代码中下载的 fast-lora.safetensors 的路径。
  1. from diffusers import FluxPipeline, AutoencoderKL
  2. from diffusers.image_processor import VaeImageProcessor
  3. from transformers import T5ForConditionalGeneration,AutoTokenizer
  4. import torch
  5. import torch.nn as nn
  6. class MLP(nn.Module):
  7.     def __init__(self, in_dim=4096, out_dim=4096, hidden_dim=4096, out_dim1=768, use_residual=True):
  8.         super().__init__()
  9.         self.layernorm = nn.LayerNorm(in_dim)
  10.         self.projector = nn.Sequential(
  11.             nn.Linear(in_dim, hidden_dim, bias=False),
  12.             nn.GELU(),
  13.             nn.Linear(hidden_dim, hidden_dim, bias=False),
  14.             nn.GELU(),
  15.             nn.Linear(hidden_dim, out_dim, bias=False),
  16.         )
  17.         self.fc = nn.Linear(out_dim, out_dim1)
  18.     def forward(self, x):
  19.         x = self.layernorm(x)
  20.         x = self.projector(x)
  21.         x2 = nn.GELU()(x)
  22.         x1 = self.fc(x2)
  23.         x1 = torch.mean(x1,1)
  24.         return x1,x2
  25. dtype = torch.bfloat16
  26. device = "cuda"
  27. ckpt_id = "ostris/OpenFLUX.1"
  28. text_encoder_ckpt_id = 'google/byt5-xxl'
  29. proj_t5 = MLP(in_dim=4672, out_dim=4096, hidden_dim=4096, out_dim1=768).to(device=device,dtype=dtype)
  30. text_encoder_t5 = T5ForConditionalGeneration.from_pretrained(text_encoder_ckpt_id).get_encoder().to(device=device,dtype=dtype)
  31. tokenizer_t5 = AutoTokenizer.from_pretrained(text_encoder_ckpt_id)
  32. proj_t5_save_path = f"diffusion_pytorch_model.bin"
  33. state_dict = torch.load(proj_t5_save_path, map_location="cpu")
  34. state_dict_new = {}
  35. for k,v in state_dict.items():
  36.     k_new = k.replace("module.","")
  37.     state_dict_new[k_new] = v
  38. proj_t5.load_state_dict(state_dict_new)
  39. pipeline = FluxPipeline.from_pretrained(
  40.     ckpt_id, text_encoder=None, text_encoder_2=None,
  41.     tokenizer=None, tokenizer_2=None, vae=None,
  42.     torch_dtype=torch.bfloat16
  43. ).to(device)
  44. pipeline.load_lora_weights("ostris/OpenFLUX.1/openflux1-v0.1.0-fast-lora.safetensors")
  45. vae = AutoencoderKL.from_pretrained(
  46.     ckpt_id,
  47.     subfolder="vae",
  48.     torch_dtype=torch.bfloat16
  49. ).to(device)
  50. vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
  51. image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
  52. while True:
  53.     raw_text = input("\nPlease Input Query (stop to exit) >>> ")
  54.     if not raw_text:
  55.         print('Query should not be empty!')
  56.         continue
  57.     if raw_text == "stop":
  58.         break
  59.     with torch.no_grad():
  60.         text_inputs = tokenizer_t5(
  61.             raw_text,
  62.             padding="max_length",
  63.             max_length=256,
  64.             truncation=True,
  65.             add_special_tokens=True,
  66.             return_tensors="pt",
  67.         ).input_ids.to(device)
  68.         text_embeddings = text_encoder_t5(text_inputs)[0]
  69.         pooled_prompt_embeds,prompt_embeds = proj_t5(text_embeddings)
  70.         height, width = 1024, 1024
  71.         latents = pipeline(
  72.             prompt_embeds=prompt_embeds,
  73.             pooled_prompt_embeds=pooled_prompt_embeds,
  74.             num_inference_steps=4, guidance_scale=0,
  75.             height=height, width=width,
  76.             output_type="latent",
  77.         ).images
  78.         latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
  79.         latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
  80.         image = vae.decode(latents, return_dict=False)[0]
  81.         image = image_processor.postprocess(image, output_type="pil")
  82.         image[0].save("ChineseOpenFLUX.jpg")
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

欢乐狗

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表