我之前在2024-07-15的时候实现过一版胶囊网络,但是当时无论我怎么训练,都没办法达到Hinton论文里的99.23%(MNIST扩展数据集上):
前两天心血来潮,又认真读了一下Hinton论文,严酷按照论文要求进行复现:终极达到了Hinton里的性能上限,同时训练速度也比以往那版要快大概五六倍。
1. 导入数据集
- import torch
- import torchvision
- import torchvision.transforms as transforms
- import matplotlib.pyplot as plt
- import numpy as np
- # 定义数据预处理
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])
- # 下载并加载MNIST训练数据集
- trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
- testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
- def show_image(image, label):
- plt.imshow(image, cmap='gray')
- plt.title(f'Label: {label}')
- plt.show()
- # 显示一个训练样本
- show_image(trainset[0][0][0], trainset[0][1])
复制代码
2. 模子代码
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from tqdm import tqdm
- def squash(s):
- norm = torch.norm(s, dim=-1, keepdim=True)
- s_squared_norm = norm ** 2
- return (s_squared_norm / (1 + s_squared_norm)) * (s / norm)
- def routing(u_hat, num_iteratiobns):
- # u_hat: (B,N,M,D)
- batch_size, num_capsules_i, num_capsules_j, n_dim = u_hat.size()
- b_ij = torch.zeros(batch_size, num_capsules_i, num_capsules_j, 1).to(u_hat.device) # (B,N,M,1)
- for _ in range(num_iteratiobns):
- c_ij = F.softmax(b_ij, dim=1) # (B,N,M,D)
- s_j = torch.sum(c_ij * u_hat, dim=1, keepdim=True) # (B,1,M,D)
- v_j = squash(s_j) # (B,1,M,D)
- b_ij += torch.sum(u_hat * v_j, dim=-1, keepdim=True) # (B,N,M,1)
-
- return v_j.squeeze(1) # (B,M,D)
- class PrimaryCaps(nn.Module):
- def __init__(self, in_channels=256, out_channels=32, capsule_dim=8, kernel_size=9, stride=2):
- super(PrimaryCaps, self).__init__()
- self.capsule_dim = capsule_dim
- self.out_channels = out_channels
- # 使用卷积层来生成初级胶囊的输入(激活初始胶囊向量)
- self.conv2 = nn.Conv2d(in_channels, out_channels * capsule_dim, kernel_size=kernel_size, stride=stride)
-
- def forward(self, x):
- # x: (B, C, H, W)
- B = x.size(0)
- # 进行卷积操作,并将输出调整为胶囊向量的形式,然后整合所有胶囊向量
- x = self.conv2(x).permute(0, 2, 3, 1).reshape(B, -1, self.capsule_dim).contiguous() # (B, N, D)
- # 对每个胶囊的输出向量应用squash函数
- x = squash(x)
- return x
- class DigitCaps(nn.Module):
- def __init__(self, num_capsules=10, num_route_nodes=1152, in_channels=8, out_channels=16, num_iterations=3):
- """
- :param in_channels: 输入胶囊的维度
- :param out_channels: 输出胶囊的维度
- :param num_capsules: 输出的胶囊数量,对应数字类别数(通常为 10)
- :param num_iterations: 动态路由的迭代次数
- :param W: 权重矩阵,用于将输入胶囊映射到输出胶囊
- """
- super(DigitCaps, self).__init__()
- self.num_capsules = num_capsules
- self.num_iterations = num_iterations
- self.W = nn.Parameter(torch.randn(1, num_route_nodes, num_capsules, in_channels, out_channels))
- # 也可以所以胶囊共享一个Wj,性能并不会比前者差多少。
- # self.W = nn.Parameter(torch.randn(1, 1, num_capsules, in_channels, out_channels))
-
- def forward(self, x):
- # x: (B, N, D1)
- # 计算预测向量
- x = x.unsqueeze(-2).unsqueeze(-2) # (B, N, 1, 1, D1)
- u_hat = torch.matmul(x, self.W).squeeze(-2) # (B, N, M, D2)
- # 进行动态路由
- v = routing(u_hat, self.num_iterations) # (B, M, D2)
- # 返回输出胶囊向量的长度
- v = torch.norm(v, dim=-1) # (B, M)
- return v
- class CapesuleNet(nn.Module):
- def __init__(self, num_classes=10):
- super(CapesuleNet, self).__init__()
- self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1) # 灰度图只有一个原始维度
- self.primary_capsules = PrimaryCaps()
- self.digit_capsules = DigitCaps(num_capsules=num_classes)
- def forward(self, x):
- x = F.relu(self.conv1(x))
- x = self.primary_capsules(x)
- x = self.digit_capsules(x)
- return x
-
- class MarginLoss(nn.Module):
- def __init__(self, m_plus=0.9, m_minus=0.1, lambd=0.5):
- super(MarginLoss, self).__init__()
- self.m_plus = m_plus
- self.m_minus = m_minus
- self.lambd = lambd
- def forward(self, v, target):
- """
- v: 形状为 (batch_size, num_classes),v_k表示第k个数字胶囊的实例化向量的长度
- target: 形状为 (batch_size, num_classes),one-hot编码的目标标签
- """
- target = torch.eye(num_classes)[target.detach().cpu()].to(device) # one-hot编码
- left = torch.clamp(self.m_plus - v, min=0) ** 2
- right = torch.clamp(v - self.m_minus, min=0) ** 2
- loss = target * left + self.lambd * (1 - target) * right
- return torch.mean(torch.sum(loss, dim=1))
复制代码 3. 训练代码
- num_epochs=10
- num_classes=10
- batch_size=128
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
- model = CapesuleNet().to(device)
- margin_loss = MarginLoss()
- # margin_loss = nn.CrossEntropyLoss()
- optimizer = torch.optim.AdamW(model.parameters())
- train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
- test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
- @torch.no_grad()
- def estimate_acc():
- model.eval()
- acc = {}
- loader = test_loader
- correct = 0
- total = 0
- for images, labels in loader:
- images = images.to(device)
- labels = labels.to(device)
- outputs = model(images)
- _, predicted = torch.max(outputs.data, 1)
- total += labels.size(0)
- correct += (predicted == labels).sum().item()
- acc.update({'val': (correct / total) * 100})
- model.train()
- return acc
- for epoch in range(num_epochs):
- correct = 0
- total = 0
- with tqdm(total=len(train_loader), desc="epoch %d" % epoch) as pbar:
- for images, labels in train_loader:
- images = images.to(device)
- labels = labels.to(device)
- optimizer.zero_grad()
- outputs = model(images)
- loss = margin_loss(outputs, labels)
- loss.backward()
- optimizer.step()
- predicted = torch.argmax(outputs.data, dim=1)
- total += labels.size(0)
- correct += (predicted == labels).sum().item()
- # 更新进度条
- pbar.set_postfix({
- 'Loss': f'{loss.item():.4f}',
- 'Accuracy': f'{(correct / total) * 100:.4f}%(train)',
- })
- pbar.update(1)
- acc = estimate_acc()
- print(f"Accuracy: {acc['val']:.4f}%(val)")
复制代码
可以看到神经网络迅速地收敛,并在第7个epoch达到了论文里的性能上限99.23%!
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |