ToB企服应用市场:ToB评测及商务社交产业平台

标题: 经典神经网络(10)PixelCNN模子、Gated PixelCNN模子及其在MNIST数据集上的 [打印本页]

作者: 羊蹓狼    时间: 2024-6-26 09:25
标题: 经典神经网络(10)PixelCNN模子、Gated PixelCNN模子及其在MNIST数据集上的
经典神经网络(10)PixelCNN模子、Gated PixelCNN模子及其在MNIST数据集上的应用

1 PixelCNN


1.1 单通道PixelCNN

1.1.1 掩码卷积

我们现在知道了PixelCNN的大要思路,就是根据前i - 1个像素输出第i个像素的概率分布。我们现在只思量单通道图像,每个像素的颜色取值只有256种,那么很容易想到下面的实现方式:

但是只输出一个像素的概率分布,这样练习服从太低了。


但是在天生图像(采样)时,照旧要一个像素一个像素的天生(如下所示)

  1. # 假设颜色取值范围为[0, 7],下面为概率分布
  2. prob_dist = torch.tensor([[0.1347, 0.1356, 0.1048, 0.1314, 0.1329, 0.1256, 0.1326, 0.1025]])
  3. # 我们并不是取概率最大的像素,而是从概率分布中采样(例如下面取像素值6)
  4. # torch.multinomial会从input这个概率分布中,取num_samples个值
  5. pixel = torch.multinomial(input=prob_dist, num_samples=1).float() # tensor([[6.]])
复制代码

我们现在已经知道了练习及采样的大要过程。但是,我们现在照旧有一个疑问,如何保证练习时候,每个像素都忽略后续像素的信息?
PixelCNN论文里提出了一种掩码卷积机制,这种机制可以巧妙地掩盖住每个像素右侧和下侧的信息。


我们来分析下这样计划的长处:





总结如下:

1.1.2 PixelCNN的网络架构



1.1.3 PixelCNN在MNIST数据集上的应用

1.1.3.1 模子

