勿忘初心做自己 发表于 2024-9-22 23:02:50

Datawhale X 魔搭 AI夏令营第四期 AIGC方向 学习笔记(一)

本期主要任务是相识AI文生图的原理并进行相关实践
下面是对baseline部分代码的功能先容:
安装Data-juicere和DiffSynth-Studio

!pip install simple-aesthetics-predictor
!pip install -v -e data-juicer
!pip uninstall pytorch-lightning -y
!pip install peft lightning pandas torchvision
!pip install -e DiffSynth-Studio 基本的通过pip安装,"!"控制语句在终端进行操作。simple-aesthetics-predictor 这个包,参考pypi上项目描述,是一个基于CLIP的美学猜测器,用于猜测图片的美学质量。"-v"、"-e"命令用于设定安装模式. data-juicer ,参考github上的原项目Readme文件,是一个“用于大语言模子的一站式数据处理系统”。peft 与参数高效微调相关,lightning 是用于简化训练过程的库,pandas和torchvision就不多说了。DiffSynth-Studio 则是一种用于实现图片和视频风格转换的引擎。
下载数据集

从modelscope上下载某个数据集,指定了目的数据集的路径,子集名称,拆分部分(训练集)和下载完成后的缓存目次。
生存数据会合的图片和元数据

os.makedirs("./data/lora_dataset/train", exist_ok=True)
os.makedirs("./data/data-juicer/input", exist_ok=True)
with open("./data/data-juicer/input/metadata.jsonl", "w") as f:
    for data_id, data in enumerate(tqdm(ds)):
      image = data["image"].convert("RGB")
      image.save(f"/mnt/workspace/kolors/data/lora_dataset/train/{data_id}.jpg")
      metadata = {"text": "二次元", "image": }
      f.write(json.dumps(metadata))
      f.write("\n") 这部分主要进行对下载得到的数据集的遍历,将此中的图片转化成RGB格式后生存到指定路径(../data/lora_dataset/train)。别的创建由文本和对应图片构成的字典作为元数据写入json文件生存
数据处理

在变量 data_juicer_config 中定义了数据处理的各项配置信息,并将其写入yaml文件中。之后调用dj-process命令开启数据处理,并通过该配置文件传入相关参数。
生存处理好的数据

主要是从 result.jsonl 文件中进行文本和图像的生存,并将文件名和文本信息存至csv文件中
训练模子 

from diffsynth import download_models
download_models(["Kolors", "SDXL-vae-fp16-fix"])

!python DiffSynth-Studio/examples/train/kolors/train_kolors_lora.py -h 下载模子;终端查察训练脚本输入参数
cmd = """
python DiffSynth-Studio/examples/train/kolors/train_kolors_lora.py \
--pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
--pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
--pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
--lora_rank 16 \
--lora_alpha 4.0 \
--dataset_path data/lora_dataset_processed \
--output_path ./models \
--max_epochs 1 \
--center_crop \
--use_gradient_checkpointing \
--precision "16-mixed"
""".strip()

os.system(cmd)  这一段定义了训练过程需要在终端实行的命令,主要包含以下内容:指定了预训练需要的Unet模子路径、文本编码器模子路径和fp16VAE模子路径;指定lora的品级和alpha值相关参数;指定命据集路径、输出路径;指定最大训练轮数,利用中心裁剪、梯度查抄点,和精度参数。
加载模子 
def load_lora(model, lora_rank, lora_alpha, lora_path):
    lora_config = LoraConfig(
      r=lora_rank,
      lora_alpha=lora_alpha,
      init_lora_weights="gaussian",
      target_modules=["to_q", "to_k", "to_v", "to_out"],
    )
    model = inject_adapter_in_model(lora_config, model)
    state_dict = torch.load(lora_path, map_location="cpu")
    model.load_state_dict(state_dict, strict=False)
    return model

# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
                           file_path_list=[
                                 "models/kolors/Kolors/text_encoder",
                                 "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
                                 "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
                           ])
pipe = SDXLImagePipeline.from_model_manager(model_manager)

# Load LoRA
pipe.unet = load_lora(
    pipe.unet,
    lora_rank=16, # This parameter should be consistent with that in your training script.
    lora_alpha=2.0, # lora_alpha can control the weight of LoRA.
    lora_path="models/lightning_logs/version_0/checkpoints/epoch=0-step=500.ckpt"
)  load_lora 函数加载loRA模子并进行相关参数配置(
model:要注入 LoRA 适配器的原始模子。
lora_rank:LoRA 适配器的秩,用于控制适配器的复杂度。
lora_alpha:LoRA 适配器的缩放因子,用于控制其权重。
lora_path:包含预训练 LoRA 权重的文件路径。) 
利用 inject_adapt_in_model 将loRA注入原始模子,加载loRA预训练的权重字典并应用至模子中。
后续部分并不熟悉各实例的作用,暂且一放。
生成图像

torch.manual_seed(0)
image = pipe(
    prompt="二次元,一个紫色短发小女孩,在家中沙发上坐着,双手托着腮,很无聊,全身,粉色连衣裙",
    negative_prompt="丑陋、变形、嘈杂、模糊、低对比度",
    cfg_scale=4,
    num_inference_steps=50, height=1024, width=1024,
)
image.save("1.jpg") 设置随机种子值使随机操作具有可重复性。利用pipe对象进行生成,给出正负面提示词,配置标准,推理步数和图像尺寸参数。
在实验了自己的一系列提示词后得到如下八张图,内容类似baseline原本给的,主线换成了足球:
主要存在的标题:部分画风不统一;部分图细节不佳;对“足球”一词的表现错误(应该是中文输入翻译标题);部分提示词的信息未有用表现出来
https://i-blog.csdnimg.cn/direct/59b06b42ef36444f89dd904c6406f5cf.jpeghttps://i-blog.csdnimg.cn/direct/2435d22886cb407b99cc7231e18f8436.jpeghttps://i-blog.csdnimg.cn/direct/e0667de906084008afac373fbd792512.jpeghttps://i-blog.csdnimg.cn/direct/5c9c03d9b6a945669ec260ad7dfab8b4.jpeghttps://i-blog.csdnimg.cn/direct/a8270f556dba4497afe40116a3a2c472.jpeghttps://i-blog.csdnimg.cn/direct/a2993cd7e28940289c67e4ac6b66aa66.jpeghttps://i-blog.csdnimg.cn/direct/1446d81d50404925a5739a12b75c23f1.jpeghttps://i-blog.csdnimg.cn/direct/b941b39d15154d218c5d8d149ca6455f.jpeg

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页: [1]
查看完整版本: Datawhale X 魔搭 AI夏令营第四期 AIGC方向 学习笔记(一)