经典神经网络(10)PixelCNN模子、Gated PixelCNN模子及其在MNIST数据集上的 ...

打印 上一主题 下一主题

主题 548|帖子 548|积分 1644

经典神经网络(10)PixelCNN模子、Gated PixelCNN模子及其在MNIST数据集上的应用

1 PixelCNN



  • PixelCNN是DeepMind团队在论文Pixel Recurrent Neural Networks (16.01)提出的一种天生模子,现实上这篇论文共提出了两种架构:PixelRNNPixelCNN,两者的紧张区别是前者用LSTM来建模,而PixelCNN是基于CNN的,相比RNN,CNN盘算更高效,我们这里只讨论PixelCNN。
  • PixelCNN借用了NLP里的方法来天生图像。对于天然图像,每个像素值的取值范围为0~255,共256个离散值。PixelCNN模子会根据前i - 1个像素输出第i个像素的概率分布。
  • 练习时,和多分类任务一样,要根据第i个像素的真值和猜测的概率分布求交错熵损失函数
  • 采样时(图像天生时),会根据前i - 1个像素直接从猜测的概率分布(多项分布)里采样出第i个像素。
1.1 单通道PixelCNN

1.1.1 掩码卷积

我们现在知道了PixelCNN的大要思路,就是根据前i - 1个像素输出第i个像素的概率分布。我们现在只思量单通道图像,每个像素的颜色取值只有256种,那么很容易想到下面的实现方式:

但是只输出一个像素的概率分布,这样练习服从太低了。


  • 在练习时,我们可以输入一幅图像,同时让模子输出图像每一点像素的概率分布(如下图所示),这样就能通过每个像素的真值和模子猜测的概率分布求交错熵损失函数,举行并行练习。
  • 我们能这么做的原因是:在练习时,整幅练习图像是已知的,因此我们可以在一次前向流传后得到图像每一处的概率分布。
  • 当然,我们必要找到每个像素都忽略后续像素的信息的方法,即论文中提出的掩码卷积机制,我们后面再讲。

但是在天生图像(采样)时,照旧要一个像素一个像素的天生(如下所示)


  • 在采样时,我们会先根据前i - 1个像素输出第i个像素的概率分布。
  • 然后,我们会从第i个像素的概率分布中举行采样(如下面代码所示)
  1. # 假设颜色取值范围为[0, 7],下面为概率分布
  2. prob_dist = torch.tensor([[0.1347, 0.1356, 0.1048, 0.1314, 0.1329, 0.1256, 0.1326, 0.1025]])
  3. # 我们并不是取概率最大的像素,而是从概率分布中采样(例如下面取像素值6)
  4. # torch.multinomial会从input这个概率分布中,取num_samples个值
  5. pixel = torch.multinomial(input=prob_dist, num_samples=1).float() # tensor([[6.]])
复制代码

我们现在已经知道了练习及采样的大要过程。但是,我们现在照旧有一个疑问,如何保证练习时候,每个像素都忽略后续像素的信息?
PixelCNN论文里提出了一种掩码卷积机制,这种机制可以巧妙地掩盖住每个像素右侧和下侧的信息。


  • 详细来说,PixelCNN使用了两类掩码卷积:

    • 我们把两类掩码卷积分别称为「A类」和「B类」。
    • 二者都是对卷积操作的卷积核做了掩码处理,使得卷积核的右下部门不产生贡献。
    • A类和B类的唯一区别在于:卷积核的中心像素是否产生贡献。
    • CNN的第一个的卷积层使用A类掩码卷积,之后每一层的都使用B类掩码卷积。


我们来分析下这样计划的长处:


  • 对于一个7x7的图像,我们先用1次3x3 A类掩码卷积,再用若干次3x3 B类掩码卷积。我们观察图像中心处的像素在每次卷积后的感受野(即输入图像中哪些像素的信息能够通报到中心像素上)

    • 经过了第一个A类掩码卷积后,每个像素就已经看不到自己位置上的输入信息了。
    • 再经过两次B类掩码卷积后,中心像素能够看到左上角大部门像素的信息(如下图所示,我们发现照旧会看漏少部门的信息,后面的Gated PixelCNN对此举行了改进)。
    • 这满足PixelCNN的束缚。




  • 如果不停使用A类掩码卷积,每次卷积后中心像素都会看漏一些信息,最终就会导致看漏很多信息



  • 如果第一层就使用B类卷积,中心像素照旧能看到自己位置的输入信息。这打破了PixelCNN的束缚。
