度学习pytorch实战六:ResNet50网络图像分类篇自建花数据集图像分类(5类) ...

打印 上一主题 下一主题

主题 785|帖子 785|积分 2355

1.数据集简介、训练集与测试集分别
2.模子相干知识
3.model.py——定义ResNet50网络模子
4.train.py——加载数据集并训练,训练集盘算丧失值loss,测试集盘算accuracy,生存训练好的网络参数
5.predict.py——使用训练好的网络参数后,用自己找的图像举行分类测试
一、数据集简介

1.自建数据文件夹
起首确定这次分类种类,采用爬虫、官网数据集和自己照相的照片获取5类,新建个文件夹data,内里包含5个文件夹,文件夹名字取种类英文,每个文件夹照片数目最好一样多,五百多张以上。如我选了雏菊,蒲公英,玫瑰,向日葵,郁金香5类,如下图,每种类型有600~900张图像。如下图
花数据集下载链接https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz

2.分别训练集与测试集
这是分别数据集代码,同一目次下运,复制改文件夹路径。
  1. import os
  2. from shutil import copy
  3. import random
  4. def mkfile(file):
  5.     if not os.path.exists(file):
  6.         os.makedirs(file)
  7. # 获取 photos 文件夹下除 .txt 文件以外所有文件夹名(即3种分类的类名)
  8. file_path = 'data/flower_photos'
  9. flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla]
  10. # 创建 训练集train 文件夹,并由3种类名在其目录下创建3个子目录
  11. mkfile('flower_data/train')
  12. for cla in flower_class:
  13.     mkfile('flower_data/train/' + cla)
  14. # 创建 验证集val 文件夹,并由3种类名在其目录下创建3个子目录
  15. mkfile('flower_data/val')
  16. for cla in flower_class:
  17.     mkfile('flower_data/val/' + cla)
  18. # 划分比例,训练集 : 验证集 = 9 : 1
  19. split_rate = 0.1
  20. # 遍历3种花的全部图像并按比例分成训练集和验证集
  21. for cla in flower_class:
  22.     cla_path = file_path + '/' + cla + '/'  # 某一类别动作的子目录
  23.     images = os.listdir(cla_path)  # iamges 列表存储了该目录下所有图像的名称
  24.     num = len(images)
  25.     eval_index = random.sample(images, k=int(num * split_rate))  # 从images列表中随机抽取 k 个图像名称
  26.     for index, image in enumerate(images):
  27.         # eval_index 中保存验证集val的图像名称
  28.         if image in eval_index:
  29.             image_path = cla_path + image
  30.             new_path = 'flower_data/val/' + cla
  31.             copy(image_path, new_path)  # 将选中的图像复制到新路径
  32.         # 其余的图像保存在训练集train中
  33.         else:
  34.             image_path = cla_path + image
  35.             new_path = 'flower_data/train/' + cla
  36.             copy(image_path, new_path)
  37.         print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")  # processing bar
  38.     print()
  39. print("processing done!")
复制代码
二、模子相干知识

之前有文章先容模子,如果不清晰可以点下链接转过去学习。
深度学习卷积神经网络CNN之ResNet模子网络详解说明(超具体理论篇)

三、model.py——定义ResNet50网络模子

  1. import torch.nn as nn
  2. import torch
  3. class BasicBlock(nn.Module):
  4.     expansion = 1
  5.     def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
  6.         super(BasicBlock, self).__init__()
  7.         self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
  8.                                kernel_size=3, stride=stride, padding=1, bias=False)
  9.         self.bn1 = nn.BatchNorm2d(out_channel)
  10.         self.relu = nn.ReLU()
  11.         self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
  12.                                kernel_size=3, stride=1, padding=1, bias=False)
  13.         self.bn2 = nn.BatchNorm2d(out_channel)
  14.         self.downsample = downsample
  15.     def forward(self, x):
  16.         identity = x
  17.         if self.downsample is not None:
  18.             identity = self.downsample(x)
  19.         out = self.conv1(x)
  20.         out = self.bn1(out)
  21.         out = self.relu(out)
  22.         out = self.conv2(out)
  23.         out = self.bn2(out)
  24.         out += identity
  25.         out = self.relu(out)
  26.         return out
  27. class Bottleneck(nn.Module):
  28.     """
  29.     注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
  30.     但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
  31.     这么做的好处是能够在top1上提升大概0.5%的准确率。
  32.     可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
  33.     """
  34.     expansion = 4
  35.     def __init__(self, in_channel, out_channel, stride=1, downsample=None,
  36.                  groups=1, width_per_group=64):
  37.         super(Bottleneck, self).__init__()
  38.         width = int(out_channel * (width_per_group / 64.)) * groups
  39.         self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
  40.                                kernel_size=1, stride=1, bias=False)  # squeeze channels
  41.         self.bn1 = nn.BatchNorm2d(width)
  42.         # -----------------------------------------
  43.         self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
  44.                                kernel_size=3, stride=stride, bias=False, padding=1)
  45.         self.bn2 = nn.BatchNorm2d(width)
  46.         # -----------------------------------------
  47.         self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
  48.                                kernel_size=1, stride=1, bias=False)  # unsqueeze channels
  49.         self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
  50.         self.relu = nn.ReLU(inplace=True)
  51.         self.downsample = downsample
  52.     def forward(self, x):
  53.         identity = x
  54.         if self.downsample is not None:
  55.             identity = self.downsample(x)
  56.         out = self.conv1(x)
  57.         out = self.bn1(out)
  58.         out = self.relu(out)
  59.         out = self.conv2(out)
  60.         out = self.bn2(out)
  61.         out = self.relu(out)
  62.         out = self.conv3(out)
  63.         out = self.bn3(out)
  64.         out += identity
  65.         out = self.relu(out)
  66.         return out
  67. class ResNet(nn.Module):
  68.     def __init__(self,
  69.                  block,
  70.                  blocks_num,
  71.                  num_classes=1000,
  72.                  include_top=True,
  73.                  groups=1,
  74.                  width_per_group=64):
  75.         super(ResNet, self).__init__()
  76.         self.include_top = include_top
  77.         self.in_channel = 64
  78.         self.groups = groups
  79.         self.width_per_group = width_per_group
  80.         self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
  81.                                padding=3, bias=False)
  82.         self.bn1 = nn.BatchNorm2d(self.in_channel)
  83.         self.relu = nn.ReLU(inplace=True)
  84.         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  85.         self.layer1 = self._make_layer(block, 64, blocks_num[0])
  86.         self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
  87.         self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
  88.         self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
  89.         if self.include_top:
  90.             self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
  91.             self.fc = nn.Linear(512 * block.expansion, num_classes)
  92.         for m in self.modules():
  93.             if isinstance(m, nn.Conv2d):
  94.                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  95.     def _make_layer(self, block, channel, block_num, stride=1):
  96.         downsample = None
  97.         if stride != 1 or self.in_channel != channel * block.expansion:
  98.             downsample = nn.Sequential(
  99.                 nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
  100.                 nn.BatchNorm2d(channel * block.expansion))
  101.         layers = []
  102.         layers.append(block(self.in_channel,
  103.                             channel,
  104.                             downsample=downsample,
  105.                             stride=stride,
  106.                             groups=self.groups,
  107.                             width_per_group=self.width_per_group))
  108.         self.in_channel = channel * block.expansion
  109.         for _ in range(1, block_num):
  110.             layers.append(block(self.in_channel,
  111.                                 channel,
  112.                                 groups=self.groups,
  113.                                 width_per_group=self.width_per_group))
  114.         return nn.Sequential(*layers)
  115.     def forward(self, x):
  116.         x = self.conv1(x)
  117.         x = self.bn1(x)
  118.         x = self.relu(x)
  119.         x = self.maxpool(x)
  120.         x = self.layer1(x)
  121.         x = self.layer2(x)
  122.         x = self.layer3(x)
  123.         x = self.layer4(x)
  124.         if self.include_top:
  125.             x = self.avgpool(x)
  126.             x = torch.flatten(x, 1)
  127.             x = self.fc(x)
  128.         return x
  129. def resnet34(num_classes=1000, include_top=True):
  130.     # https://download.pytorch.org/models/resnet34-333f7ec4.pth
  131.     return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
  132. def resnet50(num_classes=1000, include_top=True):
  133.     # https://download.pytorch.org/models/resnet50-19c8e357.pth
  134.     return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
  135. def resnet101(num_classes=1000, include_top=True):
  136.     # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
  137.     return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
  138. def resnext50_32x4d(num_classes=1000, include_top=True):
  139.     # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
  140.     groups = 32
  141.     width_per_group = 4
  142.     return ResNet(Bottleneck, [3, 4, 6, 3],
  143.                   num_classes=num_classes,
  144.                   include_top=include_top,
  145.                   groups=groups,
  146.                   width_per_group=width_per_group)
  147. def resnext101_32x8d(num_classes=1000, include_top=True):
  148.     # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
  149.     groups = 32
  150.     width_per_group = 8
  151.     return ResNet(Bottleneck, [3, 4, 23, 3],
  152.                   num_classes=num_classes,
  153.                   include_top=include_top,
  154.                   groups=groups,
  155.                   width_per_group=width_per_group)
复制代码
四、model.py——定义ResNet34网络模子

batch_size = 16
epochs = 5
  1. import os
  2. import sys
  3. import json
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. from torchvision import transforms, datasets
  8. from tqdm import tqdm
  9. from model import resnet50
  10. def main():
  11.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  12.     print("using {} device.".format(device))
  13.     data_transform = {
  14.         "train": transforms.Compose([transforms.RandomResizedCrop(224),
  15.                                      transforms.RandomHorizontalFlip(),
  16.                                      transforms.ToTensor(),
  17.                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
  18.         "val": transforms.Compose([transforms.Resize(256),
  19.                                    transforms.CenterCrop(224),
  20.                                    transforms.ToTensor(),
  21.                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
  22.     data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
  23.     image_path = os.path.join(data_root, "zjdata", "flower_data")  # flower data set path
  24.     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
  25.     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
  26.                                          transform=data_transform["train"])
  27.     train_num = len(train_dataset)
  28.     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
  29.     flower_list = train_dataset.class_to_idx
  30.     cla_dict = dict((val, key) for key, val in flower_list.items())
  31.     # write dict into json file
  32.     json_str = json.dumps(cla_dict, indent=4)
  33.     with open('class_indices.json', 'w') as json_file:
  34.         json_file.write(json_str)
  35.     batch_size = 16
  36.     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
  37.     print('Using {} dataloader workers every process'.format(nw))
  38.     train_loader = torch.utils.data.DataLoader(train_dataset,
  39.                                                batch_size=batch_size, shuffle=True,
  40.                                                num_workers=0)
  41.     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
  42.                                             transform=data_transform["val"])
  43.     val_num = len(validate_dataset)
  44.     validate_loader = torch.utils.data.DataLoader(validate_dataset,
  45.                                                   batch_size=batch_size, shuffle=False,
  46.                                                   num_workers=nw)
  47.     print("using {} images for training, {} images for validation.".format(train_num,
  48.                                                                            val_num))
  49.    
  50.     net = resnet50(num_classes=5, include_top=True)
  51.     net.to(device)
  52.     # define loss function
  53.     loss_function = nn.CrossEntropyLoss()
  54.     # construct an optimizer
  55.     params = [p for p in net.parameters() if p.requires_grad]
  56.     optimizer = optim.Adam(params, lr=0.1)
  57.     epochs = 5
  58.     best_acc = 0.0
  59.     save_path = './resNet50.pth'
  60.     train_steps = len(train_loader)
  61.     for epoch in range(epochs):
  62.         # train
  63.         net.train()
  64.         running_loss = 0.0
  65.         train_bar = tqdm(train_loader, file=sys.stdout)
  66.         for step, data in enumerate(train_bar):
  67.             images, labels = data
  68.             optimizer.zero_grad()
  69.             logits = net(images.to(device))
  70.             loss = loss_function(logits, labels.to(device))
  71.             loss.backward()
  72.             optimizer.step()
  73.             # print statistics
  74.             running_loss += loss.item()
  75.             train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
  76.                                                                      epochs,
  77.                                                                      loss)
  78.         # validate
  79.         net.eval()
  80.         acc = 0.0  # accumulate accurate number / epoch
  81.         with torch.no_grad():
  82.             val_bar = tqdm(validate_loader, file=sys.stdout)
  83.             for val_data in val_bar:
  84.                 val_images, val_labels = val_data
  85.                 outputs = net(val_images.to(device))
  86.                 # loss = loss_function(outputs, test_labels)
  87.                 predict_y = torch.max(outputs, dim=1)[1]
  88.                 acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
  89.                 val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
  90.                                                            epochs)
  91.         val_accurate = acc / val_num
  92.         print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
  93.               (epoch + 1, running_loss / train_steps, val_accurate))
  94.         if val_accurate > best_acc:
  95.             best_acc = val_accurate
  96.             torch.save(net.state_dict(), save_path)
  97.     print('Finished Training')
  98. if __name__ == '__main__':
  99.     main()
复制代码
训练中截图

五、predict.py——使用训练好的网络参数后,用自己找的图像举行分类测试

  1. import os
  2. import json
  3. import torch
  4. from PIL import Image
  5. from torchvision import transforms
  6. import matplotlib.pyplot as plt
  7. from model import resnet34
  8. def main():
  9.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  10.     data_transform = transforms.Compose(
  11.         [transforms.Resize(256),
  12.          transforms.CenterCrop(224),
  13.          transforms.ToTensor(),
  14.          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
  15.     # load image
  16.     img_path = "./1.jpg"
  17.     assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
  18.     img = Image.open(img_path)
  19.     plt.imshow(img)
  20.     # [N, C, H, W]
  21.     img = data_transform(img)
  22.     # expand batch dimension
  23.     img = torch.unsqueeze(img, dim=0)
  24.     # read class_indict
  25.     json_path = './class_indices.json'
  26.     assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
  27.     with open(json_path, "r") as f:
  28.         class_indict = json.load(f)
  29.     # create model
  30.     model = resnet34(num_classes=5).to(device)
  31.     # load model weights
  32.     weights_path = "./resNet50.pth"
  33.     assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
  34.     model.load_state_dict(torch.load(weights_path, map_location=device))
  35.     # prediction
  36.     model.eval()
  37.     with torch.no_grad():
  38.         # predict class
  39.         output = torch.squeeze(model(img.to(device))).cpu()
  40.         predict = torch.softmax(output, dim=0)
  41.         predict_cla = torch.argmax(predict).numpy()
  42.     print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
  43.                                                  predict[predict_cla].numpy())
  44.     plt.title(print_res)
  45.     for i in range(len(predict)):
  46.         print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
  47.                                                   predict[i].numpy()))
  48.     plt.show()
  49. if __name__ == '__main__':
  50.     main()
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

温锦文欧普厨电及净水器总代理

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表