代码填空任务---自编码器模子

打印 上一主题 下一主题

主题 1023|帖子 1023|积分 3069

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
1.自编码器模子填空:
  1. import torch
  2. import matplotlib.pyplot as plt
  3. from torchvision import datasets, transforms
  4. from torch import nn, optim
  5. from torch.nn import functional as F
  6. from tqdm import tqdm
  7. import os
  8. # os.chdir(os.path.dirname(__file__))
  9. '模型结构'
  10. #损失函数
  11. #交叉熵,衡量各个像素原始数据与重构数据的误差
  12. #均方误差可作为交叉熵替代使用.衡量各个像素原始数据与重构数据的误差
  13. '超参数及构造模型'
  14. #模型参数
  15. #压缩后的特征维度
  16. #encoder和decoder中间层的维度
  17. #原始图片和生成图片的维度
  18. #训练参数
  19. #训练时期
  20. #每步训练样本数
  21. #学习率
  22. device =torch.device('cuda' if torch.cuda.is_available() else 'cpu')#训练设备
  23. #确定模型,导入已训练模型(如有)
  24. modelname = 'ae.pth'
  25. #模型初始化
  26. #优化器
  27. try:
  28.     model.load_state_dict(torch.load(modelname))
  29.     print('[INFO] Load Model complete')
  30. except:
  31.     pass
  32. '训练模型'
  33. #准备mnist数据集 (数据会下载到py文件所在的data文件夹下)
  34. train_loader = torch.utils.data.DataLoader(
  35.     datasets.MNIST('./', train=True, download=True,
  36.                    transform=transforms.ToTensor()),
  37.     batch_size=batch_size, shuffle=True)
  38. test_loader = torch.utils.data.DataLoader(
  39.     datasets.MNIST('./', train=False, transform=transforms.ToTensor()),
  40.     batch_size=batch_size, shuffle=False)
  41. #此方法获取的数据各像素值范围0-1
  42. #训练及测试
  43. loss_history = {'train':[],'eval':[]}
  44. for epoch in range(epochs):   
  45.     #训练
  46.     #每个epoch重置损失,设置进度条
  47.     train_loss = 0
  48.     train_nsample = 0
  49.     t = tqdm(train_loader,desc = f'[train]epoch:{epoch}')
  50.     for imgs, lbls in t: #imgs:(bs,28,28)
  51.         #获取数据
  52.         #imgs:(bs,28*28)
  53.         #模型运算     
  54.         
  55.         #计算损失
  56.         # 重构与原始数据的差距(也可使用loss_MSE)
  57.         #反向传播、参数优化,重置
  58.         
  59.         
  60.         
  61.         #计算平均损失,设置进度条
  62.         
  63.         
  64.         t.set_postfix({'loss':train_loss/train_nsample})
  65.     #每个epoch记录总损失
  66.     loss_history['train'].append(train_loss/train_nsample)
  67.     #测试
  68.     #每个epoch重置损失,设置进度条
  69.     test_loss = 0
  70.     test_nsample = 0
  71.     e = tqdm(test_loader,desc = f'[eval]epoch:{epoch}')
  72.     for imgs, label in e:
  73.         #获取数据
  74.         
  75.         #模型运算   
  76.         
  77.         #计算损失
  78.          
  79.         #计算平均损失,设置进度条
  80.         
  81.         
  82.         e.set_postfix({'loss':test_loss/test_nsample})
  83.     #每个epoch记录总损失   
  84.     loss_history['eval'].append(test_loss/test_nsample)
  85.     #展示效果   
  86.     #将测试步骤中的数据、重构数据绘图
  87.     concat = torch.cat((imgs[0].view(28, 28),
  88.             re_imgs[0].view( 28, 28)), 1)
  89.     plt.matshow(concat.cpu().detach().numpy())
  90.     plt.show()
  91.     #显示每个epoch的loss变化
  92.     plt.plot(range(epoch+1),loss_history['train'])
  93.     plt.plot(range(epoch+1),loss_history['eval'])
  94.     plt.show()
  95.     #存储模型
  96.     torch.save(model.state_dict(),modelname)
  97. '调用模型'
  98. #对数据集
  99. dataset = datasets.MNIST('./', train=False, transform=transforms.ToTensor())
  100. #取一组手写数据(正常数据)
  101. raw = dataset[0][0].view(1,-1) #raw: bs,28,28->bs,28*28
  102. #对手写数据(正常数据)重构
  103. re_raw = model(raw.to(device))
  104. #取一组随机数据(异常数据)
  105. rand = torch.randn_like(raw)
  106. #对随机数据(异常数据)重构
  107. re_rand = model(rand.to(device))
  108. #定义一个衡量标准,按像素平均所有原始数据和重构数据的误差
  109. f = lambda x,y: abs(x-y).mean()
  110. #正常数据 原始数据与重构数据差异
  111. print('正常数据:',f(re_raw.to("cpu"),raw))
  112. #异常数据 原始数据与重构数据差异
  113. print('异常数据:',f(re_rand.to("cpu"),rand))
  114. #正常数据,原始数据与重构数据作图
  115. plt.matshow(raw.view(28,28).detach().cpu().numpy())
  116. plt.show()
  117. plt.matshow(re_raw.view(28,28).detach().cpu().numpy())
  118. plt.show()
  119. #异常数据,原始数据与重构数据作图
  120. plt.matshow(rand.view(28,28).detach().cpu().numpy())
  121. plt.show()
  122. plt.matshow(re_rand.view(28,28).detach().cpu().numpy())
  123. plt.show()
复制代码
参考:
手写系列——AE网络、VAE网络和Condition VAE网络-CSDN博客

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

伤心客

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表