ToB企服应用市场:ToB评测及商务社交产业平台

标题: Stable Diffusion的微调方法原理总结 [打印本页]

作者: 张春    时间: 2024-8-27 20:30
标题: Stable Diffusion的微调方法原理总结
目录
1、Textural Inversion(简易)
2、DreamBooth(完整)
3、LoRA(灵巧)
4、ControlNet(彻底)
5、其他

1、Textural Inversion(简易)


        不改变网络结构,仅改变CLIP中token embedding的字典。在字典中新增一个伪词的embedding,fine-tune这个embedding的值。其他全部可调参数都冻结。
优点:练习量极小,需要的素材就是一张图。完全不改变神经网络中的任何参数。
缺点:效果一样平常。
TI的简便激发了很多研究者的灵感,基于TI思路的研究出现了很多。
2、DreamBooth(完整)


        详细做法是,加入一个新词(sks)代表subject,embedding初始值继承原范例的词的embedding。调解了模子中全部可调参数,彻底的让模子学会subject。损失函数加入了监督功能,去监控漂移征象,防止劫难性遗忘“学会新的忘了旧的”。

在LoRA出现前,练习DreamBooth是潮流,但代价较大。
3、LoRA(灵巧)



        LoRA的网络是一种additional network,LoRA练习不改变根本模子的任何参数,只对附加网络内部参数进行调解。在生成图像时,附加网络输出与原网络输出融合,从而改变生成效果。
        由于LoRA是将矩阵压缩到低秩后练习,以是LoRA网络的参数量很小(千分之一),练习速率快。实验发现,低维矩阵对高维矩阵的替代损失不大。以是即便练习的矩阵小,练习效果仍旧很好,已成为一种customization image generation范式。LoRA厥后在结构上改进出差别的版本,比方LoHA,LyCORIS等。

LoRA详解:https://zhuanlan.zhihu.com/p/632159261
Self-Attention的LoRA微调代码:GitHub - owenliang/pytorch-diffusion: pytorch复现stable diffusion
代码分析:
用于更换的线性层 (Wq, Wk, Wv矩阵):
  1. class CrossAttention(nn.Module):
  2.     def __init__(self,channel,qsize,vsize,fsize,cls_emb_size):
  3.         super().__init__()
  4.         # Wq, Wk, Wv 矩阵使用LoRA微调降低参数量, W + WA * WB
  5.         self.w_q=nn.Linear(channel,qsize)
  6.         self.w_k=nn.Linear(cls_emb_size,qsize)
  7.         self.w_v=nn.Linear(cls_emb_size,vsize)
  8.         self.softmax=nn.Softmax(dim=-1)
  9.         self.z_linear=nn.Linear(vsize,channel)
  10.         self.norm1=nn.LayerNorm(channel)
  11.         # feed-forward结构
  12.         self.feedforward=nn.Sequential(
  13.             nn.Linear(channel,fsize),
  14.             nn.ReLU(),
  15.             nn.Linear(fsize,channel)
  16.         )
  17.         self.norm2=nn.LayerNorm(channel)
复制代码
找到模子中全部的Wq, Wk, Wv线性层并将其更换为Lora:
  1. if __name__=='__main__':   # 加入LoRA微调的训练过程
  2.     # 预训练模型
  3.     model=torch.load('model.pt')
  4.     # 向nn.Linear层注入Lora
  5.     for name,layer in model.named_modules():
  6.         name_cols=name.split('.')
  7.         # 过滤出cross attention使用的linear权重
  8.         filter_names=['w_q','w_k','w_v']
  9.         if any(n in name_cols for n in filter_names) and isinstance(layer,nn.Linear):   # module名字中存在w_q, w_k, w_v且属于线性层
  10.             # print(name)   # enc_convs.0.crossattn.w_q,enc_convs.0.crossattn.w_k,enc_convs.0.crossattn.w_v,……
  11.             inject_lora(model,name,layer)
复制代码
Lora详细实现与更换过程:
  1. # Lora实现,封装linear,替换到父module里
  2. class LoraLayer(nn.Module):
  3.     def __init__(self,raw_linear,in_features,out_features,r,alpha):
  4.         super().__init__()
  5.         self.r=r   # 秩数
  6.         self.alpha=alpha   # LoRA分支的权重比例系数
  7.         self.lora_a=nn.Parameter(torch.empty((in_features,r)))   # 可训练参数
  8.         self.lora_b=nn.Parameter(torch.zeros((r,out_features)))
  9.    
  10.         nn.init.kaiming_uniform_(self.lora_a,a=math.sqrt(5))   # WA 矩阵参数需要进行初始化
  11.         self.raw_linear=raw_linear   # 原始模型权重 W
  12.    
  13.     def forward(self,x):    # x:(batch_size,in_features)
  14.         raw_output=self.raw_linear(x)   
  15.         lora_output=x@((self.lora_a@self.lora_b)*self.alpha/self.r)    # LoRA分支:x * (WA * WB * α/r)
  16.         return raw_output+lora_output   # W + LoRA
  17. def inject_lora(model,name,layer):
  18.     name_cols=name.split('.')   # [enc_convs, 0, crossattn, w_q]
  19.     # 逐层下探到linear归属的module
  20.     children=name_cols[:-1]   # [enc_convs, 0, crossattn]
  21.     cur_layer=model
  22.     for child in children:
  23.         cur_layer=getattr(cur_layer,child)   # 逐层深入得到w_q, w_k, w_v层的属性
  24.    
  25.     #print(layer==getattr(cur_layer,name_cols[-1]))
  26.     lora_layer=LoraLayer(layer,layer.in_features,layer.out_features,LORA_R,LORA_ALPHA)
  27.     setattr(cur_layer,name_cols[-1],lora_layer)   # 把 crossattn 的 w_q/w_k/w_v层 的属性替换为LoraLayer
复制代码
模子练习过程:冻结非Lora分支的全部参数
  1.     # lora权重的加载
  2.     try:
  3.         restore_lora_state=torch.load('lora.pt')   # 加载训练好的Lora权重(lora_a, lora_b矩阵),enc_convs.0.crossattn.w_q.lora_a等
  4.         model.load_state_dict(restore_lora_state,strict=False)
  5.     except:
  6.         pass
  7.     model=model.to(DEVICE)
  8.     # 冻结非Lora参数
  9.     for name,param in model.named_parameters():
  10.         if name.split('.')[-1] not in ['lora_a','lora_b']:  # 非LoRA部分不计算梯度
  11.             param.requires_grad=False
  12.         else:
  13.             param.requires_grad=True
复制代码
模子推理过程:将Lora分支参数合并到原始模子参数中(相加)
  1. if __name__=='__main__':
  2.     # 加载模型
  3.     model=torch.load('model.pt')
  4.     USE_LORA=True
  5.     if USE_LORA:   # 使用LoRA推理
  6.         # 把Linear层替换为Lora
  7.         for name,layer in model.named_modules():
  8.             name_cols=name.split('.')
  9.             # 过滤出cross attention使用的linear权重
  10.             filter_names=['w_q','w_k','w_v']
  11.             if any(n in name_cols for n in filter_names) and isinstance(layer,nn.Linear):
  12.                 inject_lora(model,name,layer)
  13.         # lora权重的加载
  14.         try:
  15.             restore_lora_state=torch.load('lora.pt')
  16.             model.load_state_dict(restore_lora_state,strict=False)
  17.         except:
  18.             pass
  19.         model=model.to(DEVICE)
  20.         # lora权重合并到主模型(把LoRA权重加到原始模型权重中)
  21.         for name,layer in model.named_modules():
  22.             name_cols=name.split('.')
  23.             if isinstance(layer,LoraLayer):   # 找到模型中所有的 LoraLayer 层
  24.                 children=name_cols[:-1]
  25.                 cur_layer=model
  26.                 for child in children:
  27.                     cur_layer=getattr(cur_layer,child)    # cur_layer = cross attention对象(包含修改过的wq, wk, wv)
  28.                 lora_weight=(layer.lora_a@layer.lora_b)*layer.alpha/layer.r   # 计算得到lora分支权重
  29.                 before_weight=layer.raw_linear.weight.clone()   # 原始模型权重W
  30.                 layer.raw_linear.weight=nn.Parameter(layer.raw_linear.weight.add(lora_weight.T)).to(DEVICE)    # 把Lora参数加到base model的linear weight上
  31.                 setattr(cur_layer,name_cols[-1],layer.raw_linear)   # 使用新的合并分支替换原来的两分支Lora结构
复制代码
4、ControlNet(彻底)


        将神经网络快的差别权重,分别复制到“锁定”副本(locked copy)和“可练习”副本(trainable copy)中。按制定规则集成原图特征并生成新的内容,不会导致生成图和原图看起来毫无关系。
5、其他



免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4