起首要明确一点,我们在编写模型、训练和使用模型的时间通常都是分开的,所以应该把Module的编写以及train方法和test方法分开编写。
调用gpu进行训练:在网络模型,数据,丧失函数对象背面都使用.cuda()方法,如loss_fn = loss_fn.cuda()
【代码示例】完成完整CIFAR10模型的训练
按照官网给出的模型结构进行构建:
- # model.py
- class myModule(nn.Module):
- def __init__(self):
- super().__init__()
- self.model = nn.Sequential(
- nn.Conv2d(3, 32, 5, 1, 2),
- nn.MaxPool2d(2),
- nn.Conv2d(32, 32, 5, 1, 2),
- nn.MaxPool2d(2),
- nn.Conv2d(32, 64, 5, 1, 2),
- nn.MaxPool2d(2),
- nn.Flatten(),
- nn.Linear(64*4*4, 64),
- nn.Linear(64, 10)
- )
- def forward(self, ingput):
- output = self.model(ingput)
- return output
复制代码 导入自己创建的模型,实例化一个模型对象之后,导入CIFAR10数据集进行训练
- # train.py
- import torchvision
- from torch.utils.tensorboard import SummaryWriter
- from module import *
- from torch import nn
- from torch.utils.data import DataLoader
- # 使用Dataset来下载数据集
- train_data = torchvision.datasets.CIFAR10(root="dataset/CIFAR10", train=True, transform=torchvision.transforms.ToTensor(),
- download=True)
- test_data = torchvision.datasets.CIFAR10(root="dataset/CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
- download=True)
- # 数据集长度
- train_data_size = len(train_data)
- test_data_size = len(test_data)
- print("训练数据集的长度为:{}".format(train_data_size))
- print("测试数据集的长度为:{}".format(test_data_size))
- # 利用 DataLoader 来加载数据集
- train_dataloader = DataLoader(train_data, batch_size=64)
- test_dataloader = DataLoader(test_data, batch_size=64)
- # 创建网络模型,实例化自定义的模型
- mymodule = myModule()
- if torch.cuda.is_available():
- mymodule = mymodule.cuda()
- # 定义损失函数为交叉熵损失函数
- loss_fn = nn.CrossEntropyLoss()
- if torch.cuda.is_available():
- loss_fn = loss_fn.cuda()
- # 优化器
- learning_rate = 0.01
- optimizer = torch.optim.SGD(mymodule.parameters(), lr=learning_rate)
- # 设置训练网络的一些参数
- # 记录训练的次数
- total_train_step = 0
- # 记录测试的次数
- total_test_step = 0
- # 训练的轮数
- epoch = 10
- # tensorboard配置日志目录
- writer = SummaryWriter("logs_train")
- for i in range(epoch):
- print("-------第 {} 轮训练开始-------".format(i+1))
- # 训练步骤开始
- mymodule.train()
- for data in train_dataloader:
- imgs, targets = data
- if torch.cuda.is_available():
- imgs = imgs.cuda()
- targets = targets.cuda()
- outputs = mymodule(imgs)
- loss = loss_fn(outputs, targets)
- # 优化器优化模型
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- total_train_step = total_train_step + 1 # 每读取一次图片+1
- if total_train_step % 100 == 0:
- print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))
- writer.add_scalar("train_loss", loss.item(), total_train_step)
- # 测试步骤开始
- mymodule.eval()
- total_test_loss = 0 # 损失函数值
- total_accuracy = 0 # 准确率
- with torch.no_grad():
- for data in test_dataloader:
- imgs, targets = data
- if torch.cuda.is_available():
- imgs = imgs.cuda()
- targets = targets.cuda()
- outputs = mymodule(imgs)
- loss = loss_fn(outputs, targets)
- total_test_loss = total_test_loss + loss.item()
- accuracy = (outputs.argmax(1) == targets).sum()
- total_accuracy = total_accuracy + accuracy
- print("整体测试集上的Loss: {}".format(total_test_loss))
- print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))
- writer.add_scalar("test_loss", total_test_loss, total_test_step)
- writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
- total_test_step = total_test_step + 1
- # 每轮都保存模型
- torch.save(mymodule, "mymodule{}.pth".format(i))
- print("模型已保存")
- writer.close()
复制代码- # test.py
- import torch
- import torchvision
- from PIL import Image
- from torch import nn
- image_path = "imgs/airplane.png"
- image = Image.open(image_path)
- print(image)
- image = image.convert('RGB')
- transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
- torchvision.transforms.ToTensor()])
- image = transform(image)
- print(image.shape)
- class Tudui(nn.Module):
- def __init__(self):
- super(Tudui, self).__init__()
- self.model = nn.Sequential(
- nn.Conv2d(3, 32, 5, 1, 2),
- nn.MaxPool2d(2),
- nn.Conv2d(32, 32, 5, 1, 2),
- nn.MaxPool2d(2),
- nn.Conv2d(32, 64, 5, 1, 2),
- nn.MaxPool2d(2),
- nn.Flatten(),
- nn.Linear(64*4*4, 64),
- nn.Linear(64, 10)
- )
- def forward(self, x):
- x = self.model(x)
- return x
- model = torch.load("mymodule9.pth", map_location=torch.device('cpu'))
- print(model)
- image = torch.reshape(image, (1, 3, 32, 32))
- model.eval()
- with torch.no_grad():
- output = model(image)
- print(output)
- print(output.argmax(1))
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |