J4学习打卡

[复制链接]
发表于 2025-12-30 05:44:26 | 显示全部楼层 |阅读模式

  • 🍨 本文为🔗365天深度学习训练营 中的学习纪录博客
  • 🍖 原作者:K同砚啊
  DPN(ResNet与DenseNet团结)

  1. import torch
  2. import torch.nn as nn
  3. import torchvision.transforms as transforms
  4. import torchvision
  5. from torchvision import transforms, datasets
  6. import os, PIL, pathlib, warnings
  7. warnings.filterwarnings("ignore")
  8. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  9. print(device)
  10. import os,PIL,random,pathlib
  11. data_dir_str = r'C:\Users\11054\Desktop\kLearning\J1_learning\bird_photos'
  12. data_dir = pathlib.Path(data_dir_str)
  13. print("data_dir:", data_dir, "\n")
  14. data_paths = list(data_dir.glob('*'))
  15. classNames = [str(path).split('/')[-1] for path in data_paths]
  16. print('classNames:', classNames , '\n')
  17. train_transforms = transforms.Compose([
  18.     transforms.Resize([224, 224]),  # resize输入图片
  19.     transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换成tensor
  20.     transforms.Normalize(
  21.         mean=[0.485, 0.456, 0.406],
  22.         std=[0.229, 0.224, 0.225])  # 从数据集中随机抽样计算得到
  23. ])
  24. total_data = datasets.ImageFolder(data_dir_str, transform=train_transforms)
  25. print(total_data)
  26. print(total_data.class_to_idx)
  27. train_size = int(0.8 * len(total_data))
  28. test_size = len(total_data) - train_size
  29. train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
  30. print(train_dataset, test_dataset)
  31. batch_size = 4
  32. train_dl = torch.utils.data.DataLoader(train_dataset,
  33.                                       batch_size=batch_size,
  34.                                       shuffle=True,
  35.                                       num_workers=1,
  36.                                       pin_memory=False)
  37. test_dl = torch.utils.data.DataLoader(test_dataset,
  38.                                       batch_size=batch_size,
  39.                                       shuffle=True,
  40.                                       num_workers=1,
  41.                                       pin_memory=False)
  42. for X, y in test_dl:
  43.     print("Shape of X [N, C, H, W]:", X.shape)
  44.     print("Shape of y:", y.shape, y.dtype)
  45.     break
  46. import torch
  47. import torch.nn as nn
  48. class Block(nn.Module):
  49.     """
  50.     param : in_channel--输入通道数
  51.             mid_channel -- 中间经历的通道数
  52.             out_channel -- ResNet部分使用的通道数(sum操作,这部分输出仍然是out_channel个通道)
  53.             dense_channel -- DenseNet部分使用的通道数(concat操作,这部分输出是2*dense_channel个通道)
  54.             groups -- conv2中的分组卷积参数
  55.             is_shortcut -- ResNet前是否进行shortcut操作
  56.     """
  57.     def __init__(self, in_channel, mid_channel, out_channel, dense_channel, stride, groups, is_shortcut=False):
  58.         super(Block, self).__init__()
  59.         self.is_shortcut = is_shortcut
  60.         self.out_channel = out_channel
  61.         self.conv1 = nn.Sequential(
  62.             nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False),
  63.             nn.BatchNorm2d(mid_channel),
  64.             nn.ReLU()
  65.         )
  66.         self.conv2 = nn.Sequential(
  67.             nn.Conv2d(mid_channel, mid_channel, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False),
  68.             nn.BatchNorm2d(mid_channel),
  69.             nn.ReLU()
  70.         )
  71.         self.conv3 = nn.Sequential(
  72.             nn.Conv2d(mid_channel, out_channel+dense_channel, kernel_size=1, bias=False),
  73.             nn.BatchNorm2d(out_channel+dense_channel)
  74.         )
  75.         if self.is_shortcut:
  76.             self.shortcut = nn.Sequential(
  77.             nn.Conv2d(in_channel, out_channel+dense_channel, kernel_size=3, padding=1, stride=stride, bias=False),
  78.             nn.BatchNorm2d(out_channel+dense_channel)
  79.         )
  80.         self.relu = nn.ReLU(inplace=True)
  81.     def forward(self, x):
  82.         a = x
  83.         x = self.conv1(x)
  84.         x = self.conv2(x)
  85.         x = self.conv3(x)
  86.         if self.is_shortcut:
  87.             a = self.shortcut(a)
  88.         # a[:, :self.out_channel, :, :]+x[:, :self.out_channel, :, :]是使用ResNet的方法,即采用sum的方式将特征图进行求和,通道数不变,都是out_channel个通道
  89.         # a[:, self.out_channel:, :, :], x[:, self.out_channel:, :, :]]是使用DenseNet的方法,即采用concat的方式将特征图在channel维度上直接进行叠加,通道数加倍,即2*dense_channel
  90.         # 注意最终是将out_channel个通道的特征(ResNet方式)与2*dense_channel个通道特征(DenseNet方式)进行叠加,因此最终通道数为out_channel+2*dense_channel
  91.         x = torch.cat([a[:, :self.out_channel, :, :]+x[:, :self.out_channel, :, :], a[:, self.out_channel:, :, :], x[:, self.out_channel:, :, :]], dim=1)
  92.         x = self.relu(x)
  93.         return x
  94. # DPN搭建
  95. class DPN(nn.Module):
  96.     def __init__(self, cfg):
  97.         super(DPN, self).__init__()
  98.         self.group = cfg['group']
  99.         self.in_channel = cfg['in_channel']
  100.         mid_channels = cfg['mid_channels']
  101.         out_channels = cfg['out_channels']
  102.         dense_channels = cfg['dense_channels']
  103.         num = cfg['num']
  104.         self.conv1 = nn.Sequential(
  105.             nn.Conv2d(3, self.in_channel, 7, stride=2, padding=3, bias=False, padding_mode='zeros'),
  106.             nn.BatchNorm2d(self.in_channel),
  107.             nn.ReLU(),
  108.             nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
  109.         )
  110.         self.conv2 = self._make_layers(mid_channels[0], out_channels[0], dense_channels[0], num[0], stride=1)
  111.         self.conv3 = self._make_layers(mid_channels[1], out_channels[1], dense_channels[1], num[1], stride=2)
  112.         self.conv4 = self._make_layers(mid_channels[2], out_channels[2], dense_channels[2], num[2], stride=2)
  113.         self.conv5 = self._make_layers(mid_channels[3], out_channels[3], dense_channels[3], num[3], stride=2)
  114.         self.pool = nn.AdaptiveAvgPool2d((1,1))
  115.         self.fc = nn.Linear(cfg['out_channels'][3] + (num[3] + 1) * cfg['dense_channels'][3], cfg['classes']) # fc层需要计算
  116.     def _make_layers(self, mid_channel, out_channel, dense_channel, num, stride):
  117.         layers = []
  118.         # is_shortcut=True表示进行shortcut操作,则将浅层的特征进行一次卷积后与进行第三次卷积的特征图相加(ResNet方式)和concat(DeseNet方式)操作
  119.         # 第一次使用Block可以满足浅层特征的利用,后续重复的Block则不需要线层特征,因此后续的Block的is_shortcut=False(默认值)
  120.         layers.append(Block(self.in_channel, mid_channel, out_channel, dense_channel, stride=stride, groups=self.group, is_shortcut=True))
  121.         self.in_channel = out_channel + dense_channel*2
  122.         for i in range(1, num):
  123.             layers.append(Block(self.in_channel, mid_channel, out_channel, dense_channel, stride=1, groups=self.group))
  124.              # 由于Block包含DenseNet在叠加特征图,所以第一次是2倍dense_channel,后面每次都会多出1倍dense_channel
  125.             self.in_channel +=  dense_channel
  126.         return nn.Sequential(*layers)
  127.     def forward(self, x):
  128.         x = self.conv1(x)
  129.         x = self.conv2(x)
  130.         x = self.conv3(x)
  131.         x = self.conv4(x)
  132.         x = self.conv5(x)
  133.         x = self.pool(x)
  134.         x = torch.flatten(x, start_dim=1)
  135.         x = self.fc(x)
  136.         return x
  137. def DPN92(n_class=4):
  138.     cfg = {
  139.         "group" : 32,
  140.         "in_channel" : 64,
  141.         "mid_channels" : (96, 192, 384, 768),
  142.         "out_channels" : (256, 512, 1024, 2048),
  143.         "dense_channels" : (16, 32, 24, 128),
  144.         "num" : (3, 4, 20, 3),
  145.         "classes" : (n_class)
  146.     }
  147.     return DPN(cfg)
  148. def DPN98(n_class=4):
  149.     cfg = {
  150.         "group" : 40,
  151.         "in_channel" : 96,
  152.         "mid_channels" : (160, 320, 640, 1280),
  153.         "out_channels" : (256, 512, 1024, 2048),
  154.         "dense_channels" : (16, 32, 32, 128),
  155.         "num" : (3, 6, 20, 3),
  156.         "classes" : (n_class)
  157.     }
  158.     return DPN(cfg)
  159. model = DPN92().to(device)
  160. import torchsummary as summary
  161. summary.summary(model, (3, 224, 224))
