【胶囊网络】完美复现Hinton论文99.23%

饭宝  论坛元老 | 2025-3-16 20:53:52 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 1026|帖子 1026|积分 3078

我之前在2024-07-15的时候实现过一版胶囊网络,但是当时无论我怎么训练,都没办法达到Hinton论文里的99.23%(MNIST扩展数据集上):

前两天心血来潮,又认真读了一下Hinton论文,严酷按照论文要求进行复现:终极达到了Hinton里的性能上限,同时训练速度也比以往那版要快大概五六倍。
1. 导入数据集

  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. # 定义数据预处理
  7. transform = transforms.Compose([
  8.     transforms.ToTensor(),
  9.     transforms.Normalize((0.1307,), (0.3081,))
  10. ])
  11. # 下载并加载MNIST训练数据集
  12. trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  13. testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
  14. def show_image(image, label):
  15.     plt.imshow(image, cmap='gray')
  16.     plt.title(f'Label: {label}')
  17.     plt.show()
  18. # 显示一个训练样本
  19. show_image(trainset[0][0][0], trainset[0][1])
复制代码

2. 模子代码

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from tqdm import tqdm
  5. def squash(s):
  6.     norm = torch.norm(s, dim=-1, keepdim=True)
  7.     s_squared_norm = norm ** 2
  8.     return (s_squared_norm / (1 + s_squared_norm)) * (s / norm)
  9. def routing(u_hat, num_iteratiobns):
  10.     # u_hat: (B,N,M,D)
  11.     batch_size, num_capsules_i, num_capsules_j, n_dim = u_hat.size()
  12.     b_ij = torch.zeros(batch_size, num_capsules_i, num_capsules_j, 1).to(u_hat.device) # (B,N,M,1)
  13.     for _ in range(num_iteratiobns):
  14.         c_ij = F.softmax(b_ij, dim=1) # (B,N,M,D)
  15.         s_j = torch.sum(c_ij * u_hat, dim=1, keepdim=True) # (B,1,M,D)
  16.         v_j = squash(s_j) # (B,1,M,D)
  17.         b_ij += torch.sum(u_hat * v_j, dim=-1, keepdim=True) # (B,N,M,1)
  18.         
  19.     return v_j.squeeze(1) # (B,M,D)
  20. class PrimaryCaps(nn.Module):
  21.     def __init__(self, in_channels=256, out_channels=32, capsule_dim=8, kernel_size=9, stride=2):
  22.         super(PrimaryCaps, self).__init__()
  23.         self.capsule_dim = capsule_dim
  24.         self.out_channels = out_channels
  25.         # 使用卷积层来生成初级胶囊的输入(激活初始胶囊向量)
  26.         self.conv2 = nn.Conv2d(in_channels, out_channels * capsule_dim, kernel_size=kernel_size, stride=stride)
  27.         
  28.     def forward(self, x):
  29.         # x: (B, C, H, W)
  30.         B = x.size(0)
  31.         # 进行卷积操作,并将输出调整为胶囊向量的形式,然后整合所有胶囊向量
  32.         x = self.conv2(x).permute(0, 2, 3, 1).reshape(B, -1, self.capsule_dim).contiguous() # (B, N, D)
  33.         # 对每个胶囊的输出向量应用squash函数
  34.         x = squash(x)
  35.         return x
  36. class DigitCaps(nn.Module):
  37.     def __init__(self, num_capsules=10, num_route_nodes=1152, in_channels=8, out_channels=16, num_iterations=3):
  38.         """
  39.         :param in_channels: 输入胶囊的维度
  40.         :param out_channels: 输出胶囊的维度
  41.         :param num_capsules: 输出的胶囊数量,对应数字类别数(通常为 10)
  42.         :param num_iterations: 动态路由的迭代次数
  43.         :param W: 权重矩阵,用于将输入胶囊映射到输出胶囊
  44.         """
  45.         super(DigitCaps, self).__init__()
  46.         self.num_capsules = num_capsules
  47.         self.num_iterations = num_iterations
  48.         self.W = nn.Parameter(torch.randn(1, num_route_nodes, num_capsules, in_channels, out_channels))
  49.         # 也可以所以胶囊共享一个Wj,性能并不会比前者差多少。
  50.         # self.W = nn.Parameter(torch.randn(1, 1, num_capsules, in_channels, out_channels))
  51.    
  52.     def forward(self, x):
  53.         # x: (B, N, D1)
  54.         # 计算预测向量
  55.         x = x.unsqueeze(-2).unsqueeze(-2) # (B, N, 1, 1, D1)
  56.         u_hat = torch.matmul(x, self.W).squeeze(-2) # (B, N, M, D2)
  57.         # 进行动态路由
  58.         v = routing(u_hat, self.num_iterations) # (B, M, D2)
  59.         # 返回输出胶囊向量的长度
  60.         v = torch.norm(v, dim=-1) # (B, M)
  61.         return v
  62. class CapesuleNet(nn.Module):
  63.     def __init__(self, num_classes=10):
  64.         super(CapesuleNet, self).__init__()
  65.         self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1) # 灰度图只有一个原始维度
  66.         self.primary_capsules = PrimaryCaps()
  67.         self.digit_capsules = DigitCaps(num_capsules=num_classes)
  68.     def forward(self, x):
  69.         x = F.relu(self.conv1(x))
  70.         x = self.primary_capsules(x)
  71.         x = self.digit_capsules(x)
  72.         return x
  73.    
  74. class MarginLoss(nn.Module):
  75.     def __init__(self, m_plus=0.9, m_minus=0.1, lambd=0.5):
  76.         super(MarginLoss, self).__init__()
  77.         self.m_plus = m_plus
  78.         self.m_minus = m_minus
  79.         self.lambd = lambd
  80.     def forward(self, v, target):
  81.         """
  82.         v: 形状为 (batch_size, num_classes),v_k表示第k个数字胶囊的实例化向量的长度
  83.         target: 形状为 (batch_size, num_classes),one-hot编码的目标标签
  84.         """
  85.         target = torch.eye(num_classes)[target.detach().cpu()].to(device) # one-hot编码
  86.         left = torch.clamp(self.m_plus - v, min=0) ** 2
  87.         right = torch.clamp(v - self.m_minus, min=0) ** 2
  88.         loss = target * left + self.lambd * (1 - target) * right
  89.         return torch.mean(torch.sum(loss, dim=1))