实现PixelCNN,最紧张的是实现掩码卷积。

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader
  5. import torchvision
  6. from torchvision.transforms import ToTensor
  7. import time
  8. import einops
  9. import cv2
  10. import numpy as np
  11. import os
  12. class MaskConv2d(nn.Module):
  13.     """
  14.         掩码卷积的实现思路:
  15.             在卷积核组上设置一个mask,在前向传播的时候,先让卷积核组乘mask,再做普通的卷积
  16.     """
  17.     def __init__(self, conv_type, *args, **kwags):
  18.         super().__init__()
  19.         assert conv_type in ('A', 'B')
  20.         self.conv = nn.Conv2d(*args, **kwags)
  21.         H, W = self.conv.weight.shape[-2:]
  22.         # 由于输入输出都是单通道图像,我们只需要在卷积核的h, w两个维度设置掩码
  23.         mask = torch.zeros((H, W), dtype=torch.float32)
  24.         mask[0:H // 2] = 1
  25.         mask[H // 2, 0:W // 2] = 1
  26.         if conv_type == 'B':
  27.             mask[H // 2, W // 2] = 1
  28.         # 为了保证掩码能正确广播到4维的卷积核组上,我们做一个reshape操作
  29.         mask = mask.reshape((1, 1, H, W))
  30.         # register_buffer可以把一个变量加入成员变量的同时,记录到PyTorch的Module中
  31.         # 每当执行model.to(device)把模型中所有参数转到某个设备上时,被注册的变量会跟着转。
  32.         # 第三个参数表示被注册的变量是否要加入state_dict中以保存下来
  33.         self.register_buffer(name='mask', tensor=mask, persistent=False)
  34.     def forward(self, x):
  35.         self.conv.weight.data *= self.mask
  36.         conv_res = self.conv(x)
  37.         return conv_res
复制代码
有了最焦点的掩码卷积,我们来根据论文中的模子布局图把模子搭起来


  1. class ResidualBlock(nn.Module):
  2.     """
  3.         残差块ResidualBlock
  4.     """
  5.     def __init__(self, h, bn=True):
  6.         super().__init__()
  7.         self.relu = nn.ReLU()
  8.         self.conv1 = nn.Conv2d(2 * h, h, 1)
  9.         self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()
  10.         self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)
  11.         self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()
  12.         self.conv3 = nn.Conv2d(h, 2 * h, 1)
  13.         self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()
  14.     def forward(self, x):
  15.         # 1、ReLU + 1×1 Conv + bn
  16.         y = self.relu(x)
  17.         y = self.conv1(y)
  18.         y = self.bn1(y)
  19.         # 2、ReLU + 3×3 Conv(mask B) + bn
  20.         y = self.relu(y)
  21.         y = self.conv2(y)
  22.         y = self.bn2(y)
  23.         # 3、ReLU + 1×1 Conv + bn
  24.         y = self.relu(y)
  25.         y = self.conv3(y)
  26.         y = self.bn3(y)
  27.         # 4、残差连接
  28.         y = y + x
  29.         return y
复制代码

  1. class PixelCNN(nn.Module):
  2.     def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):
  3.         super().__init__()
  4.         self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)
  5.         self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()
  6.         self.residual_blocks = nn.ModuleList()
  7.         for _ in range(n_blocks):
  8.             self.residual_blocks.append(ResidualBlock(h, bn))
  9.         self.relu = nn.ReLU()
  10.         self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)
  11.         self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
  12.         self.out = nn.Conv2d(linear_dim, color_level, 1)
  13.     def forward(self, x):
  14.         # 1、7 × 7 conv(mask A)
  15.         x = self.conv1(x)
  16.         x = self.bn1(x)
  17.         # 2、Multiple residual blocks
  18.         for block in self.residual_blocks:
  19.             x = block(x)
  20.         x = self.relu(x)
  21.         # 3、1 × 1 conv
  22.         x = self.linear1(x)
  23.         x = self.relu(x)
  24.         x = self.linear2(x)
  25.         x = self.out(x)
  26.         return x
复制代码
1.1.3.2 数据集及练习

准备好了模子代码,我们可以编写练习脚本了:

  1. def get_dataloader(batch_size: int):
  2.     dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',
  3.                                          train=True,
  4.                                          transform=ToTensor()
  5.                                          )
  6.     return DataLoader(dataset, batch_size=batch_size, shuffle=True)
  7. def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):
  8.     """训练过程"""
  9.     dataloader = get_dataloader(batch_size)
  10.     model = model.to(device)
  11.     optimizer = torch.optim.Adam(model.parameters(), 1e-3)
  12.     loss_fn = nn.CrossEntropyLoss()
  13.     tic = time.time()
  14.     for e in range(n_epochs):
  15.         total_loss = 0
  16.         for x, _ in dataloader:
  17.             current_batch_size = x.shape[0]
  18.             x = x.to(device)
  19.             # 把训练集的浮点颜色值转换成[0, color_level-1]之间的整型标签
  20.             y = torch.ceil(x * (color_level - 1)).long()
  21.             y = y.squeeze(1)
  22.             predict_y = model(x)
  23.             loss = loss_fn(predict_y, y)
  24.             optimizer.zero_grad()
  25.             loss.backward()
  26.             optimizer.step()
  27.             total_loss += loss.item() * current_batch_size
  28.         total_loss /= len(dataloader.dataset)
  29.         toc = time.time()
  30.         torch.save(model.state_dict(), model_path)
  31.         print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
  32. if __name__ == '__main__':
  33.     os.makedirs('work_dirs', exist_ok=True)
  34.     device = 'cuda' if torch.cuda.is_available() else 'cpu'
  35.     # 需要注意的是:MNIST数据集的大部分像素都是0和255
  36.     color_level = 8  # or 256
  37.     # 1、创建PixelCNN模型
  38.     model = PixelCNN(n_blocks=15, h=128, linear_dim=32, bn=True, color_level=color_level)
  39.     # 2、模型训练
  40.     model_path = f'work_dirs/model_pixelcnn_{color_level}.pth'
  41.     train(model, device, model_path)
  42.     # 3、采样
  43.     sample(model, device, model_path, f'work_dirs/pixelcnn_{color_level}.jpg')        
复制代码
1.1.3.3 采样


  1. def sample(model, device, model_path, output_path, n_sample=1):
  2.     """
  3.         把x初始化成一个0张量。
  4.         循环遍历每一个像素,输入x,把预测出的下一个像素填入x
  5.     """
  6.     model.eval()
  7.     model.load_state_dict(torch.load(model_path))
  8.     model = model.to(device)
  9.     C, H, W = get_img_shape()  # (1, 28, 28)
  10.     x = torch.zeros((n_sample, C, H, W)).to(device)
  11.     with torch.no_grad():
  12.         for i in range(H):
  13.             for j in range(W):
  14.                 # 我们先获取模型的输出,再用softmax转换成概率分布
  15.                 output = model(x)
  16.                 prob_dist = F.softmax(output[:, :, i, j], -1)
  17.                 # 再用torch.multinomial从概率分布里采样出【1】个[0, color_level-1]的离散颜色值
  18.                 # 再除以(color_level - 1)把离散颜色转换成浮点[0, 1]
  19.                 pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)
  20.                 # 最后把新像素填入到生成图像中
  21.                 x[:, :, i, j] = pixel
  22.     # 乘255变成一个用8位字节表示的图像
  23.     imgs = x * 255
  24.     imgs = imgs.clamp(0, 255)
  25.     imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))
  26.     imgs = imgs.detach().cpu().numpy().astype(np.uint8)
  27.     cv2.imwrite(output_path, imgs)
复制代码
1.2 多通道PixelCNN

如下图所示,作者假设RGB三个通道之间存在相互影响


更详细地,我们规定一个子像素只由它之前的子像素决定,天生图像时,我们一个子像素一个子像素地天生。


如下图所示,由于现在要猜测三个颜色通道,网络的输出应该是一个[256x3, H, W]外形的张量


图像变为多通道后,A类卷积和B类卷积的定义也必要做出一些调解。我们不仅要思量像素在空间上的束缚,还要思量一个像素内子像素间的束缚。为此,我们要用不同的策略实现束缚。为了方便描述,我们设卷积核组的外形为[o, i, h, w],此中o为输出通道数,i为输入通道数,h, w为卷积核的高和宽。




2 Gated PixelCNN

2.1 Gated PixelCNN简述









2.2 Gated PixelCNN在MNIST数据集上的应用

