【PyTorch】(基础七)---- 完整训练流程

打印 上一主题 下一主题

主题 1898|帖子 1898|积分 5694

起首要明确一点,我们在编写模型、训练和使用模型的时间通常都是分开的,所以应该把Module的编写以及train方法和test方法分开编写。
调用gpu进行训练:在网络模型,数据,丧失函数对象背面都使用.cuda()方法,如loss_fn = loss_fn.cuda()
【代码示例】完成完整CIFAR10模型的训练
按照官网给出的模型结构进行构建:

  1. # model.py
  2. class myModule(nn.Module):
  3.     def __init__(self):
  4.         super().__init__()
  5.         self.model = nn.Sequential(
  6.             nn.Conv2d(3, 32, 5, 1, 2),
  7.             nn.MaxPool2d(2),
  8.             nn.Conv2d(32, 32, 5, 1, 2),
  9.             nn.MaxPool2d(2),
  10.             nn.Conv2d(32, 64, 5, 1, 2),
  11.             nn.MaxPool2d(2),
  12.             nn.Flatten(),
  13.             nn.Linear(64*4*4, 64),
  14.             nn.Linear(64, 10)
  15.         )
  16.     def forward(self, ingput):
  17.         output = self.model(ingput)
  18.         return output
复制代码
导入自己创建的模型,实例化一个模型对象之后,导入CIFAR10数据集进行训练
  1. # train.py
  2. import torchvision
  3. from torch.utils.tensorboard import SummaryWriter
  4. from module import *
  5. from torch import nn
  6. from torch.utils.data import DataLoader
  7. # 使用Dataset来下载数据集
  8. train_data = torchvision.datasets.CIFAR10(root="dataset/CIFAR10", train=True, transform=torchvision.transforms.ToTensor(),
  9.                                           download=True)
  10. test_data = torchvision.datasets.CIFAR10(root="dataset/CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
  11.                                          download=True)
  12. # 数据集长度
  13. train_data_size = len(train_data)
  14. test_data_size = len(test_data)
  15. print("训练数据集的长度为:{}".format(train_data_size))
  16. print("测试数据集的长度为:{}".format(test_data_size))
  17. # 利用 DataLoader 来加载数据集
  18. train_dataloader = DataLoader(train_data, batch_size=64)
  19. test_dataloader = DataLoader(test_data, batch_size=64)
  20. # 创建网络模型,实例化自定义的模型
  21. mymodule = myModule()
  22. if torch.cuda.is_available():
  23.     mymodule = mymodule.cuda()
  24. # 定义损失函数为交叉熵损失函数
  25. loss_fn = nn.CrossEntropyLoss()
  26. if torch.cuda.is_available():
  27.     loss_fn = loss_fn.cuda()
  28. # 优化器
  29. learning_rate = 0.01
  30. optimizer = torch.optim.SGD(mymodule.parameters(), lr=learning_rate)
  31. # 设置训练网络的一些参数
  32. # 记录训练的次数
  33. total_train_step = 0
  34. # 记录测试的次数
  35. total_test_step = 0
  36. # 训练的轮数
  37. epoch = 10
  38. # tensorboard配置日志目录
  39. writer = SummaryWriter("logs_train")
  40. for i in range(epoch):
  41.     print("-------第 {} 轮训练开始-------".format(i+1))
  42.     # 训练步骤开始
  43.     mymodule.train()
  44.     for data in train_dataloader:
  45.         imgs, targets = data
  46.         if torch.cuda.is_available():
  47.             imgs = imgs.cuda()
  48.             targets = targets.cuda()
  49.         outputs = mymodule(imgs)
  50.         loss = loss_fn(outputs, targets)
  51.         # 优化器优化模型
  52.         optimizer.zero_grad()
  53.         loss.backward()
  54.         optimizer.step()
  55.         total_train_step = total_train_step + 1   # 每读取一次图片+1
  56.         if total_train_step % 100 == 0:
  57.             print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))
  58.             writer.add_scalar("train_loss", loss.item(), total_train_step)
  59.     # 测试步骤开始
  60.     mymodule.eval()
  61.     total_test_loss = 0    # 损失函数值
  62.     total_accuracy = 0  # 准确率
  63.     with torch.no_grad():
  64.         for data in test_dataloader:
  65.             imgs, targets = data
  66.             if torch.cuda.is_available():
  67.                 imgs = imgs.cuda()
  68.                 targets = targets.cuda()
  69.             outputs = mymodule(imgs)
  70.             loss = loss_fn(outputs, targets)
  71.             total_test_loss = total_test_loss + loss.item()
  72.             accuracy = (outputs.argmax(1) == targets).sum()
  73.             total_accuracy = total_accuracy + accuracy
  74.     print("整体测试集上的Loss: {}".format(total_test_loss))
  75.     print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))
  76.     writer.add_scalar("test_loss", total_test_loss, total_test_step)
  77.     writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
  78.     total_test_step = total_test_step + 1
  79.     # 每轮都保存模型
  80.     torch.save(mymodule, "mymodule{}.pth".format(i))
  81.     print("模型已保存")
  82. writer.close()
复制代码
  1. # test.py
  2. import torch
  3. import torchvision
  4. from PIL import Image
  5. from torch import nn
  6. image_path = "imgs/airplane.png"
  7. image = Image.open(image_path)
  8. print(image)
  9. image = image.convert('RGB')
  10. transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
  11.                                             torchvision.transforms.ToTensor()])
  12. image = transform(image)
  13. print(image.shape)
  14. class Tudui(nn.Module):
  15.     def __init__(self):
  16.         super(Tudui, self).__init__()
  17.         self.model = nn.Sequential(
  18.             nn.Conv2d(3, 32, 5, 1, 2),
  19.             nn.MaxPool2d(2),
  20.             nn.Conv2d(32, 32, 5, 1, 2),
  21.             nn.MaxPool2d(2),
  22.             nn.Conv2d(32, 64, 5, 1, 2),
  23.             nn.MaxPool2d(2),
  24.             nn.Flatten(),
  25.             nn.Linear(64*4*4, 64),
  26.             nn.Linear(64, 10)
  27.         )
  28.     def forward(self, x):
  29.         x = self.model(x)
  30.         return x
  31. model = torch.load("mymodule9.pth", map_location=torch.device('cpu'))
  32. print(model)
  33. image = torch.reshape(image, (1, 3, 32, 32))
  34. model.eval()
  35. with torch.no_grad():
  36.     output = model(image)
  37. print(output)
  38. print(output.argmax(1))
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

杀鸡焉用牛刀

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表