马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
x
迩来看Stable Diffusion(SD),发现他是通过在latent space中做Denoising,然后将得到的结果丢到VAE的Decoder中,来天生图像。以是,决定本身将这个流程的代码跑一遍,有个印象。以是就从VAE开始,这个虽然之前已经看过,但是过一段时间后,又会有各种各样的错误,以是这里记录一下,方便日后查找。
因为这里主要是看SD的主要流程,和其中比力关键的地方,以是例子不能太难,就选择MNIST作为练习数据。当然,这篇blog作为一个记录性的文章,并不会涉及到其中的原理,只会对其中计划的关键地方进行先容,并给出实现。如果大家渴望看到相关的原理解说,请参考作者的paper和其他科普类的blog。
总的来说,VAE有一个Encoder和一个Decoder。其中Encoder对输入数据编码到latent vector,之后Decoder对latent vector做解码,进行数据重修。但是VAE的想法就是编码后的latent vector需要服从Stadard Gaussian Distribution。但是,如果直接在Encoder输出一个高斯分布,并进行采样后输出到Decoder会导致梯度无法回传。这里简朴表明一下为什么。想象一下,Encoder输出一个均值和一个方差,然后通过这两个参数构建一个Gaussian Distribution。接着从这里面采样,并输出到Decoder。那么整个盘算过程其实就断开了,因为Decoder的输入是采样得到的,而采样这个步骤不是一个数学盘算过程,也就无法进行梯度通报。
为相识决这个题目,VAE采用了重参数化的方法来使的盘算不绝开。这里说一下重参数化方法,以便背面理解。通过Encoder,得到了Mean和Variance之后,我们不再进行采样操作,而是通过它构造一个latent vector,而且让它服从高斯分布。具体的方法就是z=Mean + Variance * Gaussian Sample。至于如许为什么能够成立,请参考其他资料。
那么,VAE的Loss就包含两部分(这里是偏直观的理解,详细的数学推导有点复杂,请参考其他资料),一部分就是对于Encoder输出的限定,而另一部分就是Decoder的重修损失。其中Encoder的输出是分布的参数,以是用KL Divergence作为loss function,而Decoder是重修损失,用什么都可以,简朴一点可以直接用MSE。
表明清晰后下面就给出代码(pytorch):
第一步,导入必要的包
- import torch
- import torchvision
- from torch.utils.data import DataLoader
- from torch.nn import functional as F
复制代码 第二步,加载数据
- training_data = torchvision.datasets.MNIST(
- root = 'data',
- train = True ,
- transform = torchvision.transforms.ToTensor(),
- download = True
- )
- testing_data = torchvision.datasets.MNIST(
- root = 'data',
- train = False ,
- transform = torchvision.transforms.ToTensor(),
- download = True
- )
复制代码 第三步,实现一个VAE网络
- class VAE( torch.nn.Module ):
- def __init__(self):
- super(VAE , self).__init__();
- self.linear1 = torch.nn.Linear( 28 * 28 , 256 )
- self.linear2 = torch.nn.Linear( 256 , 64 )
- self.mean_ = torch.nn.Linear( 64 , 32 )
- self.log_var = torch.nn.Linear( 64 , 32 );
- self.linear3 = torch.nn.Linear( 32 , 64 )
- self.linear4 = torch.nn.Linear( 64, 256)
- self.linear5 = torch.nn.Linear( 256, 512)
- def reparameter( self , log_var , mean ):
- std = torch.exp( log_var * .5 );
- latent = mean + std * torch.randn_like( std );
- return latent;
- def encode( self , x ):
- latent = F.relu( self.linear1( x ) );
- latent = F.relu( self.linear2( x ) );
- log_var , mean = self.log_var_( latent ) , self.mean_( latent );
- latent = self.reparameter( log_var , mean );
- return mean , log_var , latent;
- def decode( self , latent ):
- result = F.relu( self.linear3( latent ) );
- result = F.relu( self.linear4( latent ) );
- result = F.sigmoid( self.linear5( latent ) );
- return result;
- def forward( self , x ):
- # encoder
- latent = self.encoder( x );
- # distribution parameters
- log_var , mean = self.log_var_layer( latent ) , self.mean_layer( latent );
- # reparameterization
- latent = self.reparameter( log_var , mean );
- # decode
- x = self.decoder( latent );
- return mean , log_var , x ;
- def criterion( self , log_var , mean , x , y ):
- re_loss = F.mes_loss( y , x , reduction = 'sum' )
- kl_loss = -.5 * ( 1 + log_var - mean.pow( 2 ) - log_var.exp() ).sum()
- return re_loss + kl_loss;
复制代码 这里定义了一个简朴的全连接网络(前两层是Encoder,背面三层是Deocder),在forward中,Encoder会输出Mean和LogVariance(方差取log,这里也可以直接输出方差,但是由于方差非负,以是需要一些额外操作);接着重参数化会输入均值和方差,并输出一个latent;之后输出到Decoder进行重修。criterion就是盘算VAE的误差,有两个,分别是重修误差(re_loss)和分布误差(kl_loss)。
第四步,练习VAE
- model = VAE();
- device = 'cuda'
- vae_optimizer = torch.optim.Adam( model.parameters() , lr = 1e-3 , weight_decay = .08 )
- training_dataloader = DataLoader( training_data , batch_size = 64 , shuffle = True )
- model = model.train().to( device );
- for _ in range( 10 ):
- for batch , ( x , Y ) in enumerate( dataloader ):
- x = x.flatten( 1 ).to( device );
- log_var , mean , y = model( x );
- loss = model.criterion( log_var , mean , x , y );
- optimizer.zero_grad();
- loss.backward();
- optimizer.step();
复制代码 第五步,天生图像
- model = model.eval().to( device );
- latent = torch.randn( 64, 32 ).to( device );
- image = model.decode( latent )
- image = image.reshape( 64 , 1 , 28 , 28 ).detach().cpu()
- torchvision.utils.save_image( image , 'sample.png' )
复制代码 这里先采样64个batch,latent 维度是32,接着用decode天生图像,并保存,就可以看到结果了。当然,我们这里有测试数据集,也可以通过测试数据集进行encode,之后进行重参数化,并进行decode查看结果。
上述代码是从我的代码中删除掉部分不必要的内容给出的,如果有题目标化,接待大家指出来。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
|