2.2.1 创建模子


  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader
  5. import torchvision
  6. from torchvision.transforms import ToTensor
  7. import time
  8. import einops
  9. import cv2
  10. import numpy as np
  11. import os
  12. class VerticalMaskConv2d(nn.Module):
  13.     """
  14.         垂直卷积
  15.     """
  16.     def __init__(self, *args, **kwags):
  17.         super().__init__()
  18.         self.conv = nn.Conv2d(*args, **kwags)
  19.         H, W = self.conv.weight.shape[-2:]
  20.         mask = torch.zeros((H, W), dtype=torch.float32)
  21.         mask[0:H // 2 + 1] = 1
  22.         mask = mask.reshape((1, 1, H, W))
  23.         self.register_buffer('mask', mask, False)
  24.     def forward(self, x):
  25.         self.conv.weight.data *= self.mask
  26.         conv_res = self.conv(x)
  27.         return conv_res
  28. class HorizontalMaskConv2d(nn.Module):
  29.     """
  30.         水平卷积
  31.     """
  32.     def __init__(self, conv_type, *args, **kwags):
  33.         super().__init__()
  34.         assert conv_type in ('A', 'B')
  35.         self.conv = nn.Conv2d(*args, **kwags)
  36.         H, W = self.conv.weight.shape[-2:]
  37.         mask = torch.zeros((H, W), dtype=torch.float32)
  38.         mask[H // 2, 0:W // 2] = 1
  39.         if conv_type == 'B':
  40.             mask[H // 2, W // 2] = 1
  41.         mask = mask.reshape((1, 1, H, W))
  42.         self.register_buffer('mask', mask, False)
  43.     def forward(self, x):
  44.         self.conv.weight.data *= self.mask
  45.         conv_res = self.conv(x)
  46.         return conv_res
复制代码
  1. # 垂直卷积
  2. tensor([[[[1., 1., 1.],
  3.           [1., 1., 1.],
  4.           [0., 0., 0.]]]])
  5. # A类水平卷积
  6. tensor([[[[0., 0., 0.],
  7.           [1., 0., 0.],
  8.           [0., 0., 0.]]]])
  9. # B类水平卷积
  10. tensor([[[[0., 0., 0.],
  11.           [1., 1., 0.],
  12.           [0., 0., 0.]]]])
复制代码





  1. class GatedBlock(nn.Module):
  2.     def __init__(self, conv_type, in_channels, p, bn=True):
  3.         super().__init__()
  4.         self.conv_type = conv_type
  5.         self.p = p
  6.         self.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)
  7.         self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
  8.         self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, kernel_size=1)
  9.         self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
  10.         self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,
  11.                                            1)
  12.         self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
  13.         self.h_output_conv = nn.Conv2d(p, p, 1)
  14.         self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()
  15.     def forward(self, v_input, h_input):
  16.         # v代表垂直卷积部分的结果
  17.         v = self.v_conv(v_input)
  18.         v = self.bn1(v)
  19.         # Note: 重点代码
  20.         # 为了把v的信息贴到h上,我们并不是像前面的示意图所写的令v上移一个单位
  21.         # 而是用下面的代码令v下移了一个单位(下移即去掉最下面一行,往最上面一行填0)
  22.         v_to_h = v[:, :, 0:-1]
  23.         v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
  24.         # 和h相加前,先经过 1×1 conv
  25.         v_to_h = self.v_to_h_conv(v_to_h)
  26.         v_to_h = self.bn2(v_to_h)
  27.         # 分为两份,经过tanh 和 sigmoid
  28.         v1, v2 = v[:, :self.p], v[:, self.p:]
  29.         v1 = torch.tanh(v1)
  30.         v2 = torch.sigmoid(v2)
  31.         v = v1 * v2
  32.         # h代表水平卷积部分的结果
  33.         h = self.h_conv(h_input)
  34.         h = self.bn3(h)
  35.         h = h + v_to_h
  36.         # 分为两份,经过tanh 和 sigmoid
  37.         h1, h2 = h[:, :self.p], h[:, self.p:]
  38.         h1 = torch.tanh(h1)
  39.         h2 = torch.sigmoid(h2)
  40.         h = h1 * h2
  41.         h = self.h_output_conv(h)
  42.         h = self.bn4(h)
  43.         # 在网络的第一层,每个数据是不能看到自己的。
  44.         # 所以,当GatedBlock发现卷积类型为A类时,不应该对h做残差连接。
  45.         if self.conv_type == 'B':
  46.             h = h + h_input
  47.         return v, h
复制代码

  1. class GatedPixelCNN(nn.Module):
  2.     def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
  3.         super().__init__()
  4.         self.block1 = GatedBlock('A', 1, p, bn)
  5.         self.blocks = nn.ModuleList()
  6.         for _ in range(n_blocks):
  7.             self.blocks.append(GatedBlock('B', p, p, bn))
  8.         self.relu = nn.ReLU()
  9.         self.linear1 = nn.Conv2d(p, linear_dim, 1)
  10.         self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
  11.         self.out = nn.Conv2d(linear_dim, color_level, 1)
  12.     def forward(self, x):
  13.         v, h = self.block1(x, x)
  14.         for block in self.blocks:
  15.             v, h = block(v, h)
  16.         x = self.relu(h)
  17.         x = self.linear1(x)
  18.         x = self.relu(x)
  19.         x = self.linear2(x)
  20.         x = self.out(x)
  21.         return x
