Datawhale X 魔搭 AI夏令营 “AIGC”方向 task1

[复制链接]
发表于 2026-2-4 17:05:06 | 显示全部楼层 |阅读模式
一、任务要求

task1 的任务和上一期的类似,都是跑通给出的代码即可,没有太大难度。
具体要求是练习 Lora 模子,实现文生图,额外的要求是8张图片必须构成一个连贯的故事,必要肯定的“写小作文”本领。
二、代码剖析


  • 下载数据集
    这一步不消分析,就是生存图片和元数据。
  • 数据处理处罚
    标题已经提示了利用 data-juicer 处理处罚数据。之以是选择 data-juicer,大概由于有主动化、批量处理处罚、同一设置管理以及允许多核并行的上风。
    代码实在很简单,只必要编写设置文件,就可以主动处理处罚数据。具体表明如下:
    1. # data_juicer 配置
    2. data_juicer_config = """
    3. # global parameters(全局参数)
    4. project_name: 'data-process'
    5. dataset_path: './data/data-juicer/input/metadata.jsonl'  # path to your dataset directory or file
    6. # 指定用于处理数据集的子进程数目
    7. np: 4  # number of subprocess to process your dataset
    8. # 定义了数据集中用于存储文本的键
    9. text_keys: 'text'
    10. # 定义了数据集中用于存储图像的键
    11. image_key: 'image'
    12. # 定义了图像的特殊标记,用于标记或识别图像数据
    13. image_special_token: '<__dj__image>'
    14. # 指定处理后的数据将导出的路径
    15. export_path: './data/data-juicer/output/result.jsonl'
    16. # process schedule
    17. # a list of several process operators with their arguments(定义了处理数据的操作列表,这里包含两个操作)
    18. process:
    19.     #  过滤图像尺寸
    20.     - image_shape_filter:
    21.         # 要求图像的最小宽度和高度分别为 1024 像素
    22.         min_width: 1024
    23.         min_height: 1024
    24.         # 选择 any 表示只要宽度或高度任意一项满足条件即通过
    25.         any_or_all: any
    26.     # 过滤图像的宽高比
    27.     - image_aspect_ratio_filter:
    28.         # 要求宽高比在 0.5 到 2.0 之间
    29.         min_ratio: 0.5
    30.         max_ratio: 2.0
    31.         # 选择 any 表示只要宽高比任意一项满足条件即通过
    32.         any_or_all: any
    33. """
    34. # 将配置写入文件
    35. with open("data/data-juicer/data_juicer_config.yaml", "w") as file:
    36.     file.write(data_juicer_config.strip())
    37. # 这行代码使用了 dj-process 命令,并指定了配置文件路径,用来启动数据处理任务
    38. # 其中 dj-process 是一个命令行工具,读取指定的配置文件,根据其中定义的参数和操作来处理数据集
    39. !dj-process --config data/data-juicer/data_juicer_config.yaml
    复制代码
  • 练习模子
    这一部门的焦点内容是通过一句下令行来实行的,具体代码如下:
    1. import os
    2. cmd = """
    3. python DiffSynth-Studio/examples/train/kolors/train_kolors_lora.py \
    4.   --pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
    5.   --pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
    6.   --pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
    7.   --lora_rank 16 \
    8.   --lora_alpha 4.0 \
    9.   --dataset_path data/lora_dataset_processed \
    10.   --output_path ./models \
    11.   --max_epochs 1 \
    12.   --center_crop \
    13.   --use_gradient_checkpointing \
    14.   --precision "16-mixed"
    15. """.strip()
    16. os.system(cmd)
    复制代码
    参数的寄义:
    python DiffSynth-Studio/examples/train/kolors/train_kolors_lora.py:运行指定的 Python 脚本。
    --pretrained_unet_path:指定了预练习的 UNet 模子的路径。
    --pretrained_text_encoder_path:指定了预练习的文本编码器路径。
    --pretrained_fp16_vae_path:指定了预练习的 16 位肴杂精度 VAE 模子的路径。
    --lora_rank:决定了新到场的低秩矩阵的维度。它控制了新参数矩阵的复杂度,较高的 lora_rank 值增长了模子的表达本领,但也会增长盘算复杂度和内存利用;较低的 lora_rank 值则大概镌汰表达本领,但会更加高效。
    --lora_alpha:低秩矩阵更新的缩放系数,用来平衡原始模子参数和 LoRA 新到场的参数之间的影响。较高的 lora_alpha 值会放大 LoRA 参数的影响力,大概导致模子更快收敛或更容易过拟合;较低的值则相反。
    --dataset_path:指定了练习数据集的路径。
    --output_path:指定了练习输出模子的生存路径。
    --max_epochs:设置了最大练习轮数。
    --center_crop:启用居中裁剪图像的选项。
    --use_gradient_checkpointing:启用了梯度查抄点,用于镌汰内存利用。
    --precision "16-mixed":利用肴杂精度练习,以进步盘算服从。

    具体的模子代码显然就在 train_kolors_lora.py 中了,简单分析一下。
    1. from diffsynth import ModelManager, SDXLImagePipeline
    2. from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
    3. import torch, os, argparse
    4. # 设置环境变量 TOKENIZERS_PARALLELISM 为 True,以允许多线程并行化分词处理
    5. os.environ["TOKENIZERS_PARALLELISM"] = "True"
    6. # LightningModel 继承自 LightningModelForT2ILoRA,这是一个用于文本到图像任务的 LoRA 训练模型类
    7. class LightningModel(LightningModelForT2ILoRA):
    8.     # 初始化:
    9.     # torch_dtype 指定了数据类型(如 float16 或 float32),主要用于控制模型的精度。
    10.     # pretrained_weights 是一个包含预训练模型路径的列表,用于加载预训练的 UNet、文本编码器、VAE 等模型。
    11.     # learning_rate 是学习率,用于训练过程中的优化。
    12.     # use_gradient_checkpointing 控制是否使用梯度检查点来节省内存。
    13.     # lora_rank 和 lora_alpha 是 LoRA 相关的参数,用于调整低秩矩阵的秩和缩放系数。
    14.     # lora_target_modules 指定了哪些模块(如 to_q、to_k 等)将应用 LoRA 技术。
    15.     def __init__(
    16.         self,
    17.         torch_dtype=torch.float16, pretrained_weights=[],
    18.         learning_rate=1e-4, use_gradient_checkpointing=True,
    19.         lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"
    20.     ):
    21.         super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing)
    22.         # Load models
    23.         # 模型管理和加载:
    24.         # ModelManager 用于管理模型加载和设备分配
    25.         # SDXLImagePipeline 用于定义图像生成的流水线
    26.         # self.pipe.scheduler.set_timesteps(1100) 设置了调度器的时间步数
    27.         model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
    28.         model_manager.load_models(pretrained_weights)
    29.         self.pipe = SDXLImagePipeline.from_model_manager(model_manager)
    30.         self.pipe.scheduler.set_timesteps(1100)
    31.         # Convert the vae encoder to torch.float16
    32.         # 参数调整和冻结:
    33.         # 将 VAE 编码器转换为指定的数据类型
    34.         # 冻结模型的部分参数,以减少训练过程中的计算需求
    35.         # 将 LoRA 层添加到指定的模型模块中
    36.         self.pipe.vae_encoder.to(torch_dtype)
    37.         self.freeze_parameters()
    38.         self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules)
    39. # 解析命令行参数
    40. def parse_args():
    41.     # 定义了一个 argparse.ArgumentParser 对象,用于解析命令行参数
    42.     parser = argparse.ArgumentParser(description="Simple example of a training script.")
    43.     # 下面是若干必需的参数
    44.     parser.add_argument(
    45.         "--pretrained_unet_path",
    46.         type=str,
    47.         default=None,
    48.         required=True,
    49.         help="Path to pretrained model (UNet). For example, `models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors`.",
    50.     )
    51.     parser.add_argument(
    52.         "--pretrained_text_encoder_path",
    53.         type=str,
    54.         default=None,
    55.         required=True,
    56.         help="Path to pretrained model (Text Encoder). For example, `models/kolors/Kolors/text_encoder`.",
    57.     )
    58.     parser.add_argument(
    59.         "--pretrained_fp16_vae_path",
    60.         type=str,
    61.         default=None,
    62.         required=True,
    63.         help="Path to pretrained model (VAE). For example, `models/kolors/Kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors`.",
    64.     )
    65.     parser.add_argument(
    66.         "--lora_target_modules",
    67.         type=str,
    68.         default="to_q,to_k,to_v,to_out",
    69.         help="Layers with LoRA modules.",
    70.     )
    71.     parser = add_general_parsers(parser)
    72.     args = parser.parse_args()
    73.     return args
    74. # 启动训练
    75. if __name__ == '__main__':
    76.     args = parse_args()
    77.     model = LightningModel(
    78.         torch_dtype=torch.float32 if args.precision == "32" else torch.float16,
    79.         pretrained_weights=[
    80.             args.pretrained_unet_path,
    81.             args.pretrained_text_encoder_path,
    82.             args.pretrained_fp16_vae_path,
    83.         ],
    84.         learning_rate=args.learning_rate,
    85.         use_gradient_checkpointing=args.use_gradient_checkpointing,
    86.         lora_rank=args.lora_rank,
    87.         lora_alpha=args.lora_alpha,
    88.         lora_target_modules=args.lora_target_modules
    89.     )
    90.     launch_training_task(model, args)
    复制代码
  • 加载模子
    顾名思义,将前面练习好的模子加载,用于下面天生图像的任务。
    1. def load_lora(model, lora_rank, lora_alpha, lora_path):
    2.     # 配置 LoRA 模块,包括秩、缩放系数、初始化方式等
    3.     lora_config = LoraConfig(
    4.         r=lora_rank,
    5.         lora_alpha=lora_alpha,
    6.         init_lora_weights="gaussian",
    7.         target_modules=["to_q", "to_k", "to_v", "to_out"],
    8.     )
    9.     # 将 LoRA 模块注入到指定的模型中
    10.     model = inject_adapter_in_model(lora_config, model)
    11.     # 加载训练好的 LoRA 权重
    12.     state_dict = torch.load(lora_path, map_location="cpu")
    13.     # 将加载的权重应用到模型中,并返回更新后的模型
    14.     model.load_state_dict(state_dict, strict=False)
    15.     return model
    16. # Load models
    17. # 管理和加载多个模型组件,比如文本编码器(Text Encoder)、UNet、VAE 等
    18. model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
    19.                              file_path_list=[
    20.                                  "models/kolors/Kolors/text_encoder",
    21.                                  "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
    22.                                  "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
    23.                              ])
    24. pipe = SDXLImagePipeline.from_model_manager(model_manager)
    25. # Load LoRA
    26. # 加载 LoRA 模块到 UNet 模型
    27. # 两者的关系:LoRA 提供了一种高效的方法来微调大型预训练模型,比如 UNet,使其能够更好地适应特定任务或数据集
    28. pipe.unet = load_lora(
    29.     pipe.unet,
    30.     lora_rank=16, # This parameter should be consistent with that in your training script.(参数需要与训练时的一致,以确保模型的行为一致)
    31.     lora_alpha=2.0, # lora_alpha can control the weight of LoRA.
    32.     lora_path="models/lightning_logs/version_0/checkpoints/epoch=0-step=500.ckpt"
    33. )
    复制代码
  • 天生图像
    这里指的是文生图,8张图片的天生方式是一样的,以第一张为例,代码如下:
    1. # 设置随机数生成器的种子,保证结果的可重复性
    2. torch.manual_seed(0)
    3. image = pipe(
    4.     # 模型生成图像的描述性文本(即正向提示)
    5.     prompt="水墨画风格,一个黑头发的青年男子,站在窗户前,双手背在身后,很惆怅,全身,白色长袍",
    6.     # 反向提示,指定了生成图像时需要避免的特征
    7.     negative_prompt="丑陋、变形、嘈杂、模糊、低对比度、现代",
    8.     # 指导生成模型在多大程度上遵循提示的配置参数;值越大,模型越倾向于严格遵循 prompt 提示,可能会导致图像风格更加确定但多样性降低;值较小时,生成的图像可能更具多样性但可能偏离提示
    9.     cfg_scale=4,
    10.     # 生成图像时进行的推理步骤数,数值越高,图像质量通常越高,但生成时间也会增加;指定生成图像的分辨率为 1024x1024 像素
    11.     num_inference_steps=50, height=1024, width=1024,
    12. )
    13. # 保存图片
    14. image.save("1.jpg")
    复制代码
  • 展示图像
    将多个图像拼接成一张大图,并调解其尺寸。
    1. images = [np.array(Image.open(f"{i}.jpg")) for i in range(1, 9)]
    2. # 拼接图像,每两张拼接为一行,形成4*2的大图
    3. image = np.concatenate([
    4.     np.concatenate(images[0:2], axis=1),
    5.     np.concatenate(images[2:4], axis=1),
    6.     np.concatenate(images[4:6], axis=1),
    7.     np.concatenate(images[6:8], axis=1),
    8. ], axis=0)
    9. # 调整大图尺寸为宽1024,高2048
    10. image = Image.fromarray(image).resize((1024, 2048))
    11. image
    复制代码
三、结果展示

我原来想以某个汗青故事为原型,风格是古风,但是我发现给定的练习集险些都是二次元的图片,我以为结果不会太好,不外照旧挺有感觉的。固然我自己爬虫搞了几张古风图片,之后试试看,盼望能进步图片质量。
照旧展示一张吧!


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!qidao123.com:ToB企服之家,中国第一个企服评测及软件市场,开放入驻,技术点评得现金

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
回复

使用道具 举报

登录后关闭弹窗

登录参与点评抽奖  加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表