Stable Diffusion实现——使用MNISTl实现VAE(踩坑履历)

打印 上一主题 下一主题

主题 1614|帖子 1614|积分 4842

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
        迩来看Stable Diffusion(SD),发现他是通过在latent space中做Denoising,然后将得到的结果丢到VAE的Decoder中,来天生图像。以是,决定本身将这个流程的代码跑一遍,有个印象。以是就从VAE开始,这个虽然之前已经看过,但是过一段时间后,又会有各种各样的错误,以是这里记录一下,方便日后查找。
        因为这里主要是看SD的主要流程,和其中比力关键的地方,以是例子不能太难,就选择MNIST作为练习数据。当然,这篇blog作为一个记录性的文章,并不会涉及到其中的原理,只会对其中计划的关键地方进行先容,并给出实现。如果大家渴望看到相关的原理解说,请参考作者的paper和其他科普类的blog。
        总的来说,VAE有一个Encoder和一个Decoder。其中Encoder对输入数据编码到latent vector,之后Decoder对latent vector做解码,进行数据重修。但是VAE的想法就是编码后的latent vector需要服从Stadard Gaussian Distribution。但是,如果直接在Encoder输出一个高斯分布,并进行采样后输出到Decoder会导致梯度无法回传。这里简朴表明一下为什么。想象一下,Encoder输出一个均值和一个方差,然后通过这两个参数构建一个Gaussian Distribution。接着从这里面采样,并输出到Decoder。那么整个盘算过程其实就断开了,因为Decoder的输入是采样得到的,而采样这个步骤不是一个数学盘算过程,也就无法进行梯度通报。
        为相识决这个题目,VAE采用了重参数化的方法来使的盘算不绝开。这里说一下重参数化方法,以便背面理解。通过Encoder,得到了MeanVariance之后,我们不再进行采样操作,而是通过它构造一个latent vector,而且让它服从高斯分布。具体的方法就是z=Mean + Variance * Gaussian Sample。至于如许为什么能够成立,请参考其他资料。
        那么,VAE的Loss就包含两部分(这里是偏直观的理解,详细的数学推导有点复杂,请参考其他资料),一部分就是对于Encoder输出的限定,而另一部分就是Decoder的重修损失。其中Encoder的输出是分布的参数,以是用KL Divergence作为loss function,而Decoder是重修损失,用什么都可以,简朴一点可以直接用MSE。
        表明清晰后下面就给出代码(pytorch):
        第一步,导入必要的包

  1. import torch
  2. import torchvision
  3. from torch.utils.data import DataLoader
  4. from torch.nn import functional as F
复制代码
        第二步,加载数据

  1. training_data = torchvision.datasets.MNIST(
  2.     root = 'data',
  3.     train = True ,
  4.     transform = torchvision.transforms.ToTensor(),
  5.     download = True
  6. )
  7. testing_data = torchvision.datasets.MNIST(
  8.     root = 'data',
  9.     train = False ,
  10.     transform = torchvision.transforms.ToTensor(),
  11.     download = True
  12. )
复制代码
        第三步,实现一个VAE网络

  1. class VAE( torch.nn.Module ):
  2.     def __init__(self):
  3.         super(VAE , self).__init__();
  4.         self.linear1 = torch.nn.Linear( 28 * 28 , 256 )
  5.         self.linear2 = torch.nn.Linear( 256 , 64 )
  6.         self.mean_ = torch.nn.Linear( 64 , 32 )
  7.         self.log_var = torch.nn.Linear( 64 , 32 );
  8.         self.linear3 = torch.nn.Linear( 32 , 64 )
  9.         self.linear4 = torch.nn.Linear( 64, 256)
  10.         self.linear5 = torch.nn.Linear( 256, 512)
  11.     def reparameter( self , log_var , mean ):
  12.         std = torch.exp( log_var * .5 );
  13.         latent = mean + std * torch.randn_like( std );
  14.         return latent;
  15.     def encode( self , x ):
  16.         latent = F.relu( self.linear1( x ) );
  17.         latent = F.relu( self.linear2( x ) );
  18.         log_var , mean = self.log_var_( latent ) , self.mean_( latent );
  19.         latent = self.reparameter( log_var , mean );
  20.         return mean , log_var , latent;
  21.     def decode( self , latent ):
  22.         result = F.relu( self.linear3( latent ) );
  23.         result = F.relu( self.linear4( latent ) );
  24.         result = F.sigmoid( self.linear5( latent ) );
  25.         return result;
  26.     def forward( self , x ):
  27.         # encoder
  28.         latent = self.encoder( x );
  29.         # distribution parameters
  30.         log_var , mean = self.log_var_layer( latent ) , self.mean_layer( latent );
  31.         # reparameterization
  32.         latent = self.reparameter( log_var , mean );
  33.         # decode
  34.         x = self.decoder( latent );
  35.         return mean , log_var , x ;
  36.     def criterion( self , log_var , mean , x , y ):
  37.         re_loss = F.mes_loss( y , x , reduction = 'sum' )
  38.         kl_loss = -.5 * ( 1 + log_var - mean.pow( 2 ) - log_var.exp() ).sum()
  39.         return re_loss + kl_loss;
复制代码
        这里定义了一个简朴的全连接网络(前两层是Encoder,背面三层是Deocder),在forward中,Encoder会输出Mean和LogVariance(方差取log,这里也可以直接输出方差,但是由于方差非负,以是需要一些额外操作);接着重参数化会输入均值和方差,并输出一个latent;之后输出到Decoder进行重修。criterion就是盘算VAE的误差,有两个,分别是重修误差(re_loss)和分布误差(kl_loss)。
        第四步,练习VAE

  1. model = VAE();
  2. device = 'cuda'
  3. vae_optimizer = torch.optim.Adam( model.parameters() , lr = 1e-3 , weight_decay = .08 )
  4. training_dataloader = DataLoader( training_data , batch_size = 64 , shuffle = True )
  5. model = model.train().to( device );
  6. for _ in range( 10 ):
  7.     for batch , ( x , Y ) in enumerate( dataloader ):
  8.         x = x.flatten( 1 ).to( device );
  9.         log_var , mean , y = model( x );
  10.         loss = model.criterion( log_var , mean , x , y );
  11.         optimizer.zero_grad();
  12.         loss.backward();
  13.         optimizer.step();
复制代码
        第五步,天生图像

  1. model = model.eval().to( device );
  2. latent = torch.randn( 64, 32 ).to( device );
  3. image = model.decode( latent )
  4. image = image.reshape( 64 , 1 , 28 , 28 ).detach().cpu()
  5. torchvision.utils.save_image( image , 'sample.png' )
复制代码
        这里先采样64个batch,latent 维度是32,接着用decode天生图像,并保存,就可以看到结果了。当然,我们这里有测试数据集,也可以通过测试数据集进行encode,之后进行重参数化,并进行decode查看结果。
        上述代码是从我的代码中删除掉部分不必要的内容给出的,如果有题目标化,接待大家指出来。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
继续阅读请点击广告
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

飞不高

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表