复制代码
3. 训练代码

  1. num_epochs=10
  2. num_classes=10
  3. batch_size=128
  4. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  5. model = CapesuleNet().to(device)
  6. margin_loss = MarginLoss()
  7. # margin_loss = nn.CrossEntropyLoss()
  8. optimizer = torch.optim.AdamW(model.parameters())
  9. train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
  10. test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
  11. @torch.no_grad()
  12. def estimate_acc():
  13.     model.eval()
  14.     acc = {}
  15.     loader = test_loader
  16.     correct = 0
  17.     total = 0
  18.     for images, labels in loader:
  19.         images = images.to(device)
  20.         labels = labels.to(device)
  21.         outputs = model(images)
  22.         _, predicted = torch.max(outputs.data, 1)
  23.         total += labels.size(0)
  24.         correct += (predicted == labels).sum().item()
  25.     acc.update({'val': (correct / total) * 100})
  26.     model.train()
  27.     return acc
  28. for epoch in range(num_epochs):
  29.     correct = 0
  30.     total = 0
  31.     with tqdm(total=len(train_loader), desc="epoch %d" % epoch) as pbar:
  32.         for images, labels in train_loader:
  33.             images = images.to(device)
  34.             labels = labels.to(device)
  35.             optimizer.zero_grad()
  36.             outputs = model(images)
  37.             loss = margin_loss(outputs, labels)
  38.             loss.backward()
  39.             optimizer.step()
  40.             predicted = torch.argmax(outputs.data, dim=1)
  41.             total += labels.size(0)
  42.             correct += (predicted == labels).sum().item()
  43.             # 更新进度条
  44.             pbar.set_postfix({
  45.                 'Loss': f'{loss.item():.4f}',
  46.                 'Accuracy': f'{(correct / total) * 100:.4f}%(train)',
  47.             })
  48.             pbar.update(1)
  49.     acc = estimate_acc()
  50.     print(f"Accuracy: {acc['val']:.4f}%(val)")
复制代码

可以看到神经网络迅速地收敛,并在第7个epoch达到了论文里的性能上限99.23%!

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

饭宝

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表