马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
x
1.自编码器模子填空:
- import torch
- import matplotlib.pyplot as plt
- from torchvision import datasets, transforms
- from torch import nn, optim
- from torch.nn import functional as F
- from tqdm import tqdm
- import os
-
- # os.chdir(os.path.dirname(__file__))
-
- '模型结构'
-
- #损失函数
- #交叉熵,衡量各个像素原始数据与重构数据的误差
- #均方误差可作为交叉熵替代使用.衡量各个像素原始数据与重构数据的误差
-
- '超参数及构造模型'
- #模型参数
- #压缩后的特征维度
- #encoder和decoder中间层的维度
- #原始图片和生成图片的维度
-
- #训练参数
- #训练时期
- #每步训练样本数
- #学习率
- device =torch.device('cuda' if torch.cuda.is_available() else 'cpu')#训练设备
-
- #确定模型,导入已训练模型(如有)
- modelname = 'ae.pth'
- #模型初始化
- #优化器
- try:
- model.load_state_dict(torch.load(modelname))
- print('[INFO] Load Model complete')
- except:
- pass
-
- '训练模型'
- #准备mnist数据集 (数据会下载到py文件所在的data文件夹下)
- train_loader = torch.utils.data.DataLoader(
- datasets.MNIST('./', train=True, download=True,
- transform=transforms.ToTensor()),
- batch_size=batch_size, shuffle=True)
- test_loader = torch.utils.data.DataLoader(
- datasets.MNIST('./', train=False, transform=transforms.ToTensor()),
- batch_size=batch_size, shuffle=False)
- #此方法获取的数据各像素值范围0-1
-
- #训练及测试
- loss_history = {'train':[],'eval':[]}
- for epoch in range(epochs):
- #训练
- #每个epoch重置损失,设置进度条
- train_loss = 0
- train_nsample = 0
- t = tqdm(train_loader,desc = f'[train]epoch:{epoch}')
- for imgs, lbls in t: #imgs:(bs,28,28)
- #获取数据
- #imgs:(bs,28*28)
- #模型运算
-
- #计算损失
- # 重构与原始数据的差距(也可使用loss_MSE)
- #反向传播、参数优化,重置
-
-
-
- #计算平均损失,设置进度条
-
-
- t.set_postfix({'loss':train_loss/train_nsample})
- #每个epoch记录总损失
- loss_history['train'].append(train_loss/train_nsample)
-
- #测试
- #每个epoch重置损失,设置进度条
- test_loss = 0
- test_nsample = 0
- e = tqdm(test_loader,desc = f'[eval]epoch:{epoch}')
- for imgs, label in e:
- #获取数据
-
- #模型运算
-
- #计算损失
-
- #计算平均损失,设置进度条
-
-
- e.set_postfix({'loss':test_loss/test_nsample})
- #每个epoch记录总损失
- loss_history['eval'].append(test_loss/test_nsample)
-
-
- #展示效果
- #将测试步骤中的数据、重构数据绘图
- concat = torch.cat((imgs[0].view(28, 28),
- re_imgs[0].view( 28, 28)), 1)
- plt.matshow(concat.cpu().detach().numpy())
- plt.show()
-
- #显示每个epoch的loss变化
- plt.plot(range(epoch+1),loss_history['train'])
- plt.plot(range(epoch+1),loss_history['eval'])
- plt.show()
- #存储模型
- torch.save(model.state_dict(),modelname)
-
- '调用模型'
- #对数据集
- dataset = datasets.MNIST('./', train=False, transform=transforms.ToTensor())
- #取一组手写数据(正常数据)
- raw = dataset[0][0].view(1,-1) #raw: bs,28,28->bs,28*28
- #对手写数据(正常数据)重构
- re_raw = model(raw.to(device))
- #取一组随机数据(异常数据)
- rand = torch.randn_like(raw)
- #对随机数据(异常数据)重构
- re_rand = model(rand.to(device))
-
- #定义一个衡量标准,按像素平均所有原始数据和重构数据的误差
- f = lambda x,y: abs(x-y).mean()
- #正常数据 原始数据与重构数据差异
- print('正常数据:',f(re_raw.to("cpu"),raw))
- #异常数据 原始数据与重构数据差异
- print('异常数据:',f(re_rand.to("cpu"),rand))
-
- #正常数据,原始数据与重构数据作图
- plt.matshow(raw.view(28,28).detach().cpu().numpy())
- plt.show()
- plt.matshow(re_raw.view(28,28).detach().cpu().numpy())
- plt.show()
- #异常数据,原始数据与重构数据作图
- plt.matshow(rand.view(28,28).detach().cpu().numpy())
- plt.show()
- plt.matshow(re_rand.view(28,28).detach().cpu().numpy())
- plt.show()
复制代码 参考:
手写系列——AE网络、VAE网络和Condition VAE网络-CSDN博客
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |