qidao123.com技术社区-IT企服评测·应用市场

标题: AIGC笔记--基于PEFT库利用LoRA [打印本页]

作者: 郭卫东    时间: 2024-11-11 23:27
标题: AIGC笔记--基于PEFT库利用LoRA
1--干系讲解

LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS
LoRA 在 Stable Diffusion 中的三种应用:原理讲解与代码示例
PEFT-LoRA
2--基本原理



           固定原始层,通过添加和练习两个低秩矩阵,达到微调模型的效果;
  3--简朴代码

  1. import torch
  2. import torch.nn as nn
  3. from peft import LoraConfig, get_peft_model, LoraModel
  4. from peft.utils import get_peft_model_state_dict
  5. # 创建模型
  6. class Simple_Model(nn.Module):
  7.     def __init__(self):
  8.         super().__init__()
  9.         self.linear1 = nn.Linear(64, 128)
  10.         self.linear2 = nn.Linear(128, 256)
  11.     def forward(self, x: torch.Tensor):
  12.         x = self.linear1(x)
  13.         x = self.linear2(x)
  14.         return x
  15. if __name__ == "__main__":
  16.     # 初始化原始模型
  17.     origin_model = Simple_Model()
  18.     # 配置lora config
  19.     model_lora_config = LoraConfig(
  20.         r = 32,
  21.         lora_alpha = 32, # scaling = lora_alpha / r 一般来说,lora_alpha的参数初始化为与r相同,即scale=1
  22.         init_lora_weights = "gaussian", # 参数初始化方式
  23.         target_modules = ["linear1", "linear2"], # 对应层添加lora层
  24.         lora_dropout = 0.1
  25.     )
  26.     # Test data
  27.     input_data = torch.rand(2, 64)
  28.     origin_output = origin_model(input_data)
  29.     # 原始模型的权重参数
  30.     origin_state_dict = origin_model.state_dict()
  31.     # 两种方式生成对应的lora模型,调用后会更改原始的模型
  32.     new_model1 = get_peft_model(origin_model, model_lora_config)
  33.     new_model2 = LoraModel(origin_model, model_lora_config, "default")
  34.     output1 = new_model1(input_data)
  35.     output2 = new_model2(input_data)
  36.     # 初始化时,lora_B矩阵会初始化为全0,因此最初 y = WX + (alpha/r) * BA * X == WX
  37.     # origin_output == output1 == output2
  38.     # 获取lora权重参数,两者在key_name上会有区别
  39.     new_model1_lora_state_dict = get_peft_model_state_dict(new_model1)
  40.     new_model2_lora_state_dict = get_peft_model_state_dict(new_model2)
  41.     # origin_state_dict['linear1.weight'].shape -> [output_dim, input_dim]
  42.     # new_model1_lora_state_dict['base_model.model.linear1.lora_A.weight'].shape -> [r, input_dim]
  43.     # new_model1_lora_state_dict['base_model.model.linear1.lora_B.weight'].shape -> [output_dim, r]
  44.     print("All Done!")
复制代码
4--权重保存和合并

   核心公式是:new_weights = origin_weights + alpha* (BA)
  1.     # 借助diffuser的save_lora_weights保存模型权重
  2.     from diffusers import StableDiffusionPipeline
  3.     save_path = "./"
  4.     global_step = 0
  5.     StableDiffusionPipeline.save_lora_weights(
  6.             save_directory = save_path,
  7.             unet_lora_layers = new_model1_lora_state_dict,
  8.             safe_serialization = True,
  9.             weight_name = f"checkpoint-{global_step}.safetensors",
  10.         )
  11.     # 加载lora模型权重(参考Stable Diffusion),其实可以重写一个简单的版本
  12.     from safetensors import safe_open
  13.     alpha = 1. # 参数融合因子
  14.     lora_path = "./" + f"checkpoint-{global_step}.safetensors"
  15.     state_dict = {}
  16.     with safe_open(lora_path, framework="pt", device="cpu") as f:
  17.         for key in f.keys():
  18.             state_dict[key] = f.get_tensor(key)
  19.     all_lora_weights = []
  20.     for idx,key in enumerate(state_dict):
  21.         # only process lora down key
  22.         if "lora_B." in key: continue
  23.         up_key    = key.replace(".lora_A.", ".lora_B.") # 通过lora_A直接获取lora_B的键名
  24.         model_key = key.replace("unet.", "").replace("lora_A.", "").replace("lora_B.", "")
  25.         layer_infos = model_key.split(".")[:-1]
  26.         curr_layer = new_model1
  27.         while len(layer_infos) > 0:
  28.             temp_name = layer_infos.pop(0)
  29.             curr_layer = curr_layer.__getattr__(temp_name)
  30.         weight_down = state_dict[key].to(curr_layer.weight.data.device)
  31.         weight_up   = state_dict[up_key].to(curr_layer.weight.data.device)
  32.         # 将lora参数合并到原模型参数中 -> new_W = origin_W + alpha*(BA)
  33.         curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
  34.         all_lora_weights.append([model_key, torch.mm(weight_up, weight_down).t()])
  35.         print('Load Lora Done')
复制代码
5--完整代码

PEFT_LoRA


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 qidao123.com技术社区-IT企服评测·应用市场 (https://dis.qidao123.com/) Powered by Discuz! X3.4