复制代码
  1. cuda
  2. data_dir: C:\Users\11054\Desktop\kLearning\J1_learning\bird_photos
  3. classNames: ['C:\\Users\\11054\\Desktop\\kLearning\\J1_learning\\bird_photos\\Bananaquit', 'C:\\Users\\11054\\Desktop\\kLearning\\J1_learning\\bird_photos\\Black Skimmer', 'C:\\Users\\11054\\Desktop\\kLearning\\J1_learning\\bird_photos\\Black Throated Bushtiti', 'C:\\Users\\11054\\Desktop\\kLearning\\J1_learning\\bird_photos\\Cockatoo']
  4. Dataset ImageFolder
  5.     Number of datapoints: 565
  6.     Root location: C:\Users\11054\Desktop\kLearning\J1_learning\bird_photos
  7.     StandardTransform
  8. Transform: Compose(
  9.                Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
  10.                ToTensor()
  11.                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  12.            )
  13. {'Bananaquit': 0, 'Black Skimmer': 1, 'Black Throated Bushtiti': 2, 'Cockatoo': 3}
  14. <torch.utils.data.dataset.Subset object at 0x0000023C6A0BD880> <torch.utils.data.dataset.Subset object at 0x0000023C6D151280>
  15. Shape of X [N, C, H, W]: torch.Size([4, 3, 224, 224])
  16. Shape of y: torch.Size([4]) torch.int64
  17. ----------------------------------------------------------------
  18.         Layer (type)               Output Shape         Param #
  19. ================================================================
  20.             Conv2d-1         [-1, 64, 112, 112]           9,408
  21.        BatchNorm2d-2         [-1, 64, 112, 112]             128
  22.               ReLU-3         [-1, 64, 112, 112]               0
  23.          MaxPool2d-4           [-1, 64, 55, 55]               0
  24.             Conv2d-5           [-1, 96, 55, 55]           6,144
  25.        BatchNorm2d-6           [-1, 96, 55, 55]             192
  26.               ReLU-7           [-1, 96, 55, 55]               0
  27.             Conv2d-8           [-1, 96, 55, 55]           2,592
  28.        BatchNorm2d-9           [-1, 96, 55, 55]             192
  29.              ReLU-10           [-1, 96, 55, 55]               0
  30.            Conv2d-11          [-1, 272, 55, 55]          26,112
  31.       BatchNorm2d-12          [-1, 272, 55, 55]             544
  32.            Conv2d-13          [-1, 272, 55, 55]         156,672
  33.       BatchNorm2d-14          [-1, 272, 55, 55]             544
  34.              ReLU-15          [-1, 288, 55, 55]               0
  35.             Block-16          [-1, 288, 55, 55]               0
  36.            Conv2d-17           [-1, 96, 55, 55]          27,648
  37.       BatchNorm2d-18           [-1, 96, 55, 55]             192
  38.              ReLU-19           [-1, 96, 55, 55]               0
  39.            Conv2d-20           [-1, 96, 55, 55]           2,592
  40.       BatchNorm2d-21           [-1, 96, 55, 55]             192
  41.              ReLU-22           [-1, 96, 55, 55]               0
  42.            Conv2d-23          [-1, 272, 55, 55]          26,112
  43.       BatchNorm2d-24          [-1, 272, 55, 55]             544
  44.              ReLU-25          [-1, 304, 55, 55]               0
  45.             Block-26          [-1, 304, 55, 55]               0
  46.            Conv2d-27           [-1, 96, 55, 55]          29,184
  47.       BatchNorm2d-28           [-1, 96, 55, 55]             192
  48.              ReLU-29           [-1, 96, 55, 55]               0
  49.            Conv2d-30           [-1, 96, 55, 55]           2,592
  50.       BatchNorm2d-31           [-1, 96, 55, 55]             192
  51.              ReLU-32           [-1, 96, 55, 55]               0
  52.            Conv2d-33          [-1, 272, 55, 55]          26,112
  53.       BatchNorm2d-34          [-1, 272, 55, 55]             544
  54.              ReLU-35          [-1, 320, 55, 55]               0
  55.             Block-36          [-1, 320, 55, 55]               0
  56.            Conv2d-37          [-1, 192, 55, 55]          61,440
  57.       BatchNorm2d-38          [-1, 192, 55, 55]             384
  58.              ReLU-39          [-1, 192, 55, 55]               0
  59.            Conv2d-40          [-1, 192, 28, 28]          10,368
  60.       BatchNorm2d-41          [-1, 192, 28, 28]             384
  61.              ReLU-42          [-1, 192, 28, 28]               0
  62.            Conv2d-43          [-1, 544, 28, 28]         104,448
  63.       BatchNorm2d-44          [-1, 544, 28, 28]           1,088
  64.            Conv2d-45          [-1, 544, 28, 28]       1,566,720
  65.       BatchNorm2d-46          [-1, 544, 28, 28]           1,088
  66.              ReLU-47          [-1, 576, 28, 28]               0
  67.             Block-48          [-1, 576, 28, 28]               0
  68.            Conv2d-49          [-1, 192, 28, 28]         110,592
  69.       BatchNorm2d-50          [-1, 192, 28, 28]             384
  70.              ReLU-51          [-1, 192, 28, 28]               0
  71.            Conv2d-52          [-1, 192, 28, 28]          10,368
  72.       BatchNorm2d-53          [-1, 192, 28, 28]             384
  73.              ReLU-54          [-1, 192, 28, 28]               0
  74.            Conv2d-55          [-1, 544, 28, 28]         104,448
  75.       BatchNorm2d-56          [-1, 544, 28, 28]           1,088
  76.              ReLU-57          [-1, 608, 28, 28]               0
  77.             Block-58          [-1, 608, 28, 28]               0
  78.            Conv2d-59          [-1, 192, 28, 28]         116,736
  79.       BatchNorm2d-60          [-1, 192, 28, 28]             384
  80.              ReLU-61          [-1, 192, 28, 28]               0
  81.            Conv2d-62          [-1, 192, 28, 28]          10,368
  82.       BatchNorm2d-63          [-1, 192, 28, 28]             384
  83.              ReLU-64          [-1, 192, 28, 28]               0
  84.            Conv2d-65          [-1, 544, 28, 28]         104,448
  85.       BatchNorm2d-66          [-1, 544, 28, 28]           1,088
  86.              ReLU-67          [-1, 640, 28, 28]               0
  87.             Block-68          [-1, 640, 28, 28]               0
  88.            Conv2d-69          [-1, 192, 28, 28]         122,880
  89.       BatchNorm2d-70          [-1, 192, 28, 28]             384
  90.              ReLU-71          [-1, 192, 28, 28]               0
  91.            Conv2d-72          [-1, 192, 28, 28]          10,368
  92.       BatchNorm2d-73          [-1, 192, 28, 28]             384
  93.              ReLU-74          [-1, 192, 28, 28]               0
  94.            Conv2d-75          [-1, 544, 28, 28]         104,448
  95.       BatchNorm2d-76          [-1, 544, 28, 28]           1,088
  96.              ReLU-77          [-1, 672, 28, 28]               0
  97.             Block-78          [-1, 672, 28, 28]               0
  98.            Conv2d-79          [-1, 384, 28, 28]         258,048
  99.       BatchNorm2d-80          [-1, 384, 28, 28]             768
  100.              ReLU-81          [-1, 384, 28, 28]               0
  101.            Conv2d-82          [-1, 384, 14, 14]          41,472
  102.       BatchNorm2d-83          [-1, 384, 14, 14]             768
  103.              ReLU-84          [-1, 384, 14, 14]               0
  104.            Conv2d-85         [-1, 1048, 14, 14]         402,432
  105.       BatchNorm2d-86         [-1, 1048, 14, 14]           2,096
  106.            Conv2d-87         [-1, 1048, 14, 14]       6,338,304
  107.       BatchNorm2d-88         [-1, 1048, 14, 14]           2,096
  108.              ReLU-89         [-1, 1072, 14, 14]               0
  109.             Block-90         [-1, 1072, 14, 14]               0
  110.            Conv2d-91          [-1, 384, 14, 14]         411,648
  111.       BatchNorm2d-92          [-1, 384, 14, 14]             768
  112.              ReLU-93          [-1, 384, 14, 14]               0
  113.            Conv2d-94          [-1, 384, 14, 14]          41,472
  114.       BatchNorm2d-95          [-1, 384, 14, 14]             768
  115.              ReLU-96          [-1, 384, 14, 14]               0
  116.            Conv2d-97         [-1, 1048, 14, 14]         402,432
  117.       BatchNorm2d-98         [-1, 1048, 14, 14]           2,096
  118.              ReLU-99         [-1, 1096, 14, 14]               0
  119.            Block-100         [-1, 1096, 14, 14]               0
  120.           Conv2d-101          [-1, 384, 14, 14]         420,864
  121.      BatchNorm2d-102          [-1, 384, 14, 14]             768
  122.             ReLU-103          [-1, 384, 14, 14]               0
  123.           Conv2d-104          [-1, 384, 14, 14]          41,472
  124.      BatchNorm2d-105          [-1, 384, 14, 14]             768
  125.             ReLU-106          [-1, 384, 14, 14]               0
  126.           Conv2d-107         [-1, 1048, 14, 14]         402,432
  127.      BatchNorm2d-108         [-1, 1048, 14, 14]           2,096
  128.             ReLU-109         [-1, 1120, 14, 14]               0
  129.            Block-110         [-1, 1120, 14, 14]               0
  130.           Conv2d-111          [-1, 384, 14, 14]         430,080
  131.      BatchNorm2d-112          [-1, 384, 14, 14]             768
  132.             ReLU-113          [-1, 384, 14, 14]               0
  133.           Conv2d-114          [-1, 384, 14, 14]          41,472
  134.      BatchNorm2d-115          [-1, 384, 14, 14]             768
  135.             ReLU-116          [-1, 384, 14, 14]               0
  136.           Conv2d-117         [-1, 1048, 14, 14]         402,432
  137.      BatchNorm2d-118         [-1, 1048, 14, 14]           2,096
  138.             ReLU-119         [-1, 1144, 14, 14]               0
  139.            Block-120         [-1, 1144, 14, 14]               0
  140.           Conv2d-121          [-1, 384, 14, 14]         439,296
  141.      BatchNorm2d-122          [-1, 384, 14, 14]             768
  142.             ReLU-123          [-1, 384, 14, 14]               0
  143.           Conv2d-124          [-1, 384, 14, 14]          41,472
  144.      BatchNorm2d-125          [-1, 384, 14, 14]             768
  145.             ReLU-126          [-1, 384, 14, 14]               0
  146.           Conv2d-127         [-1, 1048, 14, 14]         402,432
  147.      BatchNorm2d-128         [-1, 1048, 14, 14]           2,096
  148.             ReLU-129         [-1, 1168, 14, 14]               0
  149.            Block-130         [-1, 1168, 14, 14]               0
  150.           Conv2d-131          [-1, 384, 14, 14]         448,512
  151.      BatchNorm2d-132          [-1, 384, 14, 14]             768
  152.             ReLU-133          [-1, 384, 14, 14]               0
  153.           Conv2d-134          [-1, 384, 14, 14]          41,472
  154.      BatchNorm2d-135          [-1, 384, 14, 14]             768
  155.             ReLU-136          [-1, 384, 14, 14]               0
  156.           Conv2d-137         [-1, 1048, 14, 14]         402,432
  157.      BatchNorm2d-138         [-1, 1048, 14, 14]           2,096
  158.             ReLU-139         [-1, 1192, 14, 14]               0
  159.            Block-140         [-1, 1192, 14, 14]               0
  160.           Conv2d-141          [-1, 384, 14, 14]         457,728
  161.      BatchNorm2d-142          [-1, 384, 14, 14]             768
  162.             ReLU-143          [-1, 384, 14, 14]               0
  163.           Conv2d-144          [-1, 384, 14, 14]          41,472
  164.      BatchNorm2d-145          [-1, 384, 14, 14]             768
  165.             ReLU-146          [-1, 384, 14, 14]               0
  166.           Conv2d-147         [-1, 1048, 14, 14]         402,432
  167.      BatchNorm2d-148         [-1, 1048, 14, 14]           2,096
  168.             ReLU-149         [-1, 1216, 14, 14]               0
  169.            Block-150         [-1, 1216, 14, 14]               0
  170.           Conv2d-151          [-1, 384, 14, 14]         466,944
  171.      BatchNorm2d-152          [-1, 384, 14, 14]             768
  172.             ReLU-153          [-1, 384, 14, 14]               0
  173.           Conv2d-154          [-1, 384, 14, 14]          41,472
  174.      BatchNorm2d-155          [-1, 384, 14, 14]             768
  175.             ReLU-156          [-1, 384, 14, 14]               0
  176.           Conv2d-157         [-1, 1048, 14, 14]         402,432
  177.      BatchNorm2d-158         [-1, 1048, 14, 14]           2,096
  178.             ReLU-159         [-1, 1240, 14, 14]               0
  179.            Block-160         [-1, 1240, 14, 14]               0
  180.           Conv2d-161          [-1, 384, 14, 14]         476,160
  181.      BatchNorm2d-162          [-1, 384, 14, 14]             768
  182.             ReLU-163          [-1, 384, 14, 14]               0
  183.           Conv2d-164          [-1, 384, 14, 14]          41,472
  184.      BatchNorm2d-165          [-1, 384, 14, 14]             768
  185.             ReLU-166          [-1, 384, 14, 14]               0
  186.           Conv2d-167         [-1, 1048, 14, 14]         402,432
  187.      BatchNorm2d-168         [-1, 1048, 14, 14]           2,096
  188.             ReLU-169         [-1, 1264, 14, 14]               0
  189.            Block-170         [-1, 1264, 14, 14]               0
  190.           Conv2d-171          [-1, 384, 14, 14]         485,376
  191.      BatchNorm2d-172          [-1, 384, 14, 14]             768
  192.             ReLU-173          [-1, 384, 14, 14]               0
  193.           Conv2d-174          [-1, 384, 14, 14]          41,472
  194.      BatchNorm2d-175          [-1, 384, 14, 14]             768
  195.             ReLU-176          [-1, 384, 14, 14]               0
  196.           Conv2d-177         [-1, 1048, 14, 14]         402,432
  197.      BatchNorm2d-178         [-1, 1048, 14, 14]           2,096
  198.             ReLU-179         [-1, 1288, 14, 14]               0
  199.            Block-180         [-1, 1288, 14, 14]               0
  200.           Conv2d-181          [-1, 384, 14, 14]         494,592
  201.      BatchNorm2d-182          [-1, 384, 14, 14]             768
  202.             ReLU-183          [-1, 384, 14, 14]               0
  203.           Conv2d-184          [-1, 384, 14, 14]          41,472
  204.      BatchNorm2d-185          [-1, 384, 14, 14]             768
  205.             ReLU-186          [-1, 384, 14, 14]               0
  206.           Conv2d-187         [-1, 1048, 14, 14]         402,432
  207.      BatchNorm2d-188         [-1, 1048, 14, 14]           2,096
  208.             ReLU-189         [-1, 1312, 14, 14]               0
  209.            Block-190         [-1, 1312, 14, 14]               0
  210.           Conv2d-191          [-1, 384, 14, 14]         503,808
  211.      BatchNorm2d-192          [-1, 384, 14, 14]             768
  212.             ReLU-193          [-1, 384, 14, 14]               0
  213.           Conv2d-194          [-1, 384, 14, 14]          41,472
  214.      BatchNorm2d-195          [-1, 384, 14, 14]             768
  215.             ReLU-196          [-1, 384, 14, 14]               0
  216.           Conv2d-197         [-1, 1048, 14, 14]         402,432
  217.      BatchNorm2d-198         [-1, 1048, 14, 14]           2,096
  218.             ReLU-199         [-1, 1336, 14, 14]               0
  219.            Block-200         [-1, 1336, 14, 14]               0
  220.           Conv2d-201          [-1, 384, 14, 14]         513,024
  221.      BatchNorm2d-202          [-1, 384, 14, 14]             768
  222.             ReLU-203          [-1, 384, 14, 14]               0
  223.           Conv2d-204          [-1, 384, 14, 14]          41,472
  224.      BatchNorm2d-205          [-1, 384, 14, 14]             768
  225.             ReLU-206          [-1, 384, 14, 14]               0
  226.           Conv2d-207         [-1, 1048, 14, 14]         402,432
  227.      BatchNorm2d-208         [-1, 1048, 14, 14]           2,096
  228.             ReLU-209         [-1, 1360, 14, 14]               0
  229.            Block-210         [-1, 1360, 14, 14]               0
  230.           Conv2d-211          [-1, 384, 14, 14]         522,240
  231.      BatchNorm2d-212          [-1, 384, 14, 14]             768
  232.             ReLU-213          [-1, 384, 14, 14]               0
  233.           Conv2d-214          [-1, 384, 14, 14]          41,472
  234.      BatchNorm2d-215          [-1, 384, 14, 14]             768
  235.             ReLU-216          [-1, 384, 14, 14]               0
  236.           Conv2d-217         [-1, 1048, 14, 14]         402,432
  237.      BatchNorm2d-218         [-1, 1048, 14, 14]           2,096
  238.             ReLU-219         [-1, 1384, 14, 14]               0
  239.            Block-220         [-1, 1384, 14, 14]               0
  240.           Conv2d-221          [-1, 384, 14, 14]         531,456
  241.      BatchNorm2d-222          [-1, 384, 14, 14]             768
  242.             ReLU-223          [-1, 384, 14, 14]               0
  243.           Conv2d-224          [-1, 384, 14, 14]          41,472
  244.      BatchNorm2d-225          [-1, 384, 14, 14]             768
  245.             ReLU-226          [-1, 384, 14, 14]               0
  246.           Conv2d-227         [-1, 1048, 14, 14]         402,432
  247.      BatchNorm2d-228         [-1, 1048, 14, 14]           2,096
  248.             ReLU-229         [-1, 1408, 14, 14]               0
  249.            Block-230         [-1, 1408, 14, 14]               0
  250.           Conv2d-231          [-1, 384, 14, 14]         540,672
  251.      BatchNorm2d-232          [-1, 384, 14, 14]             768
  252.             ReLU-233          [-1, 384, 14, 14]               0
  253.           Conv2d-234          [-1, 384, 14, 14]          41,472
  254.      BatchNorm2d-235          [-1, 384, 14, 14]             768
  255.             ReLU-236          [-1, 384, 14, 14]               0
  256.           Conv2d-237         [-1, 1048, 14, 14]         402,432
  257.      BatchNorm2d-238         [-1, 1048, 14, 14]           2,096
  258.             ReLU-239         [-1, 1432, 14, 14]               0
  259.            Block-240         [-1, 1432, 14, 14]               0
  260.           Conv2d-241          [-1, 384, 14, 14]         549,888
  261.      BatchNorm2d-242          [-1, 384, 14, 14]             768
  262.             ReLU-243          [-1, 384, 14, 14]               0
  263.           Conv2d-244          [-1, 384, 14, 14]          41,472
  264.      BatchNorm2d-245          [-1, 384, 14, 14]             768
  265.             ReLU-246          [-1, 384, 14, 14]               0
  266.           Conv2d-247         [-1, 1048, 14, 14]         402,432
  267.      BatchNorm2d-248         [-1, 1048, 14, 14]           2,096
  268.             ReLU-249         [-1, 1456, 14, 14]               0
  269.            Block-250         [-1, 1456, 14, 14]               0
  270.           Conv2d-251          [-1, 384, 14, 14]         559,104
  271.      BatchNorm2d-252          [-1, 384, 14, 14]             768
  272.             ReLU-253          [-1, 384, 14, 14]               0
  273.           Conv2d-254          [-1, 384, 14, 14]          41,472
  274.      BatchNorm2d-255          [-1, 384, 14, 14]             768
  275.             ReLU-256          [-1, 384, 14, 14]               0
  276.           Conv2d-257         [-1, 1048, 14, 14]         402,432
  277.      BatchNorm2d-258         [-1, 1048, 14, 14]           2,096
  278.             ReLU-259         [-1, 1480, 14, 14]               0
  279.            Block-260         [-1, 1480, 14, 14]               0
  280.           Conv2d-261          [-1, 384, 14, 14]         568,320
  281.      BatchNorm2d-262          [-1, 384, 14, 14]             768
  282.             ReLU-263          [-1, 384, 14, 14]               0
  283.           Conv2d-264          [-1, 384, 14, 14]          41,472
  284.      BatchNorm2d-265          [-1, 384, 14, 14]             768
  285.             ReLU-266          [-1, 384, 14, 14]               0
  286.           Conv2d-267         [-1, 1048, 14, 14]         402,432
  287.      BatchNorm2d-268         [-1, 1048, 14, 14]           2,096
  288.             ReLU-269         [-1, 1504, 14, 14]               0
  289.            Block-270         [-1, 1504, 14, 14]               0
  290.           Conv2d-271          [-1, 384, 14, 14]         577,536
  291.      BatchNorm2d-272          [-1, 384, 14, 14]             768
  292.             ReLU-273          [-1, 384, 14, 14]               0
  293.           Conv2d-274          [-1, 384, 14, 14]          41,472
  294.      BatchNorm2d-275          [-1, 384, 14, 14]             768
  295.             ReLU-276          [-1, 384, 14, 14]               0
  296.           Conv2d-277         [-1, 1048, 14, 14]         402,432
  297.      BatchNorm2d-278         [-1, 1048, 14, 14]           2,096
  298.             ReLU-279         [-1, 1528, 14, 14]               0
  299.            Block-280         [-1, 1528, 14, 14]               0
  300.           Conv2d-281          [-1, 768, 14, 14]       1,173,504
  301.      BatchNorm2d-282          [-1, 768, 14, 14]           1,536
  302.             ReLU-283          [-1, 768, 14, 14]               0
  303.           Conv2d-284            [-1, 768, 7, 7]         165,888
  304.      BatchNorm2d-285            [-1, 768, 7, 7]           1,536
  305.             ReLU-286            [-1, 768, 7, 7]               0
  306.           Conv2d-287           [-1, 2176, 7, 7]       1,671,168
  307.      BatchNorm2d-288           [-1, 2176, 7, 7]           4,352
  308.           Conv2d-289           [-1, 2176, 7, 7]      29,924,352
  309.      BatchNorm2d-290           [-1, 2176, 7, 7]           4,352
  310.             ReLU-291           [-1, 2304, 7, 7]               0
  311.            Block-292           [-1, 2304, 7, 7]               0
  312.           Conv2d-293            [-1, 768, 7, 7]       1,769,472
  313.      BatchNorm2d-294            [-1, 768, 7, 7]           1,536
  314.             ReLU-295            [-1, 768, 7, 7]               0
  315.           Conv2d-296            [-1, 768, 7, 7]         165,888
  316.      BatchNorm2d-297            [-1, 768, 7, 7]           1,536
  317.             ReLU-298            [-1, 768, 7, 7]               0
  318.           Conv2d-299           [-1, 2176, 7, 7]       1,671,168
  319.      BatchNorm2d-300           [-1, 2176, 7, 7]           4,352
  320.             ReLU-301           [-1, 2432, 7, 7]               0
  321.            Block-302           [-1, 2432, 7, 7]               0
  322.           Conv2d-303            [-1, 768, 7, 7]       1,867,776
  323.      BatchNorm2d-304            [-1, 768, 7, 7]           1,536
  324.             ReLU-305            [-1, 768, 7, 7]               0
  325.           Conv2d-306            [-1, 768, 7, 7]         165,888
  326.      BatchNorm2d-307            [-1, 768, 7, 7]           1,536
  327.             ReLU-308            [-1, 768, 7, 7]               0
  328.           Conv2d-309           [-1, 2176, 7, 7]       1,671,168
  329.      BatchNorm2d-310           [-1, 2176, 7, 7]           4,352
  330.             ReLU-311           [-1, 2560, 7, 7]               0
  331.            Block-312           [-1, 2560, 7, 7]               0
  332. AdaptiveAvgPool2d-313           [-1, 2560, 1, 1]               0
  333.           Linear-314                    [-1, 4]          10,244
  334. ================================================================
  335. Total params: 67,994,324
  336. Trainable params: 67,994,324
  337. Non-trainable params: 0
  338. ----------------------------------------------------------------
  339. Input size (MB): 0.57
  340. Forward/backward pass size (MB): 489.24
  341. Params size (MB): 259.38
  342. Estimated Total Size (MB): 749.20
  343. ----------------------------------------------------------------
复制代码
  1. # 训练循环
  2. def train(dataloader, model, loss_fn, optimizer):
  3.     size = len(dataloader.dataset)  # 训练集的大小
  4.     num_batches = len(dataloader)   # 批次数目, (size/batch_size,向上取整)
  5.     train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
  6.     for X, y in dataloader:  # 获取图片及其标签
  7.         X, y = X.to(device), y.to(device)
  8.         # 计算预测误差
  9.         pred = model(X)          # 网络输出
  10.         loss = loss_fn(pred, y)  # 计算网络输出pred和真实值y之间的差距,y为真实值,计算二者差值即为损失
  11.         # 反向传播
  12.         optimizer.zero_grad()  # grad属性归零
  13.         loss.backward()        # 反向传播
  14.         optimizer.step()       # 每一步自动更新
  15.         # 记录acc与loss
  16.         train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
  17.         train_loss += loss.item()
  18.     train_acc  /= size
  19.     train_loss /= num_batches
  20.     return train_acc, train_loss
复制代码
  1. def test(dataloader, model, loss_fn):
  2.     size = len(dataloader.dataset)  # 训练集的大小
  3.     num_batches = len(dataloader)   # 批次数目, (size/batch_size,向上取整)
  4.     test_loss, test_acc = 0, 0  # 初始化测试损失和正确率
  5.     # 当不进行训练时,停止梯度更新,节省计算内存消耗
  6.    # with torch.no_grad():
  7.     for imgs, target in dataloader:  # 获取图片及其标签
  8.         with torch.no_grad():
  9.             imgs, target = imgs.to(device), target.to(device)
  10.             # 计算误差
  11.             tartget_pred = model(imgs)          # 网络输出
  12.             loss = loss_fn(tartget_pred, target)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
  13.             # 记录acc与loss
  14.             test_loss += loss.item()
  15.             test_acc  += (tartget_pred.argmax(1) == target).type(torch.float).sum().item()
  16.     test_acc  /= size
  17.     test_loss /= num_batches
  18.     return test_acc, test_loss
复制代码
  1. import copy
  2. optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
  3. loss_fn = nn.CrossEntropyLoss() #创建损失函数
  4. epochs = 40
  5. train_loss = []
  6. train_acc = []
  7. test_loss = []
  8. test_acc = []
  9. best_acc = 0 #设置一个最佳准确率,作为最佳模型的判别指标
  10. if hasattr(torch.cuda, 'empty_cache'):
  11.     torch.cuda.empty_cache()
  12. for epoch in range(epochs):
  13.     model.train()
  14.     epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
  15.     #scheduler.step() #更新学习率(调用官方动态学习率接口时使用)
  16.     model.eval()
  17.     epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
  18.     #保存最佳模型到best_model
  19.     if epoch_test_acc > best_acc:
  20.         best_acc = epoch_test_acc
  21.         best_model = copy.deepcopy(model)
  22.     train_acc.append(epoch_train_acc)
  23.     train_loss.append(epoch_train_loss)
  24.     test_acc.append(epoch_test_acc)
  25.     test_loss.append(epoch_test_loss)
  26.     #获取当前的学习率
  27.     lr = optimizer.state_dict()['param_groups'][0]['lr']
  28.     template = ('Epoch: {:2d}. Train_acc: {:.1f}%, Train_loss: {:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr: {:.2E}')
  29.     print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))
  30. PATH = r'C:\Users\11054\Desktop\kLearning\J4_learning\J3_best_model.pth'
  31. torch.save(model.state_dict(), PATH)
  32. print('Done')
