Stable Diffusion实现——使用MNISTl实现VAE(踩坑履历)
迩来看Stable Diffusion(SD),发现他是通过在latent space中做Denoising,然后将得到的结果丢到VAE的Decoder中,来天生图像。以是,决定本身将这个流程的代码跑一遍,有个印象。以是就从VAE开始,这个虽然之前已经看过,但是过一段时间后,又会有各种各样的错误,以是这里记录一下,方便日后查找。因为这里主要是看SD的主要流程,和其中比力关键的地方,以是例子不能太难,就选择MNIST作为练习数据。当然,这篇blog作为一个记录性的文章,并不会涉及到其中的原理,只会对其中计划的关键地方进行先容,并给出实现。如果大家渴望看到相关的原理解说,请参考作者的paper和其他科普类的blog。
总的来说,VAE有一个Encoder和一个Decoder。其中Encoder对输入数据编码到latent vector,之后Decoder对latent vector做解码,进行数据重修。但是VAE的想法就是编码后的latent vector需要服从Stadard Gaussian Distribution。但是,如果直接在Encoder输出一个高斯分布,并进行采样后输出到Decoder会导致梯度无法回传。这里简朴表明一下为什么。想象一下,Encoder输出一个均值和一个方差,然后通过这两个参数构建一个Gaussian Distribution。接着从这里面采样,并输出到Decoder。那么整个盘算过程其实就断开了,因为Decoder的输入是采样得到的,而采样这个步骤不是一个数学盘算过程,也就无法进行梯度通报。
为相识决这个题目,VAE采用了重参数化的方法来使的盘算不绝开。这里说一下重参数化方法,以便背面理解。通过Encoder,得到了Mean和Variance之后,我们不再进行采样操作,而是通过它构造一个latent vector,而且让它服从高斯分布。具体的方法就是z=Mean + Variance * Gaussian Sample。至于如许为什么能够成立,请参考其他资料。
那么,VAE的Loss就包含两部分(这里是偏直观的理解,详细的数学推导有点复杂,请参考其他资料),一部分就是对于Encoder输出的限定,而另一部分就是Decoder的重修损失。其中Encoder的输出是分布的参数,以是用KL Divergence作为loss function,而Decoder是重修损失,用什么都可以,简朴一点可以直接用MSE。
表明清晰后下面就给出代码(pytorch):
第一步,导入必要的包
import torch
import torchvision
from torch.utils.data import DataLoader
from torch.nn import functional as F 第二步,加载数据
training_data = torchvision.datasets.MNIST(
root = 'data',
train = True ,
transform = torchvision.transforms.ToTensor(),
download = True
)
testing_data = torchvision.datasets.MNIST(
root = 'data',
train = False ,
transform = torchvision.transforms.ToTensor(),
download = True
)
第三步,实现一个VAE网络
class VAE( torch.nn.Module ):
def __init__(self):
super(VAE , self).__init__();
self.linear1 = torch.nn.Linear( 28 * 28 , 256 )
self.linear2 = torch.nn.Linear( 256 , 64 )
self.mean_ = torch.nn.Linear( 64 , 32 )
self.log_var = torch.nn.Linear( 64 , 32 );
self.linear3 = torch.nn.Linear( 32 , 64 )
self.linear4 = torch.nn.Linear( 64, 256)
self.linear5 = torch.nn.Linear( 256, 512)
def reparameter( self , log_var , mean ):
std = torch.exp( log_var * .5 );
latent = mean + std * torch.randn_like( std );
return latent;
def encode( self , x ):
latent = F.relu( self.linear1( x ) );
latent = F.relu( self.linear2( x ) );
log_var , mean = self.log_var_( latent ) , self.mean_( latent );
latent = self.reparameter( log_var , mean );
return mean , log_var , latent;
def decode( self , latent ):
result = F.relu( self.linear3( latent ) );
result = F.relu( self.linear4( latent ) );
result = F.sigmoid( self.linear5( latent ) );
return result;
def forward( self , x ):
# encoder
latent = self.encoder( x );
# distribution parameters
log_var , mean = self.log_var_layer( latent ) , self.mean_layer( latent );
# reparameterization
latent = self.reparameter( log_var , mean );
# decode
x = self.decoder( latent );
return mean , log_var , x ;
def criterion( self , log_var , mean , x , y ):
re_loss = F.mes_loss( y , x , reduction = 'sum' )
kl_loss = -.5 * ( 1 + log_var - mean.pow( 2 ) - log_var.exp() ).sum()
return re_loss + kl_loss; 这里定义了一个简朴的全连接网络(前两层是Encoder,背面三层是Deocder),在forward中,Encoder会输出Mean和LogVariance(方差取log,这里也可以直接输出方差,但是由于方差非负,以是需要一些额外操作);接着重参数化会输入均值和方差,并输出一个latent;之后输出到Decoder进行重修。criterion就是盘算VAE的误差,有两个,分别是重修误差(re_loss)和分布误差(kl_loss)。
第四步,练习VAE
model = VAE();
device = 'cuda'
vae_optimizer = torch.optim.Adam( model.parameters() , lr = 1e-3 , weight_decay = .08 )
training_dataloader = DataLoader( training_data , batch_size = 64 , shuffle = True )
model = model.train().to( device );
for _ in range( 10 ):
for batch , ( x , Y ) in enumerate( dataloader ):
x = x.flatten( 1 ).to( device );
log_var , mean , y = model( x );
loss = model.criterion( log_var , mean , x , y );
optimizer.zero_grad();
loss.backward();
optimizer.step();
第五步,天生图像
model = model.eval().to( device );
latent = torch.randn( 64, 32 ).to( device );
image = model.decode( latent )
image = image.reshape( 64 , 1 , 28 , 28 ).detach().cpu()
torchvision.utils.save_image( image , 'sample.png' ) 这里先采样64个batch,latent 维度是32,接着用decode天生图像,并保存,就可以看到结果了。当然,我们这里有测试数据集,也可以通过测试数据集进行encode,之后进行重参数化,并进行decode查看结果。
上述代码是从我的代码中删除掉部分不必要的内容给出的,如果有题目标化,接待大家指出来。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页:
[1]