复制代码
2.2.2 数据集、练习及采样


  1. def get_dataloader(batch_size: int):
  2.     dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',
  3.                                          train=True,
  4.                                          transform=ToTensor()
  5.                                          )
  6.     return DataLoader(dataset, batch_size=batch_size, shuffle=True)
  7. def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):
  8.     """训练过程"""
  9.     dataloader = get_dataloader(batch_size)
  10.     model = model.to(device)
  11.     optimizer = torch.optim.Adam(model.parameters(), 1e-3)
  12.     loss_fn = nn.CrossEntropyLoss()
  13.     tic = time.time()
  14.     for e in range(n_epochs):
  15.         total_loss = 0
  16.         for x, _ in dataloader:
  17.             current_batch_size = x.shape[0]
  18.             x = x.to(device)
  19.             # 把训练集的浮点颜色值转换成0~color_level-1之间的整型标签的
  20.             y = torch.ceil(x * (color_level - 1)).long()
  21.             y = y.squeeze(1)
  22.             predict_y = model(x)
  23.             loss = loss_fn(predict_y, y)
  24.             optimizer.zero_grad()
  25.             loss.backward()
  26.             optimizer.step()
  27.             total_loss += loss.item() * current_batch_size
  28.         total_loss /= len(dataloader.dataset)
  29.         toc = time.time()
  30.         torch.save(model.state_dict(), model_path)
  31.         print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
  32. def get_img_shape():
  33.     return (1, 28, 28)
  34. def sample(model, device, model_path, output_path, n_sample=1):
  35.     """
  36.         把x初始化成一个0张量。
  37.         循环遍历每一个像素,输入x,把预测出的下一个像素填入x
  38.     """
  39.     model.eval()
  40.     model.load_state_dict(torch.load(model_path))
  41.     model = model.to(device)
  42.     C, H, W = get_img_shape()  # (1, 28, 28)
  43.     x = torch.zeros((n_sample, C, H, W)).to(device)
  44.     with torch.no_grad():
  45.         for i in range(H):
  46.             for j in range(W):
  47.                 # 我们先获取模型的输出,再用softmax转换成概率分布
  48.                 output = model(x)
  49.                 prob_dist = F.softmax(output[:, :, i, j], -1)
  50.                 # 再用torch.multinomial从概率分布里采样出【1个】0~(color_level-1)的离散颜色值
  51.                 # 再除以(color_level - 1)把离散颜色转换成浮点颜色(因为网络是输入是浮点颜色)
  52.                 pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)
  53.                 # 最后把新像素填入生成图像
  54.                 x[:, :, i, j] = pixel
  55.     imgs = x * 255
  56.     imgs = imgs.clamp(0, 255)
  57.     imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))
  58.     imgs = imgs.detach().cpu().numpy().astype(np.uint8)
  59.     cv2.imwrite(output_path, imgs)
  60. if __name__ == '__main__':
  61.     os.makedirs('work_dirs', exist_ok=True)
  62.     device = 'cuda' if torch.cuda.is_available() else 'cpu'
  63.     color_level = 8  # or 256
  64.     # 1、创建GatedPixelCNN模型
  65.     model = GatedPixelCNN(n_blocks=15, p=128, linear_dim=32, bn=True, color_level=color_level)
  66.     # 2、模型训练
  67.     model_path = f'work_dirs/model_gatedpixelcnn_{color_level}.pth'
  68.     train(model, device, model_path, batch_size=1)
  69.     # 3、采样
  70.     sample(model, device, model_path, f'work_dirs/gatedpixelcnn_{color_level}.jpg')
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4