总结如下:


  • 逐像素猜测只依靠于前面的像素,因此在选择卷积核时要举行掩码操作制止看到未来的值,因此,在第一层猜测时可接纳掩码卷积A
  • 由于CNN的逐像素猜测是多层卷积,所以当第一层结束后,图像缺失部门已经有了猜测值,因此在举行下一次/层卷积操作时可以利用当前像素的猜测值,因此接纳下列掩码卷积B
  • 必要留意的是,这里只思量了单通道,如果扩展到RGB三个通道时,该如何举行mask呢?
1.1.2 PixelCNN的网络架构



  • 利用两类掩码卷积,PixelCNN满足了每个像素只能担当之前像素的信息这一束缚。
  • 我们可以用任意一种CNN架构来实现PixelCNN。
  • 下图赤色框所示部门是PixelCNN的网络布局,此中,第一个7x7卷积层用了A类掩码卷积,之后全部3x3卷积都是B类掩码卷积。

1.1.3 PixelCNN在MNIST数据集上的应用

1.1.3.1 模子

实现PixelCNN,最紧张的是实现掩码卷积。


  • 掩码卷积的实现思路就是在卷积核组上设置一个mask。在前向流传的时候,先让卷积核组乘mask,再做平凡的卷积。
  • 由于输入输出都是单通道图像,我们只必要在卷积核的h, w两个维度设置掩码。
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader
  5. import torchvision
  6. from torchvision.transforms import ToTensor
  7. import time
  8. import einops
  9. import cv2
  10. import numpy as np
  11. import os
  12. class MaskConv2d(nn.Module):
  13.     """
  14.         掩码卷积的实现思路:
  15.             在卷积核组上设置一个mask,在前向传播的时候,先让卷积核组乘mask,再做普通的卷积
  16.     """
  17.     def __init__(self, conv_type, *args, **kwags):
  18.         super().__init__()
  19.         assert conv_type in ('A', 'B')
  20.         self.conv = nn.Conv2d(*args, **kwags)
  21.         H, W = self.conv.weight.shape[-2:]
  22.         # 由于输入输出都是单通道图像,我们只需要在卷积核的h, w两个维度设置掩码
  23.         mask = torch.zeros((H, W), dtype=torch.float32)
  24.         mask[0:H // 2] = 1
  25.         mask[H // 2, 0:W // 2] = 1
  26.         if conv_type == 'B':
  27.             mask[H // 2, W // 2] = 1
  28.         # 为了保证掩码能正确广播到4维的卷积核组上,我们做一个reshape操作
  29.         mask = mask.reshape((1, 1, H, W))
  30.         # register_buffer可以把一个变量加入成员变量的同时,记录到PyTorch的Module中
  31.         # 每当执行model.to(device)把模型中所有参数转到某个设备上时,被注册的变量会跟着转。
  32.         # 第三个参数表示被注册的变量是否要加入state_dict中以保存下来
  33.         self.register_buffer(name='mask', tensor=mask, persistent=False)
  34.     def forward(self, x):
  35.         self.conv.weight.data *= self.mask
  36.         conv_res = self.conv(x)
  37.         return conv_res
复制代码
有了最焦点的掩码卷积,我们来根据论文中的模子布局图把模子搭起来



  • 我们先实现残差块上图右部门的ResidualBlock,这里添加归一化
  1. class ResidualBlock(nn.Module):
  2.     """
  3.         残差块ResidualBlock
  4.     """
  5.     def __init__(self, h, bn=True):
  6.         super().__init__()
  7.         self.relu = nn.ReLU()
  8.         self.conv1 = nn.Conv2d(2 * h, h, 1)
  9.         self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()
  10.         self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)
  11.         self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()
  12.         self.conv3 = nn.Conv2d(h, 2 * h, 1)
  13.         self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()
  14.     def forward(self, x):
  15.         # 1、ReLU + 1×1 Conv + bn
  16.         y = self.relu(x)
  17.         y = self.conv1(y)
  18.         y = self.bn1(y)
  19.         # 2、ReLU + 3×3 Conv(mask B) + bn
  20.         y = self.relu(y)
  21.         y = self.conv2(y)
  22.         y = self.bn2(y)
  23.         # 3、ReLU + 1×1 Conv + bn
  24.         y = self.relu(y)
  25.         y = self.conv3(y)
  26.         y = self.bn3(y)
  27.         # 4、残差连接
  28.         y = y + x
  29.         return y
复制代码


  • 有了全部这些根本模块后,我们就可以拼出最终的PixelCNN了。
  • 留意,我们可以自己决定颜色有几个亮度级别。要修改亮度级别的数目,只必要修改softmax输出的通道数color_level。
  1. class PixelCNN(nn.Module):
  2.     def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):
  3.         super().__init__()
  4.         self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)
  5.         self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()
  6.         self.residual_blocks = nn.ModuleList()
  7.         for _ in range(n_blocks):
  8.             self.residual_blocks.append(ResidualBlock(h, bn))
  9.         self.relu = nn.ReLU()
  10.         self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)
  11.         self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
  12.         self.out = nn.Conv2d(linear_dim, color_level, 1)
  13.     def forward(self, x):
  14.         # 1、7 × 7 conv(mask A)
  15.         x = self.conv1(x)
  16.         x = self.bn1(x)
  17.         # 2、Multiple residual blocks
  18.         for block in self.residual_blocks:
  19.             x = block(x)
  20.         x = self.relu(x)
  21.         # 3、1 × 1 conv
  22.         x = self.linear1(x)
  23.         x = self.relu(x)
  24.         x = self.linear2(x)
  25.         x = self.out(x)
  26.         return x
复制代码
1.1.3.2 数据集及练习

准备好了模子代码,我们可以编写练习脚本了:


  • PixelCNN有15个残差块,中心特性的通道数为128,输出火线性层的通道数为32
  1. def get_dataloader(batch_size: int):
  2.     dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',
  3.                                          train=True,
  4.                                          transform=ToTensor()
  5.                                          )
  6.     return DataLoader(dataset, batch_size=batch_size, shuffle=True)
  7. def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):
  8.     """训练过程"""
  9.     dataloader = get_dataloader(batch_size)
  10.     model = model.to(device)
  11.     optimizer = torch.optim.Adam(model.parameters(), 1e-3)
  12.     loss_fn = nn.CrossEntropyLoss()
  13.     tic = time.time()
  14.     for e in range(n_epochs):
  15.         total_loss = 0
  16.         for x, _ in dataloader:
  17.             current_batch_size = x.shape[0]
  18.             x = x.to(device)
  19.             # 把训练集的浮点颜色值转换成[0, color_level-1]之间的整型标签
  20.             y = torch.ceil(x * (color_level - 1)).long()
  21.             y = y.squeeze(1)
  22.             predict_y = model(x)
  23.             loss = loss_fn(predict_y, y)
  24.             optimizer.zero_grad()
  25.             loss.backward()
  26.             optimizer.step()
  27.             total_loss += loss.item() * current_batch_size
  28.         total_loss /= len(dataloader.dataset)
  29.         toc = time.time()
  30.         torch.save(model.state_dict(), model_path)
  31.         print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
  32. if __name__ == '__main__':
  33.     os.makedirs('work_dirs', exist_ok=True)
  34.     device = 'cuda' if torch.cuda.is_available() else 'cpu'
  35.     # 需要注意的是:MNIST数据集的大部分像素都是0和255
  36.     color_level = 8  # or 256
  37.     # 1、创建PixelCNN模型
  38.     model = PixelCNN(n_blocks=15, h=128, linear_dim=32, bn=True, color_level=color_level)
  39.     # 2、模型训练
  40.     model_path = f'work_dirs/model_pixelcnn_{color_level}.pth'
  41.     train(model, device, model_path)
  42.     # 3、采样
  43.     sample(model, device, model_path, f'work_dirs/pixelcnn_{color_level}.jpg')        
复制代码
1.1.3.3 采样



  • 在采样时,我们把x初始化成一个0张量。
  • 之后,循环遍历每一个像素,输入x,把猜测出的下一个像素填入x.
  1. def sample(model, device, model_path, output_path, n_sample=1):
  2.     """
  3.         把x初始化成一个0张量。
  4.         循环遍历每一个像素,输入x,把预测出的下一个像素填入x
  5.     """
  6.     model.eval()
  7.     model.load_state_dict(torch.load(model_path))
  8.     model = model.to(device)
  9.     C, H, W = get_img_shape()  # (1, 28, 28)
  10.     x = torch.zeros((n_sample, C, H, W)).to(device)
  11.     with torch.no_grad():
  12.         for i in range(H):
  13.             for j in range(W):
  14.                 # 我们先获取模型的输出,再用softmax转换成概率分布
  15.                 output = model(x)
  16.                 prob_dist = F.softmax(output[:, :, i, j], -1)
  17.                 # 再用torch.multinomial从概率分布里采样出【1】个[0, color_level-1]的离散颜色值
  18.                 # 再除以(color_level - 1)把离散颜色转换成浮点[0, 1]
  19.                 pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)
  20.                 # 最后把新像素填入到生成图像中
  21.                 x[:, :, i, j] = pixel
  22.     # 乘255变成一个用8位字节表示的图像
  23.     imgs = x * 255
  24.     imgs = imgs.clamp(0, 255)
  25.     imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))
  26.     imgs = imgs.detach().cpu().numpy().astype(np.uint8)
  27.     cv2.imwrite(output_path, imgs)
复制代码
1.2 多通道PixelCNN

如下图所示,作者假设RGB三个通道之间存在相互影响


  • 此中赤色猜测不受蓝色和绿色通道的影响,只受上下文影响
  • 绿色赤色通道和上下文影响,但不受蓝色通道影响;
  • 蓝色通道受上下文、赤色通道、绿色通道影响

更详细地,我们规定一个子像素只由它之前的子像素决定,天生图像时,我们一个子像素一个子像素地天生。


  • 如下图所示,对于RGB图像,R子像素由它之前全部像素决定
  • G子像素由它的R子像素和之前全部像素决定,
  • B子像素由它的R、G子像素和它之前全部像素决定。

如下图所示,由于现在要猜测三个颜色通道,网络的输出应该是一个[256x3, H, W]外形的张量


  • 即每个像素输出三个概率分布,分别表现R、G、B取某种颜色的概率。
  • 同时,本质上来讲,网络是在并行地为每个像素盘算3组结果。因此,为了达到同样的性能,网络全部的特性图的通道数也要乘3。

图像变为多通道后,A类卷积和B类卷积的定义也必要做出一些调解。我们不仅要思量像素在空间上的束缚,还要思量一个像素内子像素间的束缚。为此,我们要用不同的策略实现束缚。为了方便描述,我们设卷积核组的外形为[o, i, h, w],此中o为输出通道数,i为输入通道数,h, w为卷积核的高和宽。


  • 对于通道间的束缚,我们要在o, i两个维度上设置掩码,如下图左边所示。

    • 设输出通道可以被拆成三组o1, o2, o3,输入通道可以被拆成三组i1, i2, i3

      • 即o1 = 0/3, o2 = o/3*2/3, o3 = o*2/3
      • i1 = 0:i/3, i2 = i/3:i*2/3, i3 = i*2/3:i。
      • 序号1, 2, 3分别表现这组通道是在维护R, G, B的盘算。

    • 我们对输入通道组和输出通道组之间举行束缚。
    • 对于A类卷积,我们令o1看不到i1, i2, i3,o2看不到i2, i3,o3看不到i3;
    • 对于B类卷积,我们取消每个通道看不到自己的限定,即在A类卷积的根本上令o1看到i1,o2看到i2,o3看到i3。

  • 如下图右边所示,对于空间上的束缚,我们照旧和之前一样,在h, w两个维度上设置掩码。由于「是否看到自己」的处理已经在o, i两个维度里做好了,我们直接在空间上用原来的B类卷积就行。



  • 下面给出三维掩码示意图方便理解:

2 Gated PixelCNN

2.1 Gated PixelCNN简述



  • 可以参考大神讲解:Gated PixelCNN (sergeiturukin.com)
  • PixelCNN的掩码卷积其实有一个重大漏洞:像素存在视野盲区。如下图所示,中心像素看不到右上角三个本应该能看到的像素。



  • 为此,PixelCNN论文的作者又发表了Conditional Image Generation with PixelCNN Decoders(16.06)。这篇论文提出了一种叫做Gated PixelCNN的改进架构。Gated PixelCNN使用了一种更好的掩码卷积机制,消除了原PixelCNN里的视野盲区。



  • 如下图所示,Gated PixelCNN使用了两种卷积,即垂直卷积和水平卷积,来分别维护一个像素上侧的信息和左侧的信息

    • 垂直卷积的结果只是一些临时量
    • 而水平卷积的结果最终会被网络输出
    • 使用这种新的掩码卷积机制后,每个像素能正确地收到之前全部像素的信息了。




  • Gated PixelCNN用下图的模块代替了原PixelCNN的平凡残差模块。
  • 模块的输入输出都是两个量,左边的量是垂直卷积中心结果,右边的量是最后用来盘算输出的量。
  • 垂直卷积的结果会经过偏移和一个1x1卷积,再加到水平卷积的结果上。
  • 两条盘算路线在输出前都会经过门激活单元。所谓门激活单元,就是输入两个外形雷同的量,一个做tanh,一个做sigmoid,两个结果相乘再输出。
  • 别的,模块右侧尚有一个残差连接。

2.2 Gated PixelCNN在MNIST数据集上的应用

2.2.1 创建模子



  • 首先,实现垂直卷积和水平卷积
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader
  5. import torchvision
  6. from torchvision.transforms import ToTensor
  7. import time
  8. import einops
  9. import cv2
  10. import numpy as np
  11. import os
  12. class VerticalMaskConv2d(nn.Module):
  13.     """
  14.         垂直卷积
  15.     """
  16.     def __init__(self, *args, **kwags):
  17.         super().__init__()
  18.         self.conv = nn.Conv2d(*args, **kwags)
  19.         H, W = self.conv.weight.shape[-2:]
  20.         mask = torch.zeros((H, W), dtype=torch.float32)
  21.         mask[0:H // 2 + 1] = 1
  22.         mask = mask.reshape((1, 1, H, W))
  23.         self.register_buffer('mask', mask, False)
  24.     def forward(self, x):
  25.         self.conv.weight.data *= self.mask
  26.         conv_res = self.conv(x)
  27.         return conv_res
  28. class HorizontalMaskConv2d(nn.Module):
  29.     """
  30.         水平卷积
  31.     """
  32.     def __init__(self, conv_type, *args, **kwags):
  33.         super().__init__()
  34.         assert conv_type in ('A', 'B')
  35.         self.conv = nn.Conv2d(*args, **kwags)
  36.         H, W = self.conv.weight.shape[-2:]
  37.         mask = torch.zeros((H, W), dtype=torch.float32)
  38.         mask[H // 2, 0:W // 2] = 1
  39.         if conv_type == 'B':
  40.             mask[H // 2, W // 2] = 1
  41.         mask = mask.reshape((1, 1, H, W))
  42.         self.register_buffer('mask', mask, False)
  43.     def forward(self, x):
  44.         self.conv.weight.data *= self.mask
  45.         conv_res = self.conv(x)
  46.         return conv_res
复制代码
  1. # 垂直卷积
  2. tensor([[[[1., 1., 1.],
  3.           [1., 1., 1.],
  4.           [0., 0., 0.]]]])
  5. # A类水平卷积
  6. tensor([[[[0., 0., 0.],
  7.           [1., 0., 0.],
  8.           [0., 0., 0.]]]])
  9. # B类水平卷积
  10. tensor([[[[0., 0., 0.],
  11.           [1., 1., 0.],
  12.           [0., 0., 0.]]]])
复制代码


  • 我们现在搭建Gated Block模块,这也是最难理解的一部门。
  • 可以参考的解释:https://segmentfault.com/a/1190000041189859?utm_source=sf-similar-article



    1. # 这里比较难理解,通过对图像进行零填充并裁剪图像底部,可以确保垂直和水平堆栈之间的因果关系
    2. v_to_h = v[:, :, 0:-1]
    3. v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
    4. # 注意到,v和i相加的位置只差了一个单位。
    5. # 为了把相加的位置对齐,我们要把v往下移一个单位,把原来在i-1处的信息移到i上。
    6. # 这样,移动过后的v_to_h就能和h直接用向量加法并行地加到一起了。
    复制代码



  • 维护两个v, h两个变量,分别表现垂直卷积部门的结果和水平卷积部门的结果。

    • v会经过一个垂直掩码卷积和一个门激活函数。
    • h会经过一个类似于残差块的布局,只不过第一个卷积是水平掩码卷积、激活函数是门激活函数、进入激活函数之前会和垂直卷积的信息融合。

  1. class GatedBlock(nn.Module):
  2.     def __init__(self, conv_type, in_channels, p, bn=True):
  3.         super().__init__()
  4.         self.conv_type = conv_type
  5.         self.p = p
  6.         self.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)
  7.         self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
  8.         self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, kernel_size=1)
  9.         self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
  10.         self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,
  11.                                            1)
  12.         self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
  13.         self.h_output_conv = nn.Conv2d(p, p, 1)
  14.         self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()
  15.     def forward(self, v_input, h_input):
  16.         # v代表垂直卷积部分的结果
  17.         v = self.v_conv(v_input)
  18.         v = self.bn1(v)
  19.         # Note: 重点代码
  20.         # 为了把v的信息贴到h上,我们并不是像前面的示意图所写的令v上移一个单位
  21.         # 而是用下面的代码令v下移了一个单位(下移即去掉最下面一行,往最上面一行填0)
  22.         v_to_h = v[:, :, 0:-1]
  23.         v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
  24.         # 和h相加前,先经过 1×1 conv
  25.         v_to_h = self.v_to_h_conv(v_to_h)
  26.         v_to_h = self.bn2(v_to_h)
  27.         # 分为两份,经过tanh 和 sigmoid
  28.         v1, v2 = v[:, :self.p], v[:, self.p:]
  29.         v1 = torch.tanh(v1)
  30.         v2 = torch.sigmoid(v2)
  31.         v = v1 * v2
  32.         # h代表水平卷积部分的结果
  33.         h = self.h_conv(h_input)
  34.         h = self.bn3(h)
  35.         h = h + v_to_h
  36.         # 分为两份,经过tanh 和 sigmoid
  37.         h1, h2 = h[:, :self.p], h[:, self.p:]
  38.         h1 = torch.tanh(h1)
  39.         h2 = torch.sigmoid(h2)
  40.         h = h1 * h2
  41.         h = self.h_output_conv(h)
  42.         h = self.bn4(h)
  43.         # 在网络的第一层,每个数据是不能看到自己的。
  44.         # 所以,当GatedBlock发现卷积类型为A类时,不应该对h做残差连接。
  45.         if self.conv_type == 'B':
  46.             h = h + h_input
  47.         return v, h
复制代码


  • 最后,我们来用GatedBlock搭出Gated PixelCNN
  • Gated PixelCNN和PixelCNN的布局非常相似,只是把ResidualBlock更换成了GatedBlock而已。
  1. class GatedPixelCNN(nn.Module):
  2.     def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
  3.         super().__init__()
  4.         self.block1 = GatedBlock('A', 1, p, bn)
  5.         self.blocks = nn.ModuleList()
  6.         for _ in range(n_blocks):
  7.             self.blocks.append(GatedBlock('B', p, p, bn))
  8.         self.relu = nn.ReLU()
  9.         self.linear1 = nn.Conv2d(p, linear_dim, 1)
  10.         self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
  11.         self.out = nn.Conv2d(linear_dim, color_level, 1)
  12.     def forward(self, x):
  13.         v, h = self.block1(x, x)
  14.         for block in self.blocks:
  15.             v, h = block(v, h)
  16.         x = self.relu(h)
  17.         x = self.linear1(x)
  18.         x = self.relu(x)
  19.         x = self.linear2(x)
  20.         x = self.out(x)
  21.         return x
复制代码
2.2.2 数据集、练习及采样



  • 数据集、练习及采样和PixelCNN千篇一律,不再赘述。
  1. def get_dataloader(batch_size: int):
  2.     dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',
  3.                                          train=True,
  4.                                          transform=ToTensor()
  5.                                          )
  6.     return DataLoader(dataset, batch_size=batch_size, shuffle=True)
  7. def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):
  8.     """训练过程"""
  9.     dataloader = get_dataloader(batch_size)
  10.     model = model.to(device)
  11.     optimizer = torch.optim.Adam(model.parameters(), 1e-3)
  12.     loss_fn = nn.CrossEntropyLoss()
  13.     tic = time.time()
  14.     for e in range(n_epochs):
  15.         total_loss = 0
  16.         for x, _ in dataloader:
  17.             current_batch_size = x.shape[0]
  18.             x = x.to(device)
  19.             # 把训练集的浮点颜色值转换成0~color_level-1之间的整型标签的
  20.             y = torch.ceil(x * (color_level - 1)).long()
  21.             y = y.squeeze(1)
  22.             predict_y = model(x)
  23.             loss = loss_fn(predict_y, y)
  24.             optimizer.zero_grad()
  25.             loss.backward()
  26.             optimizer.step()
  27.             total_loss += loss.item() * current_batch_size
  28.         total_loss /= len(dataloader.dataset)
  29.         toc = time.time()
  30.         torch.save(model.state_dict(), model_path)
  31.         print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
  32. def get_img_shape():
  33.     return (1, 28, 28)
  34. def sample(model, device, model_path, output_path, n_sample=1):
  35.     """
  36.         把x初始化成一个0张量。
  37.         循环遍历每一个像素,输入x,把预测出的下一个像素填入x
  38.     """
  39.     model.eval()
  40.     model.load_state_dict(torch.load(model_path))
  41.     model = model.to(device)
  42.     C, H, W = get_img_shape()  # (1, 28, 28)
  43.     x = torch.zeros((n_sample, C, H, W)).to(device)
  44.     with torch.no_grad():
  45.         for i in range(H):
  46.             for j in range(W):
  47.                 # 我们先获取模型的输出,再用softmax转换成概率分布
  48.                 output = model(x)
  49.                 prob_dist = F.softmax(output[:, :, i, j], -1)
  50.                 # 再用torch.multinomial从概率分布里采样出【1个】0~(color_level-1)的离散颜色值
  51.                 # 再除以(color_level - 1)把离散颜色转换成浮点颜色(因为网络是输入是浮点颜色)
  52.                 pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)
  53.                 # 最后把新像素填入生成图像
  54.                 x[:, :, i, j] = pixel
  55.     imgs = x * 255
  56.     imgs = imgs.clamp(0, 255)
  57.     imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))
  58.     imgs = imgs.detach().cpu().numpy().astype(np.uint8)
  59.     cv2.imwrite(output_path, imgs)
  60. if __name__ == '__main__':
  61.     os.makedirs('work_dirs', exist_ok=True)
  62.     device = 'cuda' if torch.cuda.is_available() else 'cpu'
  63.     color_level = 8  # or 256
  64.     # 1、创建GatedPixelCNN模型
  65.     model = GatedPixelCNN(n_blocks=15, p=128, linear_dim=32, bn=True, color_level=color_level)
  66.     # 2、模型训练
  67.     model_path = f'work_dirs/model_gatedpixelcnn_{color_level}.pth'
  68.     train(model, device, model_path, batch_size=1)
  69.     # 3、采样
  70.     sample(model, device, model_path, f'work_dirs/gatedpixelcnn_{color_level}.jpg')
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

羊蹓狼

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表