利用pytorch对加噪堆叠自编码器在MNIST数据集举行训练和验证 ...

打印 上一主题 下一主题

主题 1834|帖子 1834|积分 5502

实现背景:

最近在复现关于使用深度学习提取特征来举行聚类的论文,其中使用到了加噪堆叠自编码器,详细实现细节请参考论文:Improved Deep Embedded Clustering with Local Structure Preservation
其中加噪堆叠自编码器涉及两个过程:
预训练:预训练过程对原始数据加噪,贪心式地对每一层encoder和decoder举行训练,其中训练新的AE时冻结前面训练好的AE。详见:堆栈自编码器 Stacked AutoEncoder-CSDN博客
微调:在预训练完成之后使用所有AE和原始数据对整体模型举行微调。
我在网上找到了一个SAE的示范样例:python-pytorch 利用pytorch对堆叠自编码器举行训练和验证_pytoch把训练和验证写一起的代码-CSDN博客
但是这篇博客的数据集很小,如果应用到MNIST数据集时显存很容易溢出,因此我在原始的基础上举行了改进,直接上代码:
初始化数据集:

  1. import torch
  2. from torchvision import datasets, transforms
  3. from torch.utils.data import Dataset, random_split
  4. # 定义数据预处理
  5. transform = transforms.Compose([
  6.     transforms.ToTensor(),
  7.     transforms.Normalize((0.1307,), (0.3081,))
  8. ])
  9. # 下载并加载 MNIST 训练数据集
  10. original_dataset = datasets.MNIST(root='./data', train=True,
  11.                                   download=False, transform=transform)
  12. class NoLabelDataset(Dataset):
  13.     def __init__(self, original_dataset):
  14.         self.original_dataset = original_dataset
  15.     def __getitem__(self, index):
  16.         image, _ = self.original_dataset[index]
  17.         return image
  18.     def __len__(self):
  19.         return len(self.original_dataset)
  20. # 创建不包含标签的数据集
  21. no_label_dataset = NoLabelDataset(original_dataset)
  22. # 划分训练集和验证集
  23. train_size = int(0.8 * len(no_label_dataset))
  24. val_size = len(no_label_dataset) - train_size
  25. train_dataset, val_dataset = random_split(no_label_dataset, [train_size, val_size])
  26. # 创建数据加载器
  27. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
  28. val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
  29. print(f"训练集样本数量: {len(train_dataset)}")
  30. print(f"验证集样本数量: {len(val_dataset)}")   
  31. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
复制代码
 定义模型和训练函数:

  1. import torch.nn as nn
  2. class Autoencoder(nn.Module):
  3.     def __init__(self, input_size, hidden_size):
  4.         super(Autoencoder, self).__init__()
  5.         self.encoder = nn.Sequential(
  6.             nn.Conv2d(input_size, hidden_size, kernel_size=3, stride=1, padding=1),  # 输入通道1,输出通道16
  7.             nn.ReLU())
  8.         self.decoder = nn.Sequential(
  9.             nn.ConvTranspose2d(hidden_size, input_size, kernel_size=3, stride=1, padding=1),
  10.             nn.ReLU())
  11.     def forward(self, x):
  12.         x = self.encoder(x)
  13.         x = self.decoder(x)
  14.         return x
  15. def train_ae(models, train_loader, val_loader, num_epochs, criterion, optimizer, noise_factor, finetune):
  16.     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  17.     for epoch in range(num_epochs):
  18.         # Training
  19.         models[-1].train()
  20.         train_loss = 0
  21.         for batch_data in train_loader:
  22.             optimizer.zero_grad()
  23.             if len(models) != 1:
  24.                 batch_data = batch_data.to(device)
  25.                 for model in models[:-1]:
  26.                     with torch.no_grad():
  27.                         batch_data = model.encoder(batch_data)
  28.                 batch_data = batch_data.detach()
  29.             if finetune == True:
  30.                 batch_data = batch_data.to(device)
  31.                 outputs = models[-1](batch_data)
  32.                 loss = criterion(outputs, batch_data)
  33.             else:
  34.                 noisy_image = batch_data + noise_factor * torch.randn_like(batch_data)
  35.                 noisy_image = torch.clamp(noisy_image, 0., 1.).to(device)
  36.                 outputs = models[-1](noisy_image)
  37.                 batch_data = batch_data.to(device)
  38.                 loss = criterion(outputs, batch_data)
  39.             loss.backward()
  40.             optimizer.step()
  41.             train_loss += loss.item()
  42.         
  43.         train_loss /= len(train_loader)
  44.         print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}")
  45.         # Validation
  46.         models[-1].eval()
  47.         val_loss = 0
  48.         with torch.no_grad():
  49.             for batch_data in val_loader:
  50.                 if len(models) != 1:
  51.                     batch_data = batch_data.to(device)
  52.                     for model in models[:-1]:
  53.                         batch_data = model.encoder(batch_data)
  54.                     batch_data = batch_data.detach()
  55.                 if finetune == True:
  56.                     batch_data = batch_data.to(device)
  57.                     outputs = models[-1](batch_data)
  58.                     loss = criterion(outputs, batch_data)
  59.                 else:
  60.                     noisy_image = batch_data + noise_factor * torch.randn_like(batch_data)
  61.                     noisy_image = torch.clamp(noisy_image, 0., 1.).to(device)
  62.                     outputs = models[-1](noisy_image)
  63.                     batch_data = batch_data.to(device)
  64.                     loss = criterion(outputs, batch_data)
  65.                 val_loss += loss.item()
  66.         val_loss /= len(val_loader)
  67.         print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}")
复制代码
 模型训练以及微调:

  1. batch_size = 16
  2. noise_factor = 0.4
  3. ae1 = Autoencoder(input_size=1, hidden_size=16).to(device)
  4. optimizer = torch.optim.Adam(ae1.parameters(), lr=0.001)
  5. criterion = nn.MSELoss()
  6. train_ae([ae1], train_loader, val_loader, 10, criterion, optimizer, noise_factor, finetune = False)
  7. ae2 = Autoencoder(input_size=16, hidden_size=64).to(device)
  8. optimizer = torch.optim.Adam(ae2.parameters(), lr=0.001)
  9. train_ae([ae1, ae2], train_loader, val_loader, 10, criterion, optimizer, noise_factor, finetune = False)
  10. ae3 = Autoencoder(input_size=64, hidden_size=128).to(device)
  11. optimizer = torch.optim.Adam(ae3.parameters(), lr=0.001)
  12. train_ae([ae1, ae2, ae3], train_loader, val_loader, 10, criterion, optimizer, noise_factor, finetune = False)
  13. class StackedAutoencoder(nn.Module):
  14.     def __init__(self, ae1, ae2, ae3):
  15.         super(StackedAutoencoder, self).__init__()
  16.         self.encoder = nn.Sequential(ae1.encoder, ae2.encoder, ae3.encoder)
  17.         self.decoder = nn.Sequential(ae3.decoder, ae2.decoder, ae1.decoder)
  18.     def forward(self, x):
  19.         x = self.encoder(x)
  20.         x = self.decoder(x)
  21.         return x
  22. sae = StackedAutoencoder(ae1, ae2, ae3)
  23. optimizer = torch.optim.Adam(ae1.parameters(), lr=0.001)
  24. criterion = nn.MSELoss()
  25. train_ae([sae], train_loader, val_loader, 10, criterion, optimizer, noise_factor, finetune = True)
复制代码
 结果可视化:

  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. dataiter = iter(val_loader)
  4. image = next(dataiter)[1]
  5. print(image.shape)
  6. image = image.to(device)
  7. # 通过自编码器模型进行前向传播
  8. with torch.no_grad():
  9.     output = sae(image)
  10. noise_factor = 0.4
  11. noisy_image = image + noise_factor * torch.randn_like(image)
  12. noisy_image = torch.clamp(noisy_image, 0., 1.).cpu().numpy()
  13. # 将张量转换为 numpy 数组以便可视化
  14. image = image.cpu().numpy()
  15. output = output.cpu().numpy()
  16. # 定义一个函数来显示图片
  17. def imshow(img):
  18.     img = img * 0.3081 + 0.1307  # 反归一化
  19.     npimg = img.squeeze()  # 去除单维度
  20.     plt.imshow(npimg, cmap='gray')
  21. # 可视化输入和输出图片
  22. plt.figure(figsize=(10, 5))
  23. # 显示输入图像
  24. plt.subplot(1, 3, 1)
  25. imshow(torch.from_numpy(image))
  26. plt.title('Input Image')
  27. plt.axis('off')
  28. # 显示加噪图像
  29. plt.subplot(1, 3, 2)
  30. imshow(torch.from_numpy(noisy_image))
  31. plt.title('Noisy Image')
  32. plt.axis('off')
  33. # 显示输出图像
  34. plt.subplot(1, 3, 3)
  35. imshow(torch.from_numpy(output))
  36. plt.title('Output Image')
  37. plt.axis('off')
  38. plt.savefig("预训练图.svg", dpi=300,format="svg")
复制代码
 下面附上训练好的可视化图:


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

北冰洋以北

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表