复制代码
  1. Epoch:  1. Train_acc: 42.3%, Train_loss: 1.389, Test_acc:48.7%, Test_loss:2.097, Lr: 1.00E-04
  2. Epoch:  2. Train_acc: 63.5%, Train_loss: 0.933, Test_acc:53.1%, Test_loss:2.493, Lr: 1.00E-04
  3. Epoch:  3. Train_acc: 69.2%, Train_loss: 0.795, Test_acc:69.9%, Test_loss:0.845, Lr: 1.00E-04
  4. Epoch:  4. Train_acc: 72.3%, Train_loss: 0.702, Test_acc:69.0%, Test_loss:1.069, Lr: 1.00E-04
  5. Epoch:  5. Train_acc: 78.3%, Train_loss: 0.585, Test_acc:82.3%, Test_loss:0.656, Lr: 1.00E-04
  6. Epoch:  6. Train_acc: 82.5%, Train_loss: 0.474, Test_acc:79.6%, Test_loss:0.602, Lr: 1.00E-04
  7. Epoch:  7. Train_acc: 83.6%, Train_loss: 0.458, Test_acc:83.2%, Test_loss:0.699, Lr: 1.00E-04
  8. Epoch:  8. Train_acc: 86.9%, Train_loss: 0.368, Test_acc:85.8%, Test_loss:0.577, Lr: 1.00E-04
  9. Epoch:  9. Train_acc: 88.5%, Train_loss: 0.371, Test_acc:78.8%, Test_loss:0.574, Lr: 1.00E-04
  10. Epoch: 10. Train_acc: 87.6%, Train_loss: 0.345, Test_acc:87.6%, Test_loss:0.392, Lr: 1.00E-04
  11. Epoch: 11. Train_acc: 91.6%, Train_loss: 0.247, Test_acc:80.5%, Test_loss:0.443, Lr: 1.00E-04
  12. Epoch: 12. Train_acc: 91.2%, Train_loss: 0.310, Test_acc:86.7%, Test_loss:0.361, Lr: 1.00E-04
  13. Epoch: 13. Train_acc: 93.6%, Train_loss: 0.201, Test_acc:87.6%, Test_loss:0.336, Lr: 1.00E-04
  14. Epoch: 14. Train_acc: 89.2%, Train_loss: 0.322, Test_acc:84.1%, Test_loss:0.438, Lr: 1.00E-04
  15. Epoch: 15. Train_acc: 91.8%, Train_loss: 0.226, Test_acc:88.5%, Test_loss:0.343, Lr: 1.00E-04
  16. Epoch: 16. Train_acc: 94.5%, Train_loss: 0.146, Test_acc:87.6%, Test_loss:0.321, Lr: 1.00E-04
  17. Epoch: 17. Train_acc: 96.7%, Train_loss: 0.127, Test_acc:88.5%, Test_loss:0.436, Lr: 1.00E-04
  18. Epoch: 18. Train_acc: 96.5%, Train_loss: 0.096, Test_acc:92.0%, Test_loss:0.241, Lr: 1.00E-04
  19. Epoch: 19. Train_acc: 97.1%, Train_loss: 0.094, Test_acc:86.7%, Test_loss:0.430, Lr: 1.00E-04
  20. Epoch: 20. Train_acc: 95.8%, Train_loss: 0.134, Test_acc:59.3%, Test_loss:2.130, Lr: 1.00E-04
  21. Epoch: 21. Train_acc: 95.1%, Train_loss: 0.125, Test_acc:92.9%, Test_loss:0.230, Lr: 1.00E-04
  22. Epoch: 22. Train_acc: 95.4%, Train_loss: 0.144, Test_acc:87.6%, Test_loss:0.402, Lr: 1.00E-04
  23. Epoch: 23. Train_acc: 97.1%, Train_loss: 0.081, Test_acc:91.2%, Test_loss:0.282, Lr: 1.00E-04
  24. Epoch: 24. Train_acc: 98.5%, Train_loss: 0.049, Test_acc:92.9%, Test_loss:0.280, Lr: 1.00E-04
  25. Epoch: 25. Train_acc: 98.5%, Train_loss: 0.054, Test_acc:88.5%, Test_loss:0.413, Lr: 1.00E-04
  26. Epoch: 26. Train_acc: 97.8%, Train_loss: 0.072, Test_acc:87.6%, Test_loss:0.330, Lr: 1.00E-04
  27. Epoch: 27. Train_acc: 98.9%, Train_loss: 0.045, Test_acc:91.2%, Test_loss:0.244, Lr: 1.00E-04
  28. Epoch: 28. Train_acc: 94.9%, Train_loss: 0.156, Test_acc:77.0%, Test_loss:0.813, Lr: 1.00E-04
  29. Epoch: 29. Train_acc: 95.8%, Train_loss: 0.149, Test_acc:91.2%, Test_loss:0.372, Lr: 1.00E-04
  30. Epoch: 30. Train_acc: 97.8%, Train_loss: 0.068, Test_acc:89.4%, Test_loss:0.281, Lr: 1.00E-04
  31. Epoch: 31. Train_acc: 98.5%, Train_loss: 0.030, Test_acc:83.2%, Test_loss:0.529, Lr: 1.00E-04
  32. Epoch: 32. Train_acc: 98.5%, Train_loss: 0.054, Test_acc:91.2%, Test_loss:0.304, Lr: 1.00E-04
  33. Epoch: 33. Train_acc: 98.7%, Train_loss: 0.048, Test_acc:91.2%, Test_loss:0.311, Lr: 1.00E-04
  34. Epoch: 34. Train_acc: 94.2%, Train_loss: 0.179, Test_acc:93.8%, Test_loss:0.244, Lr: 1.00E-04
  35. Epoch: 35. Train_acc: 95.8%, Train_loss: 0.119, Test_acc:87.6%, Test_loss:0.426, Lr: 1.00E-04
  36. Epoch: 36. Train_acc: 98.2%, Train_loss: 0.064, Test_acc:88.5%, Test_loss:0.341, Lr: 1.00E-04
  37. Epoch: 37. Train_acc: 98.9%, Train_loss: 0.047, Test_acc:92.9%, Test_loss:0.223, Lr: 1.00E-04
  38. Epoch: 38. Train_acc: 98.9%, Train_loss: 0.038, Test_acc:87.6%, Test_loss:0.307, Lr: 1.00E-04
  39. Epoch: 39. Train_acc: 98.9%, Train_loss: 0.035, Test_acc:92.9%, Test_loss:0.188, Lr: 1.00E-04
  40. Epoch: 40. Train_acc: 99.3%, Train_loss: 0.025, Test_acc:93.8%, Test_loss:0.159, Lr: 1.00E-04
  41. Done
复制代码
  1. import matplotlib.pyplot as plt
  2. import warnings
  3. warnings.filterwarnings("ignore")               #忽略警告信息
  4. plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
  5. plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
  6. plt.rcParams['figure.dpi']         = 100        #分辨率
  7. epochs_range = range(epochs)
  8. plt.figure(figsize=(12, 3))
  9. plt.subplot(1, 2, 1)
  10. plt.plot(epochs_range, train_acc, label='Training Accuracy')
  11. plt.plot(epochs_range, test_acc, label='Test Accuracy')
  12. plt.legend(loc='lower right')
  13. plt.title('Training and Validation Accuracy')
  14. plt.subplot(1, 2, 2)
  15. plt.plot(epochs_range, train_loss, label='Training Loss')
  16. plt.plot(epochs_range, test_loss, label='Test Loss')
  17. plt.legend(loc='upper right')
  18. plt.title('Training and Validation Loss')
  19. plt.show()
复制代码


个人总结


  • 学习了ResNet与DenseNet团结网络->DPN网络
  • DPN网络的核心是双路径块(Dual Path Block),其结构如下:


  • 输入:输入特性图被分为两部分。
  • 分组卷积:一部分特性图通太过组卷积举行处置处罚,类似于ResNeXt中的操纵。
  • 麋集毗连:另一部分特性图通过麋集毗连举行处置处罚,类似于DenseNet中的操纵。
  • 特性融合:两部分特性图通过特定的融合方式(如相加或拼接)举行融合,形成终极的输出特性图。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!qidao123.com:ToB企服之家,中国第一个企服评测及软件市场,开放入驻,技术点评得现金

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
回复

使用道具 举报

登录后关闭弹窗

登录参与点评抽奖  加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表