Pytorch | 从零构建ResNet对CIFAR10举行分类

打印 上一主题 下一主题

主题 906|帖子 906|积分 2718

前面文章我们构建了AlexNet、Vgg、GoogleNet对CIFAR10举行分类:
Pytorch | 从零构建AlexNet对CIFAR10举行分类
Pytorch | 从零构建Vgg对CIFAR10举行分类
Pytorch | 从零构建GoogleNet对CIFAR10举行分类
这篇文章我们来构建ResNet.
CIFAR10数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)网络整理的用于图像辨认研究的常用数据集,根本信息如下:


  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的种别,每个种别有6,000张图像。通常将其中50,000张作为练习集,用于模型的练习;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:全部图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时可以或许相对快速地举行练习和推理,但也增加了图像分类的难度。
  • 种别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的种别,这些种别都是现实天下中常见的物体,具有一定的代表性。
下面是一些示例样本:

ResNet

ResNet(Residual Network)即残差网络,是由微软研究院的何恺明等人在2015年提出的一种深度卷积神经网络架构,它在ILSVRC 2015图像辨认挑战赛中取得了优异成绩,在图像分类、目标检测、语义分割等盘算机视觉任务中具有广泛应用。以下是对ResNet的详细介绍:
焦点思想



  • 解决梯度消失和退化问题:随着神经网络层数的增加,会出现梯度消失或梯度爆炸问题,导致模型难以练习。同时,还会出现网络退化现象,即增加网络层数后,正确率反而降落。ResNet的焦点思想是引入残差毗连(Residual Connection),通过跨层的shortcut毗连,将输入直接转达到背面的层,使得背面的层可以学习到输入的残差,从而缓解了梯度消失和网络退化问题。
网络结构



  • 根本残差块:ResNet的根本组成单元是残差块(Residual Block)。一个典范的残差块包含两个3×3卷积层,中心有一个ReLU激活函数,而且在第二个卷积层之后也有一个ReLU激活函数。输入通过一个shortcut毗连直接与残差块的输出相加,形成残差学习。
  • 不同层数的架构:ResNet有多种不同层数的架构,如ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等。其中,数字表现网络中卷积层和全毗连层的总层数。层数越深,模型的表现本领越强,但盘算成本也越高。
创新点



  • 瓶颈结构:在ResNet-50及更深的网络中,接纳了瓶颈结构(Bottleneck)的残差块。这种结构先利用1×1卷积层举行降维,然后利用3×3卷积层举行特征提取,末了再利用1×1卷积层举行升维,这样可以在减少盘算量的同时增加网络的深度和宽度,进步模型的性能。
  • 全局平均池化:在网络的末了一层,ResNet接纳了全局平均池化(Global Average Pooling)取代传统的全毗连层举行分类。全局平均池化可以将每个特征图的空间维度压缩为一个值,得到一个固定长度的特征向量,然后直接输入到分类器中举行分类。
长处



  • 练习深度网络更容易:残差毗连使得梯度可以或许更有用地在网络中传播,大大降低了练习深度网络的难度,使得可以乐成练习上百层甚至上千层的网络。
  • 性能精彩:在各种图像辨认任务中,ResNet都取得了非常精彩的性能,相比之前的网络结构,具有更高的正确率和更好的泛化本领。
  • 模型可扩展性强:可以方便地通过增加残差块的数量来扩展网络的深度,以顺应不同的任务和数据集需求。
应用



  • 图像分类:ResNet在图像分类任务中取得了巨大乐成,如在ImageNet数据集上到达了很高的正确率,成为了图像分类范畴的主流模型之一。
  • 目标检测:与其他目标检测算法结合,如Faster R-CNN、YOLO等,通过提取图像的特征,进步目标检测的正确率和召回率。
  • 语义分割:用于对图像举行像素级的分类,将图像中的每个像素分配到不同的种别中,在城市景观分割、医学图像分割等范畴有广泛应用。
ResNet结构代码详解

结构代码

  1. import torch
  2. import torch.nn as nn
  3. class BasicBlock(nn.Module):
  4.     expansion = 1
  5.     def __init__(self, in_channels, out_channels, stride=1):
  6.         super(BasicBlock, self).__init__()
  7.         self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
  8.         self.bn1 = nn.BatchNorm2d(out_channels)
  9.         self.relu = nn.ReLU(inplace=True)
  10.         self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)
  11.         self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)
  12.         self.shortcut = nn.Sequential()
  13.         
  14.         if stride != 1 or in_channels != out_channels * BasicBlock.expansion:
  15.             self.shortcut = nn.Sequential(
  16.                 nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
  17.                 nn.BatchNorm2d(out_channels * BasicBlock.expansion)
  18.             )
  19.     def forward(self, x):
  20.         out = self.conv1(x)
  21.         out = self.bn1(out)
  22.         out = self.relu(out)
  23.         out = self.conv2(out)
  24.         out = self.bn2(out)
  25.         out += self.shortcut(x)
  26.         out = self.relu(out)
  27.         return out
  28. class ResNet(nn.Module):
  29.     def __init__(self, block, num_blocks, num_classes):
  30.         super(ResNet, self).__init__()
  31.         self.in_channels = 64
  32.         self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)
  33.         self.bn1 = nn.BatchNorm2d(64)
  34.         self.relu = nn.ReLU(inplace=True)
  35.         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
  36.         self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)
  37.         self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)
  38.         self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)
  39.         self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)
  40.         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  41.         self.fc = nn.Linear(512 * block.expansion, num_classes)
  42.     def _make_layer(self, block, out_channels, num_blocks, stride=1):
  43.         strides = [stride] + [1] * (num_blocks - 1)
  44.         layers = []
  45.         for stride in strides:
  46.             layers.append(block(self.in_channels, out_channels, stride))
  47.             self.in_channels = out_channels * block.expansion
  48.         return nn.Sequential(*layers)
  49.    
  50.     def forward(self, x):
  51.         out = self.conv1(x)
  52.         out = self.bn1(out)
  53.         out = self.relu(out)
  54.         out = self.maxpool(out)
  55.         out = self.layer1(out)
  56.         out = self.layer2(out)
  57.         out = self.layer3(out)
  58.         out = self.layer4(out)
  59.         out = self.avgpool(out)
  60.         out = out.view(out.size(0), -1)
  61.         out = self.fc(out)
  62.         return out
  63.    
  64. # ResNet18, ResNet34
  65. def ResNet18(num_classes):
  66.     return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
  67. def ResNet34(num_classes):
  68.     return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
  69. # ResNet50, ResNet101, ResNet152 需要 BottleNeck
  70. class Bottleneck(nn.Module):
  71.     expansion = 4
  72.     def __init__(self, in_channels, out_channels, stride=1):
  73.         super(Bottleneck, self).__init__()
  74.         self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
  75.         self.bn1= nn.BatchNorm2d(out_channels)
  76.         self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
  77.         self.bn2 = nn.BatchNorm2d(out_channels)
  78.         self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
  79.         self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
  80.         self.relu = nn.ReLU(inplace=True)
  81.         self.shortcut = nn.Sequential()
  82.         if stride != 1 or in_channels != out_channels * self.expansion:
  83.             self.shortcut = nn.Sequential(
  84.                 nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),
  85.                 nn.BatchNorm2d(out_channels * self.expansion)
  86.             )
  87.     def forward(self, x):
  88.         out = self.conv1(x)
  89.         out = self.bn1(out)
  90.         out = self.relu(out)
  91.         out = self.conv2(out)
  92.         out = self.bn2(out)
  93.         out = self.relu(out)
  94.         out = self.conv3(out)
  95.         out = self.bn3(out)
  96.         out += self.shortcut(x)
  97.         out = self.relu(out)
  98.         return out
  99. def ResNet50(num_classes):
  100.     return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
  101. def ResNet101(num_classes):
  102.     return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
  103. def ResNet152(num_classes):
  104.     return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
复制代码
代码详解

以下是对上述提供的PyTorch代码的详细表明,这段代码实现了经典的ResNet(残差网络)系列模型,包括ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等不同深度的网络架构:
BasicBlock 类

  1. class BasicBlock(nn.Module):
  2.     expansion = 1
  3.     def __init__(self, in_channels, out_channels, stride=1):
  4.         super(BasicBlock, self).__init__()
  5.         self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
  6.         self.bn1 = nn.BatchNorm2d(out_channels)
  7.         self.relu = nn.ReLU(inplace=True)
  8.         self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)
  9.         self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)
  10.         self.shortcut = nn.Sequential()
  11.         
  12.         if stride!= 1 or in_channels!= out_channels * BasicBlock.expansion:
  13.             self.shortcut = nn.Sequential(
  14.                 nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
  15.                 nn.BatchNorm2d(out_channels * BasicBlock.expansion)
  16.             )
复制代码


  • 类定义与属性

    • 定义了一个名为BasicBlock的类,继承自nn.Module,这是PyTorch中定义神经网络模块的基类。
    • expansion属性被设置为1,用于表现该根本块在通道维度上的扩展倍数,在BasicBlock中通道数不会举行额外的扩展(后续的Bottleneck块会有不同的扩展倍数)。

  • 初始化方法__init__

    • 首先调用父类nn.Module的初始化方法super(BasicBlock, self).__init__(),确保模块正确初始化。
    • 定义了两个卷积层conv1和conv2:

      • conv1:输入通道数为in_channels,输出通道数为out_channels,卷积核巨细为3×3,步长为stride,添补为1,而且不利用偏置(bias=False),这是遵循ResNet论文中的实现方式,通常配合后续的BatchNorm利用。
      • conv2:输入通道数为out_channels,输出通道数为out_channels * BasicBlock.expansion(现实就是out_channels,因为expansion为1),卷积核巨细同样是3×3,添补为1,无偏置。

    • 定义了两个BatchNorm2d层bn1和bn2,分别对应两个卷积层之后,用于对卷积后的特征举行归一化处理,有助于加快练习和进步模型的稳固性。
    • 定义了一个ReLU激活函数relu,而且设置inplace=True,表现直接在原张量上举行激活操作,节流内存空间(但要注意利用不妥大概导致梯度盘算问题,如前面提到的错误情况)。
    • 定义了shortcut,初始化为一个空的nn.Sequential序列。当输入和输出的通道数不划一或者步长不为1时(意味着尺寸或通道数有厘革),会重新构建shortcut,使其包含一个1×1卷积层(用于调解通道数)和一个BatchNorm2d层,以包管shortcut毗连的特征维度能与主分支的输出特征维度相匹配,便于后续举行相加操作。

  1.     def forward(self, x):
  2.         out = self.conv1(x)
  3.         out = self.bn1(out)
  4.         out = self.relu(out)
  5.         out = self.conv2(out)
  6.         out = self.bn2(out)
  7.         out += self.shortcut(x)
  8.         out = self.relu(out)
  9.         return out
复制代码


  • 前向传播方法forward

    • 首先将输入x经过conv1卷积、bn1归一化后,再通过relu激活函数得到中心特征。
    • 接着将该中心特征再经过conv2卷积和bn2归一化。
    • 然后将主分支得到的特征out与shortcut分支(直接毗连输入x经过调解后的特征)举行逐元素相加,实现残差毗连的操作。
    • 末了再经过一次relu激活函数后返回结果,作为该根本块的输出。

ResNet 类

  1. class ResNet(nn.Module):
  2.     def __init__(self, block, num_blocks, num_classes):
  3.         super(ResNet, self).__init__()
  4.         self.in_channels = 64
  5.         self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)
  6.         self.bn1 = nn.BatchNorm2d(64)
  7.         self.relu = nn.ReLU(inplace=True)
  8.         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
  9.         self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)
  10.         self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)
  11.         self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)
  12.         self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)
  13.         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  14.         self.fc = nn.Linear(512 * block.expansion, num_classes)
复制代码


  • 类定义与属性

    • 定义了ResNet类,同样继承自nn.Module,用于构建完整的ResNet网络架构。
    • 初始化了一个属性in_channels为64,用于记录当前层的输入通道数,后续会动态更新。
    • 定义了网络的起始层,包括一个3×3卷积层conv1(输入通道为3,对应彩色图像的RGB三个通道,输出通道为64),一个BatchNorm2d层bn1用于归一化,一个ReLU激活函数relu,以及一个最大池化层maxpool(其参数设置按照通例的ResNet结构设置)。
    • 分别定义了layer1、layer2、layer3、layer4这四层网络结构,它们通过调用_make_layer方法来构建,每层的输出通道数以及重复的块数量由传入的参数决定,而且随着层数加深,步长会相应改变(从第二层开始步长为2,用于逐步减小特征图尺寸)。
    • 定义了一个自顺应平均池化层avgpool,它能将输入的特征图尺寸自顺应地变为(1, 1)巨细,无论输入特征图的尺寸本来是多少,便于后续全毗连层处理。末了定义了一个全毗连层fc,用于将池化后的特征映射到指定的种别数num_classes上举行分类。

  1.     def _make_layer(self, block, out_channels, num_blocks, stride=1):
  2.         strides = [stride] + [1] * (num_blocks - 1)
  3.         layers = []
  4.         for stride in strides:
  5.             layers.append(block(self.in_channels, out_channels, stride))
  6.             self.in_channels = out_channels * block.expansion
  7.         return nn.Sequential(*layers)
复制代码


  • _make_layer方法

    • 这个方法用于构建ResNet中的每一层网络结构(由多个根本块组成)。
    • 首先根据传入的stride和num_blocks生成一个步长列表strides,例如,如果传入stride=2和num_blocks=3,那么strides会是[2, 1, 1],意味着第一个根本块大概会改变特征图的尺寸和通道数,背面的根本块保持步长为1。
    • 然后循环遍历strides列表,每次创建一个指定的block(可以是BasicBlock或者后续定义的Bottleneck块),并传入当前的输入通道数、输出通道数以及对应的步长,将创建好的块添加到layers列表中。同时,更新self.in_channels为当前块输出的通道数(思量了块的扩展倍数)。
    • 末了将layers列表中的全部块组合成一个nn.Sequential序列并返回,形成一层完整的网络结构。

  1.     def forward(self, x):
  2.         out = self.conv1(x)
  3.         out = self.bn1(out)
  4.         out = self.relu(out)
  5.         out = self.maxpool(out)
  6.         out = self.layer1(out)
  7.         out = self.layer2(out)
  8.         out = self.layer3(out)
  9.         out = self.layer4(out)
  10.         out = self.avgpool(out)
  11.         out = out.view(out.size(0), -1)
  12.         out = self.fc(out)
  13.         return out
复制代码


  • 前向传播方法forward

    • 首先将输入x依次经过网络起始层的卷积、归一化、激活和池化操作,得到开端的特征表现。
    • 然后将该特征依次通过layer1、layer2、layer3、layer4这四层网络结构,不停提取和融合特征,每一层都会进一步加深特征的抽象程度而且改变特征图的尺寸和通道数。
    • 接着经过自顺应平均池化层avgpool,将特征图变为(1, 1)巨细的特征向量。
    • 通过out.view(out.size(0), -1)操作将特征向量展平为一维向量,使其能输入到全毗连层fc中。
    • 末了将全毗连层的输出作为整个网络的最终输出,返回分类结果。

ResNet18、ResNet34、ResNet50、ResNet101、ResNet152函数

  1. # ResNet18, ResNet34
  2. def ResNet18(num_classes):
  3.     return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
  4. def ResNet34(num_classes):
  5.     return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
复制代码


  • 这两个函数分别用于创建ResNet-18和ResNet-34网络模型。它们通过调用ResNet类的构造函数,传入BasicBlock作为构建块范例,以及对应不同层数的重复块数量列表(如ResNet-18中每层分别重复2个根本块),还有指定的种别数num_classes,最终返回构建好的相应深度的ResNet模型实例,用于图像分类等任务。
  1. # ResNet50, ResNet101, ResNet152 需要 BottleNeck
  2. class Bottleneck(nn.Module):
  3.     expansion = 4
  4.     def __init__(self, in_channels, out_channels, stride=1):
  5.         super(Bottleneck, self).__init__()
  6.         self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
  7.         self.bn1= nn.BatchNorm2d(out_channels)
  8.         self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
  9.         self.bn2 = nn.BatchNorm2d(out_channels)
  10.         self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
  11.         self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
  12.         self.relu = nn.ReLU(inplace=True)
  13.         self.shortcut = nn.Sequential()
  14.         if stride!= 1 or in_channels!= out_channels * self.expansion:
  15.             self.shortcut = nn.Sequential(
  16.                 nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),
  17.                 nn.BatchNorm2d(out_channels * self.expansion)
  18.             )
复制代码


  • Bottleneck类定义与初始化

    • 定义了Bottleneck类,同样继承自nn.Module,用于构建更深层的ResNet网络(如ResNet-50及以上)中的根本块。
    • expansion属性被设置为4,意味着该块在经过一系列操作后,输出通道数会是输入通道数的4倍,通过这种方式在增加网络深度的同时控制盘算量。
    • 在初始化方法中,定义了三个卷积层conv1、conv2、conv3,分别是1×1卷积用于降维、3×3卷积举行主要的特征提取、1×1卷积用于升维,而且每个卷积层后都有对应的BatchNorm2d层举行归一化,还有ReLU激活函数用于激活中心特征。
    • 同样定义了shortcut,其构建逻辑和BasicBlock中类似,根据输入输出通道数和步长情况来决定是否需要构建包含1×1卷积和BatchNorm2d层的调解结构,以包管残差毗连的维度匹配。

  1.     def forward(self, x):
  2.         out = self.conv1(x)
  3.         out = self.bn1(out)
  4.         out = self.relu(out)
  5.         out = self.conv2(out)
  6.         out = self.bn2(out)
  7.         out = self.relu(out)
  8.         out = self.conv3(out)
  9.         out = self.bn3(out)
  10.         out += self.shortcut(x)
  11.         out = self.relu(out)
  12.         return out
复制代码


  • Bottleneck块的前向传播方法

    • 前向传播过程与BasicBlock类似,只是中心经过了三个卷积层及对应的归一化和激活操作,末了同样是将主分支特征与shortcut分支特征相加后再经过ReLU激活函数输出,实现残差学习。

  1. def ResNet50(num_classes):
  2.     return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
  3. def ResNet101(num_classes):
  4.     return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
  5. def ResNet152(num_classes):
  6.     return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
复制代码


  • 这几个函数分别用于创建ResNet-50、ResNet-101和ResNet-152网络模型,它们与创建ResNet-18、ResNet-34的函数类似,只是传入的构建块范例变为Bottleneck,以及对应不同层数的重复Bottleneck块数量列表,还有指定的种别数num_classes,最终返回相应深度的ResNet模型实例,用于更复杂的图像分类等任务,这些更深层的网络结构在处理大规模图像数据集时往往能取得更好的性能表现。
练习过程和测试结果

练习过程损失函数厘革曲线:

练习过程正确率厘革曲线:

测试结果:

代码汇总

项目github所在
项目结构:
  1. |--data
  2. |--models
  3.         |--__init__.py
  4.         |-resnet.py
  5.         |--...
  6. |--results
  7. |--weights
  8. |--train.py
  9. |--test.py
复制代码
resnet.py

  1. import torch
  2. import torch.nn as nn
  3. class BasicBlock(nn.Module):
  4.     expansion = 1
  5.     def __init__(self, in_channels, out_channels, stride=1):
  6.         super(BasicBlock, self).__init__()
  7.         self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
  8.         self.bn1 = nn.BatchNorm2d(out_channels)
  9.         self.relu = nn.ReLU(inplace=True)
  10.         self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)
  11.         self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)
  12.         self.shortcut = nn.Sequential()
  13.         
  14.         if stride != 1 or in_channels != out_channels * BasicBlock.expansion:
  15.             self.shortcut = nn.Sequential(
  16.                 nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
  17.                 nn.BatchNorm2d(out_channels * BasicBlock.expansion)
  18.             )
  19.     def forward(self, x):
  20.         out = self.conv1(x)
  21.         out = self.bn1(out)
  22.         out = self.relu(out)
  23.         out = self.conv2(out)
  24.         out = self.bn2(out)
  25.         out += self.shortcut(x)
  26.         out = self.relu(out)
  27.         return out
  28. class ResNet(nn.Module):
  29.     def __init__(self, block, num_blocks, num_classes):
  30.         super(ResNet, self).__init__()
  31.         self.in_channels = 64
  32.         self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)
  33.         self.bn1 = nn.BatchNorm2d(64)
  34.         self.relu = nn.ReLU(inplace=True)
  35.         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
  36.         self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)
  37.         self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)
  38.         self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)
  39.         self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)
  40.         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  41.         self.fc = nn.Linear(512 * block.expansion, num_classes)
  42.     def _make_layer(self, block, out_channels, num_blocks, stride=1):
  43.         strides = [stride] + [1] * (num_blocks - 1)
  44.         layers = []
  45.         for stride in strides:
  46.             layers.append(block(self.in_channels, out_channels, stride))
  47.             self.in_channels = out_channels * block.expansion
  48.         return nn.Sequential(*layers)
  49.    
  50.     def forward(self, x):
  51.         out = self.conv1(x)
  52.         out = self.bn1(out)
  53.         out = self.relu(out)
  54.         out = self.maxpool(out)
  55.         out = self.layer1(out)
  56.         out = self.layer2(out)
  57.         out = self.layer3(out)
  58.         out = self.layer4(out)
  59.         out = self.avgpool(out)
  60.         out = out.view(out.size(0), -1)
  61.         out = self.fc(out)
  62.         return out
  63.    
  64. # ResNet18, ResNet34
  65. def ResNet18(num_classes):
  66.     return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
  67. def ResNet34(num_classes):
  68.     return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
  69. # ResNet50, ResNet101, ResNet152 需要 BottleNeck
  70. class Bottleneck(nn.Module):
  71.     expansion = 4
  72.     def __init__(self, in_channels, out_channels, stride=1):
  73.         super(Bottleneck, self).__init__()
  74.         self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
  75.         self.bn1= nn.BatchNorm2d(out_channels)
  76.         self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
  77.         self.bn2 = nn.BatchNorm2d(out_channels)
  78.         self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
  79.         self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
  80.         self.relu = nn.ReLU(inplace=True)
  81.         self.shortcut = nn.Sequential()
  82.         if stride != 1 or in_channels != out_channels * self.expansion:
  83.             self.shortcut = nn.Sequential(
  84.                 nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),
  85.                 nn.BatchNorm2d(out_channels * self.expansion)
  86.             )
  87.     def forward(self, x):
  88.         out = self.conv1(x)
  89.         out = self.bn1(out)
  90.         out = self.relu(out)
  91.         out = self.conv2(out)
  92.         out = self.bn2(out)
  93.         out = self.relu(out)
  94.         out = self.conv3(out)
  95.         out = self.bn3(out)
  96.         out += self.shortcut(x)
  97.         out = self.relu(out)
  98.         return out
  99. def ResNet50(num_classes):
  100.     return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
  101. def ResNet101(num_classes):
  102.     return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
  103. def ResNet152(num_classes):
  104.     return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
复制代码
train.py

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torchvision
  5. import torchvision.transforms as transforms
  6. from models import *
  7. import matplotlib.pyplot as plt
  8. import ssl
  9. ssl._create_default_https_context = ssl._create_unverified_context
  10. # 定义数据预处理操作
  11. transform = transforms.Compose(
  12.     [transforms.ToTensor(),
  13.      transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])
  14. # 加载CIFAR10训练集
  15. trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
  16.                                         download=False, transform=transform)
  17. trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
  18.                                           shuffle=True, num_workers=2)
  19. # 定义设备(GPU优先,若可用)
  20. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  21. # 实例化模型
  22. model_name = 'ResNet18'
  23. if model_name == 'AlexNet':
  24.     model = AlexNet(num_classes=10).to(device)
  25. elif model_name == 'Vgg_A':
  26.     model = Vgg(cfg_vgg='A', num_classes=10).to(device)
  27. elif model_name == 'Vgg_A-LRN':
  28.     model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
  29. elif model_name == 'Vgg_B':
  30.     model = Vgg(cfg_vgg='B', num_classes=10).to(device)
  31. elif model_name == 'Vgg_C':
  32.     model = Vgg(cfg_vgg='C', num_classes=10).to(device)
  33. elif model_name == 'Vgg_D':
  34.     model = Vgg(cfg_vgg='D', num_classes=10).to(device)
  35. elif model_name == 'Vgg_E':
  36.     model = Vgg(cfg_vgg='E', num_classes=10).to(device)
  37. elif model_name == 'GoogleNet':
  38.     model = GoogleNet(num_classes=10).to(device)
  39. elif model_name == 'ResNet18':
  40.     model = ResNet18(num_classes=10).to(device)
  41. elif model_name == 'ResNet34':
  42.     model = ResNet34(num_classes=10).to(device)
  43. elif model_name == 'ResNet50':
  44.     model = ResNet50(num_classes=10).to(device)
  45. elif model_name == 'ResNet101':
  46.     model = ResNet101(num_classes=10).to(device)
  47. elif model_name == 'ResNet152':
  48.     model = ResNet152(num_classes=10).to(device)
  49. criterion = nn.CrossEntropyLoss()
  50. optimizer = optim.Adam(model.parameters(), lr=0.001)
  51. # 训练轮次
  52. epochs = 15
  53. def train(model, trainloader, criterion, optimizer, device):
  54.     model.train()
  55.     running_loss = 0.0
  56.     correct = 0
  57.     total = 0
  58.     for i, data in enumerate(trainloader, 0):
  59.         inputs, labels = data[0].to(device), data[1].to(device)
  60.         optimizer.zero_grad()
  61.         outputs = model(inputs)
  62.         loss = criterion(outputs, labels)
  63.         loss.backward()
  64.         optimizer.step()
  65.         running_loss += loss.item()
  66.         _, predicted = outputs.max(1)
  67.         total += labels.size(0)
  68.         correct += predicted.eq(labels).sum().item()
  69.     epoch_loss = running_loss / len(trainloader)
  70.     epoch_acc = 100. * correct / total
  71.     return epoch_loss, epoch_acc
  72. if __name__ == "__main__":
  73.     loss_history, acc_history = [], []
  74.     for epoch in range(epochs):
  75.         train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)
  76.         print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
  77.         loss_history.append(train_loss)
  78.         acc_history.append(train_acc)
  79.         # 保存模型权重,每5轮次保存到weights文件夹下
  80.         if (epoch + 1) % 5 == 0:
  81.             torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')
  82.    
  83.     # 绘制损失曲线
  84.     plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')
  85.     plt.xlabel('Epoch')
  86.     plt.ylabel('Loss')
  87.     plt.title('Training Loss Curve')
  88.     plt.legend()
  89.     plt.savefig(f'results\\{model_name}_train_loss_curve.png')
  90.     plt.close()
  91.     # 绘制准确率曲线
  92.     plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')
  93.     plt.xlabel('Epoch')
  94.     plt.ylabel('Accuracy (%)')
  95.     plt.title('Training Accuracy Curve')
  96.     plt.legend()
  97.     plt.savefig(f'results\\{model_name}_train_acc_curve.png')
  98.     plt.close()
复制代码
test.py

  1. import torch
  2. import torch.nn as nn
  3. import torchvision
  4. import torchvision.transforms as transforms
  5. from models import *
  6. import ssl
  7. ssl._create_default_https_context = ssl._create_unverified_context
  8. # 定义数据预处理操作
  9. transform = transforms.Compose(
  10.     [transforms.ToTensor(),
  11.      transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])
  12. # 加载CIFAR10测试集
  13. testset = torchvision.datasets.CIFAR10(root='./data', train=False,
  14.                                        download=False, transform=transform)
  15. testloader = torch.utils.data.DataLoader(testset, batch_size=128,
  16.                                          shuffle=False, num_workers=2)
  17. # 定义设备(GPU优先,若可用)
  18. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  19. # 实例化模型
  20. model_name = 'ResNet18'
  21. if model_name == 'AlexNet':
  22.     model = AlexNet(num_classes=10).to(device)
  23. elif model_name == 'Vgg_A':
  24.     model = Vgg(cfg_vgg='A', num_classes=10).to(device)
  25. elif model_name == 'Vgg_A-LRN':
  26.     model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
  27. elif model_name == 'Vgg_B':
  28.     model = Vgg(cfg_vgg='B', num_classes=10).to(device)
  29. elif model_name == 'Vgg_C':
  30.     model = Vgg(cfg_vgg='C', num_classes=10).to(device)
  31. elif model_name == 'Vgg_D':
  32.     model = Vgg(cfg_vgg='D', num_classes=10).to(device)
  33. elif model_name == 'Vgg_E':
  34.     model = Vgg(cfg_vgg='E', num_classes=10).to(device)
  35. elif model_name == 'GoogleNet':
  36.     model = GoogleNet(num_classes=10).to(device)
  37. elif model_name == 'ResNet18':
  38.     model = ResNet18(num_classes=10).to(device)
  39. elif model_name == 'ResNet34':
  40.     model = ResNet34(num_classes=10).to(device)
  41. elif model_name == 'ResNet50':
  42.     model = ResNet50(num_classes=10).to(device)
  43. elif model_name == 'ResNet101':
  44.     model = ResNet101(num_classes=10).to(device)
  45. elif model_name == 'ResNet152':
  46.     model = ResNet152(num_classes=10).to(device)
  47. criterion = nn.CrossEntropyLoss()
  48. # 加载模型权重
  49. weights_path = f"weights/{model_name}_epoch_15.pth"  
  50. model.load_state_dict(torch.load(weights_path, map_location=device))
  51. def test(model, testloader, criterion, device):
  52.     model.eval()
  53.     running_loss = 0.0
  54.     correct = 0
  55.     total = 0
  56.     with torch.no_grad():
  57.         for data in testloader:
  58.             inputs, labels = data[0].to(device), data[1].to(device)
  59.             outputs = model(inputs)
  60.             loss = criterion(outputs, labels)
  61.             running_loss += loss.item()
  62.             _, predicted = outputs.max(1)
  63.             total += labels.size(0)
  64.             correct += predicted.eq(labels).sum().item()
  65.     epoch_loss = running_loss / len(testloader)
  66.     epoch_acc = 100. * correct / total
  67.     return epoch_loss, epoch_acc
  68. if __name__ == "__main__":
  69.     test_loss, test_acc = test(model, testloader, criterion, device)
  70.     print(f"================{model_name} Test================")
  71.     print(f"Load Model Weights From: {weights_path}")
  72.     print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

卖不甜枣

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表