手撕Diffusion系列 - 第十一期 - lora微调 - 基于Stable Diffusion(代码)
Stable Diffusion 原理图
Stable Diffusion的原明白释
Stable Diffusion的网络布局图如下图所示:
Stable Diffusion 的网络布局图 ## Stable Diffusion 和 Diffusion 的区别
- 改动1:利用 AE,VAE,VQVAE 等自编码器,进行了图像特性提取,利用正确提取特性后的图像作为自己原本在Diffusion中的图像。
- 改动2:在训练过程中,额外添加了一些引导信息,促使图像天生,往我们所希望的方向去走,这里添加信息的方式主要是利用交叉注意力机制(这里我看图应该是只用交叉注意力就行,但是我看视频博主用的代码以及参照的Stable-Diffusion Unet图上都是利用的Transoformer的编码器,也就是得到注意力值之后还得进行一个feedforward层)。
- **改动3:**利用 AE,VAE,VQVAE 等自编码器进行解码。(这个实质上和第一点是重复的)
- **注意:**本次的代码改动先只改动第二个,也就是添加引导信息,对于编码器用于减少计算量,本次改进先不到场(555~,由于视频博主没教),后续可能会进行添加(由于也比力简朴)。
Stable Diffusion 和Diffusion 的Unet对比
原本的Unet图像
Stable Diffusion的 Unet 图像
- 我们可以发现,两者之间的区别主要在于,在卷积完了之后添加了一个Transformer的模块,也就是其编码器将两个信息进行了融合,其他并没有改变。
- 所以主要区别在卷积后的那一部门,如下图。
卷积后的区别
- 这个ResnetBlock就是之前的卷积模块,作为右边的残差部门,所以这里写成 了ResnetBlock。
- 因此,如果我们将Tranformer模块融入到Restnet模块里面,并且保持其输入卷积的图像和transformer输出的图像形状一致的话,那么就其他部门完全不必要改变了,只不过里面多添加了一些引导信息(MNIST数据集中是label,但是也可以添加文本等等引导信息) 而已。
Lora 微调原理
LoRA 微调算法 - 初始示意图
- 算法过程:对于原先的参数不改变,通过右边添加一个参数矩阵来进行微调,也就是利用新的参数矩阵来微调拟合新范畴的参数和初始参数的差距。也就是ΔW。
理论:预训练大型语言模型在顺应新任务时具有较低的“内在维度” , 所以当对于一个预训练模型来说,原先的参数是有非常多的冗余的,因此我们可以利用低维空间(也就是降维)去表示目的参数和原先参数之间的隔断。因此ΔW是相对W来说维度非常小的,减少了非常多的参数量。
LoRA参数微调具体表现
- 由于要包管输入和输出的维度和原本的参数W一样,所以一般参数输入的维度还是类似的,但是中间的维度小很多,从而达到减少参数量的结果。比如原本是100x100的参数量,现在变为100x5(r)x2,减少了10倍。
o u t p u t = n e t ( x ) + t o r c h . m a t m u l ( x , t o r c h . m a t m u l ( l o r a a , l o r a b ) ∗ a l p h a ( 可能这里也会除以 r ) output=net(x)+torch.matmul(x,torch.matmul(lora_a,lora_b)*alpha(可能这里也会除以r) output=net(x)+torch.matmul(x,torch.matmul(loraa,lorab)∗alpha(可能这里也会除以r)
alpha或者alpha/r 是一个缩放因子,用于调解组合结果(原始模型输出加上低秩自顺应)的大小。这均衡了预训练模型的知识和新的特定于任务的顺应——默认环境下,alpha通常设置为 1。另请注意,虽然W A被初始化为小的随机权重,但WB被初始化为 0,因此训练开始时ΔW = WAxWB = 0 ,这意味着我们以原始权重开始训练。
Stable Diffusion 添加lora微调代码
Part1 添加lora.py文件 - 用于设置lora层以及替换
1. 引入相干库函数
- # 该模块主要是实现lora类,实现lora层的alpha和beta通路,把输入的x经过两条通路后的结果,进行联合输出。
- # 然后添加一个函数,主要是为了实现将原本的线性层换曾lora层。
- '''
- # Part1 引入相关的库函数
- '''
- import torch
- from torch import nn
- from config import *
复制代码 2. 界说LoraLayer的类
- '''
- # Part2 设计一个类,实现lora_layer
- '''
- class LoraLayer(nn.<
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |