sdnet

[复制链接]
发表于 2025-12-29 13:25:43 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

×
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import math
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch._utils
  9. import torch.nn.functional as F
  10. import torch.nn.init as init
  11. import torch.optim as optim
  12. from Lib.config import config
  13. import random
  14. import scipy.io as scio
  15. from torch.utils.data import TensorDataset, DataLoader
  16. import csv
  17. import matplotlib.pyplot as plt
  18. #  定义一个3x3卷积!kernel_initializer='he_normal','glorot_normal'
  19. def regularized_padded_conv(in_channels, out_channels, kernel_size, stride=1):
  20.     conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=kernel_size // 2, bias=False)
  21.     # 使用 kaiming_normal_ 进行初始化
  22.     nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='leaky_relu')
  23.     return conv
  24. ####################### 通道注意力机制 ##########################
  25. class ChannelAttention(nn.Module):
  26.     def __init__(self, in_planes, ratio=16):
  27.         super(ChannelAttention, self).__init__()
  28.         self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
  29.         self.max_pool = nn.AdaptiveMaxPool2d((1, 1))
  30.         compressed_channels = in_planes // ratio
  31.         self.conv1 = nn.Conv2d(in_planes, compressed_channels, kernel_size=1, stride=1, padding=0)
  32.         self.conv2 = nn.Conv2d(compressed_channels, in_planes, kernel_size=1, stride=1, padding=0)
  33.         self.leaky_relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
  34.     def forward(self, inputs):
  35.         avg = self.avg_pool(inputs)
  36.         max = self.max_pool(inputs)
  37.         avg = self.conv2(self.leaky_relu(self.conv1(avg)))
  38.         max = self.conv2(self.leaky_relu(self.conv1(max)))
  39.         out = avg + max
  40.         out = torch.sigmoid(out)
  41.         return out
  42. ########################### 空间注意力机制 ###########################
  43. class SpatialAttention(nn.Module):
  44.     def __init__(self, kernel_size=7):
  45.         super(SpatialAttention, self).__init__()
  46.         self.conv1 = regularized_padded_conv(2, 1, kernel_size, stride=1)
  47.         self.sigmoid = nn.Sigmoid()
  48.         self.leaky_relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
  49.     def forward(self, inputs):
  50.         avg_out = torch.mean(inputs, dim=1, keepdim=True)
  51.         max_out, _ = torch.max(inputs, dim=1, keepdim=True)
  52.         out = torch.cat([avg_out, max_out], dim=1)
  53.         out = self.conv1(out)
  54.         out = self.sigmoid(out)
  55.         return out
  56. ####################################csc layer###########################################################
  57. class elasnet_prox(nn.Module):
  58.     r"""Applies the elastic net proximal operator,
  59.     NOTS: it will degenerate to ell1_prox if mu=0.0
  60.     The elastic net proximal operator function is given as the following function
  61.     \argmin_{x} \lambda ||x||_1 + \mu /2 ||x||_2^2 + 0.5 ||x - input||_2^2
  62.     Args:
  63.       lambd: the :math:`\lambda` value on the ell_1 penalty term. Default: 0.5
  64.       mu:    the :math:`\mu` value on the ell_2 penalty term. Default: 0.0
  65.     Shape:
  66.       - Input: :math:`(N, *)` where `*` means, any number of additional
  67.         dimensions
  68.       - Output: :math:`(N, *)`, same shape as the input
  69.     """
  70.     def __init__(self, lambd=0.5, mu=0.0):
  71.         super(elasnet_prox, self).__init__()
  72.         self.lambd = lambd
  73.         self.scaling_mu = 1.0 / (1.0 + mu)
  74.     def forward(self, input):
  75.         return F.softshrink(input * self.scaling_mu, self.lambd * self.scaling_mu)
  76.     def extra_repr(self):
  77.         return '{} {}'.format(self.lambd, self.scaling_mu)
  78. class DictBlock(nn.Module):
  79.     # c = argmin_c lmbd * ||c||_1  +  mu/2 * ||c||_2^2 + 1 / 2 * ||x - weight (@conv) c||_2^2
  80.     def __init__(self, n_channel, dict_size, mu=0.0, lmbd=0.0, n_dict=1, non_negative=True,
  81.                  stride=1, kernel_size=3, padding=1, share_weight=True, square_noise=True,
  82.                  n_steps=10, step_size_fixed=True, step_size=0.1, w_norm=True,
  83.                  padding_mode="constant"):
  84.         super(DictBlock, self).__init__()
  85.         self.mu = mu
  86.         self.lmbd = lmbd  # LAMBDA
  87.         self.n_dict = n_dict
  88.         self.stride = stride
  89.         self.kernel_size = (kernel_size, kernel_size)
  90.         self.padding = padding
  91.         self.padding_mode = padding_mode
  92.         assert self.padding_mode in ['constant', 'reflect', 'replicate', 'circular']
  93.         self.groups = 1
  94.         self.n_steps = n_steps
  95.         self.conv_transpose_output_padding = 0 if stride == 1 else 1
  96.         self.w_norm = w_norm
  97.         self.non_negative = non_negative
  98.         self.v_max = None
  99.         self.v_max_error = 0.
  100.         self.xsize = None
  101.         self.zsize = None
  102.         self.lmbd_ = None
  103.         self.square_noise = square_noise
  104.         self.weight = nn.Parameter(torch.Tensor(dict_size, self.n_dict * n_channel, kernel_size, kernel_size))
  105.         with torch.no_grad():
  106.             init.kaiming_uniform_(self.weight)
  107.         self.nonlinear = elasnet_prox(self.lmbd * step_size, self.mu * step_size)
  108.         self.register_buffer('step_size', torch.tensor(step_size, dtype=torch.float))
  109.     def fista_forward(self, x):
  110.         for i in range(self.n_steps):
  111.             weight = self.weight
  112.             step_size = self.step_size
  113.             if i == 0:
  114.                 c_pre = 0.
  115.                 c = step_size * F.conv2d(x.repeat(1, self.n_dict, 1, 1), weight, bias=None, stride=self.stride,
  116.                                          padding=self.padding)
  117.                 c = self.nonlinear(c)
  118.             elif i == 1:
  119.                 c_pre = c
  120.                 xp = F.conv_transpose2d(c, weight, bias=None, stride=self.stride, padding=self.padding,
  121.                                         output_padding=self.conv_transpose_output_padding)
  122.                 r = x.repeat(1, self.n_dict, 1, 1) - xp
  123.                 if self.square_noise:
  124.                     gra = F.conv2d(r, weight, bias=None, stride=self.stride, padding=self.padding)
  125.                 else:
  126.                     w = r.view(r.size(0), -1)
  127.                     normw = w.norm(p=2, dim=1, keepdim=True).clamp_min(1e-12).expand_as(w).detach()
  128.                     w = (w / normw).view(r.size())
  129.                     gra = F.conv2d(w, weight, bias=None, stride=self.stride, padding=self.padding) * 0.5
  130.                 c = c + step_size * gra
  131.                 c = self.nonlinear(c)
  132.                 t = (math.sqrt(5.0) + 1.0) / 2.0
  133.             else:
  134.                 t_pre = t
  135.                 t = (math.sqrt(1.0 + 4.0 * t_pre * t_pre) + 1) / 2.0
  136.                 a = (t_pre + t - 1.0) / t * c + (1.0 - t_pre) / t * c_pre
  137.                 c_pre = c
  138.                 xp = F.conv_transpose2d(c, weight, bias=None, stride=self.stride, padding=self.padding,
  139.                                         output_padding=self.conv_transpose_output_padding)
  140.                 r = x.repeat(1, self.n_dict, 1, 1) - xp
  141.                 if self.square_noise:
  142.                     gra = F.conv2d(r, weight, bias=None, stride=self.stride, padding=self.padding)
  143.                 else:
  144.                     w = r.view(r.size(0), -1)
  145.                     normw = w.norm(p=2, dim=1, keepdim=True).clamp_min(1e-12).expand_as(w).detach()
  146.                     w = (w / normw).view(r.size())
  147.                     gra = F.conv2d(w, weight, bias=None, stride=self.stride, padding=self.padding) * 0.5
  148.                 c = a + step_size * gra
  149.                 c = self.nonlinear(c)
  150.             if self.non_negative:
  151.                 c = F.leaky_relu(c, negative_slope=0.1)
  152.         return c, weight
  153.     def forward(self, x):
  154.         if self.xsize is None:
  155.             self.xsize = (x.size(-3), x.size(-2), x.size(-1))
  156.             print(self.xsize)
  157.         else:
  158.             assert self.xsize[-3] == x.size(-3) and self.xsize[-2] == x.size(-2) and self.xsize[-1] == x.size(-1)
  159.         if self.w_norm:
  160.             self.normalize_weight()
  161.         c, weight = self.fista_forward(x)
  162.         # Compute loss
  163.         xp = F.conv_transpose2d(c, weight, bias=None, stride=self.stride, padding=self.padding,
  164.                                 output_padding=self.conv_transpose_output_padding)
  165.         r = x.repeat(1, self.n_dict, 1, 1) - xp
  166.         r_loss = torch.sum(torch.pow(r, 2)) / self.n_dict
  167.         c_loss = self.lmbd * torch.sum(torch.abs(c)) + self.mu / 2. * torch.sum(torch.pow(c, 2))
  168.         if self.zsize is None:
  169.             self.zsize = (c.size(-3), c.size(-2), c.size(-1))
  170.             print(self.zsize)
  171.         else:
  172.             assert self.zsize[-3] == c.size(-3) and self.zsize[-2] == c.size(-2) and self.zsize[-1] == c.size(-1)
  173.         if self.lmbd_ is None and config.MODEL.ADAPTIVELAMBDA:
  174.             self.lmbd_ = self.lmbd * self.xsize[-3] * self.xsize[-2] * self.xsize[-1] / (
  175.                         self.zsize[-3] * self.zsize[-2] * self.zsize[-1])
  176.             self.lmbd = self.lmbd_
  177.             print("======")
  178.             print("xsize", self.xsize)
  179.             print("zsize", self.zsize)
  180.             print("new lmbd: ", self.lmbd)
  181.         return c, (r_loss, c_loss)
  182.     def update_stepsize(self):
  183.         step_size = 0.9 / self.power_iteration(self.weight)
  184.         self.step_size = self.step_size * 0. + step_size
  185.         self.nonlinear.lambd = self.lmbd * step_size
  186.         self.nonlinear.scaling_mu = 1.0 / (1.0 + self.mu * step_size)
  187.     def normalize_weight(self):
  188.         with torch.no_grad():
  189.             w = self.weight.view(self.weight.size(0), -1)
  190.             normw = w.norm(p=2, dim=1, keepdim=True).clamp_min(1e-12).expand_as(w)
  191.             w = (w / normw).view(self.weight.size())
  192.             self.weight.data = w.data
  193.     def power_iteration(self, weight):
  194.         max_iteration = 50
  195.         v_max_error = 1.0e5
  196.         tol = 1.0e-5
  197.         k = 0
  198.         with torch.no_grad():
  199.             if self.v_max is None:
  200.                 c = weight.shape[0]
  201.                 v = torch.randn(size=(1, c, self.zsize[-2], self.zsize[-1])).to(weight.device)
  202.             else:
  203.                 v = self.v_max.clone()
  204.             while k < max_iteration and v_max_error > tol:
  205.                 tmp = F.conv_transpose2d(
  206.                     v, weight, bias=None, stride=self.stride, padding=self.padding,
  207.                     output_padding=self.conv_transpose_output_padding
  208.                 )
  209.                 v_ = F.conv2d(tmp, weight, bias=None, stride=self.stride, padding=self.padding)
  210.                 v_ = F.normalize(v_.view(-1), dim=0, p=2).view(v.size())
  211.                 v_max_error = torch.sum((v_ - v) ** 2)
  212.                 k += 1
  213.                 v = v_
  214.             v_max = v.clone()
  215.             Dv_max = F.conv_transpose2d(
  216.                 v_max, weight, bias=None, stride=self.stride, padding=self.padding,
  217.                 output_padding=self.conv_transpose_output_padding
  218.             )  # Dv
  219.             lambda_max = torch.sum(Dv_max ** 2).item()  # vTDTDv / vTv, ignore the vTv since vTv = 1
  220.         self.v_max = v_max
  221.         return lambda_max
  222. ################################# SDNet ################################################################
  223. from Lib.config import config as _cfg
  224. cfg = _cfg
  225. class DictConv2d(nn.Module):
  226.     def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
  227.         super(DictConv2d, self).__init__()
  228.         self.dn = DictBlock(
  229.             in_channels, out_channels, stride=stride, kernel_size=kernel_size, padding=padding,
  230.             mu=cfg['MODEL']['MU'], lmbd=cfg['MODEL']['LAMBDA'][0], square_noise=cfg['MODEL']['SQUARE_NOISE'],
  231.             n_dict=cfg['MODEL']['EXPANSION_FACTOR'], non_negative=cfg['MODEL']['NONEGATIVE'],
  232.             n_steps=cfg['MODEL']['NUM_LAYERS'], w_norm=cfg['MODEL']['WNORM']
  233.         )
  234.         self.rc = None
  235.         self.r_loss = []
  236.     def get_rc(self):
  237.         if self.rc is None:
  238.             raise ValueError("should call forward first.")
  239.         else:
  240.             return self.rc
  241.     def forward(self, x):
  242.         out, rc = self.dn(x)
  243.         self.rc = rc
  244.         if self.training is False:
  245.             self.r_loss.extend([self.rc[0].item() / len(x)] * len(x))
  246.         return out
  247. #########模型构建###############
  248. class SDNet_model(nn.Module):
  249.     def __init__(self, dropout1, dropout2, num_classes=2):
  250.         super(SDNet_model, self).__init__()
  251.         #  self.layer0 = nn.Sequential(
  252.         #      DictConv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False),
  253.         #      nn.BatchNorm2d(64),
  254.         #      nn.ReLU(inplace=True),
  255.         #  )
  256.         self.conv0 = nn.Conv2d(1, 64, kernel_size=(3, 3), padding=(1, 1))
  257.         self.bn0 = nn.BatchNorm2d(64)
  258.         self.pool0 = nn.MaxPool2d(kernel_size=(2, 2))
  259.         self.conv1 = nn.Conv2d(64, 128, kernel_size=(3, 3), padding=(1, 1))
  260.         self.bn1 = nn.BatchNorm2d(128)
  261.         self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
  262.         self.dropout1 = nn.Dropout2d(p=dropout1)
  263.         self.layer0 = nn.Sequential(
  264.             DictConv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False),
  265.             nn.BatchNorm2d(256),
  266.             nn.LeakyReLU(inplace=True),
  267.         )
  268.         self.conv2 = nn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1))
  269.         self.bn2 = nn.BatchNorm2d(256)
  270.         self.ca = ChannelAttention(256)
  271.         self.sa = SpatialAttention()
  272.         self.conv3 = nn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1))
  273.         self.pool2 = nn.MaxPool2d(kernel_size=(2, 2))
  274.         self.dropout2 = nn.Dropout2d(p=dropout2)
  275.         self.flatten = nn.Flatten()
  276.         self.fc1 = nn.Linear(256 * 12 * 75, 512)
  277.         self.fc2 = nn.Linear(512, 256)
  278.         self.fc3 = nn.Linear(256, num_classes)
  279.         self.leaky_relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
  280.         self.sigmoid = nn.Sigmoid()
  281.     def update_stepsize(self):
  282.         for m in self.modules():
  283.             if isinstance(m, DictBlock):
  284.                 m.update_stepsize()
  285.     def get_rc(self):
  286.         rc_list = []
  287.         for m in self.modules():
  288.             if isinstance(m, DictConv2d):
  289.                 rc_list.append(m.get_rc())
  290.         return rc_list
  291.     def forward(self, x):
  292.         #  x = self.layer0(x)
  293.         x = self.conv0(x)
  294.         x = self.bn0(x)
  295.         x = self.pool0(x)
  296.         x = self.conv1(x)
  297.         x = self.bn1(x)
  298.         x = self.pool1(x)
  299.         x = self.dropout1(x)
  300.         x = self.layer0(x)
  301.         x = self.conv2(x)
  302.         x = self.bn2(x)
  303.         x = self.ca(x) * x
  304.         x = self.sa(x) * x
  305.         x = self.conv3(x)
  306.         x = self.pool2(x)
  307.         # print(x.shape)
  308.         x = self.dropout2(x)
  309.         x = self.flatten(x)
  310.         # print(x.shape)
  311.         x = self.leaky_relu(self.fc1(x))
  312.         x = self.fc2(x)
  313.         x = self.leaky_relu(x)
  314.         x = self.fc3(x)
  315.         x = self.sigmoid(x)
  316.         return x
  317. def SDCNN_model(num_classes, dropout1, dropout2):
  318.     model = SDNet_model(num_classes=num_classes, dropout1=dropout1, dropout2=dropout2)
  319.     return model
  320. randomSeed = 1
  321. random.seed(randomSeed)
  322. torch.manual_seed(randomSeed)
  323. np.random.seed(randomSeed)
  324. def main():
  325.     # 数据导入
  326.     dataFile = r'C:\Users\sun\Desktop\SDNET\SDNet-main\data\python_energy_T.mat'
  327.     data = scio.loadmat(dataFile)
  328.     train_input = data['train_input']
  329.     train_output = data['train_output']
  330.     test_input = data['test_input']
  331.     test_output = data['test_output']
  332.     validate_input = data['validate_input']
  333.     validate_output = data['validate_output']
  334.     train_input = train_input.reshape(-1, 1, 100, 300).astype('float32')
  335.     test_input = test_input.reshape(-1, 1, 100, 300).astype('float32')
  336.     validate_input = validate_input.reshape(-1, 1, 100, 300).astype('float32')
  337.     train_input = torch.from_numpy(train_input)
  338.     train_output = torch.from_numpy(train_output)
  339.     validate_input = torch.from_numpy(validate_input)
  340.     validate_output = torch.from_numpy(validate_output)
  341.     test_input = torch.from_numpy(test_input)
  342.     test_output = torch.from_numpy(test_output)
  343.     # 定义超参数搜索空间
  344.     epochs = range(50, 201)
  345.     batch_sizes = [64, 128, 256]
  346.     dropouts1 = [0.1, 0.3, 0.5]
  347.     dropouts2 = [0.1, 0.3, 0.5]
  348.     # 初始化最优超参数和最高准确度
  349.     best_hyperparams = {'epoch': None, 'batch_size': None, 'dropout1': None, 'dropout2': None}
  350.     best_accuracy = 0.0
  351.     # 定义随机搜索算法的迭代次数
  352.     num_iterations = 10
  353.     # 随机搜索算法
  354.     for i in range(num_iterations):
  355.         # 随机选择超参数组合
  356.         epoch = random.choice(epochs)
  357.         batch_size = random.choice(batch_sizes)
  358.         dropout1 = random.choice(dropouts1)
  359.         dropout2 = random.choice(dropouts2)
  360.         print(f"Iteration {i+1}/{num_iterations}: epoch={epoch}, batch_size={batch_size}, dropout1={dropout1}, dropout2={dropout2}")
  361.         # 实例化模型、损失函数和优化器
  362.         model = SDCNN_model(num_classes=2, dropout1=dropout1, dropout2=dropout2)
  363.         criterion = nn.BCELoss()
  364.         optimizer = optim.Adam(model.parameters(), lr=0.001)
  365.         # 将数据转换为PyTorch DataLoader
  366.         train_dataset = TensorDataset(train_input, torch.tensor(train_output).float())
  367.         valid_dataset = TensorDataset(validate_input, torch.tensor(validate_output).float())
  368.         train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  369.         valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
  370.         # 实例化学习率调度器 #diff 添加学习率调度器
  371.         scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
  372.         # 训练模型
  373.         for e in range(epoch):
  374.             model.train()
  375.             for inputs, targets in train_loader:
  376.                 inputs, targets = inputs, targets
  377.                 optimizer.zero_grad()
  378.                 outputs = model(inputs)
  379.                 loss = criterion(outputs, targets)
  380.                 loss.backward()
  381.                 optimizer.step()
  382.             scheduler.step()
  383.         # 评估模型
  384.         model.eval()
  385.         correct = 0
  386.         total = 0
  387.         with torch.no_grad():
  388.             for inputs, targets in valid_loader:
  389.                 inputs, targets = inputs, targets
  390.                 outputs = model(inputs)
  391.                 predicted = torch.argmax(outputs, dim=1)
  392.                 total += targets.size(0)
  393.                 targets_index = torch.argmax(targets, dim=1)
  394.                 correct += (predicted == targets_index).sum().item()
  395.         accuracy = 100 * correct / total
  396.         print(f"Iteration {i+1}: Accuracy={accuracy:.2f}%")
  397.         # 更新最优超参数和最高准确度
  398.         if accuracy > best_accuracy:
  399.             best_hyperparams['epoch'] = epoch
  400.             best_hyperparams['batch_size'] = batch_size
  401.             best_hyperparams['dropout1'] = dropout1
  402.             best_hyperparams['dropout2'] = dropout2
  403.             best_accuracy = accuracy
  404.     print(f"New best accuracy: {best_accuracy:.2f}% with hyperparameters {best_hyperparams}")
  405.     # 使用找到的最佳超参数进行最终训练
  406.     best_epoch = best_hyperparams['epoch']
  407.     best_batch_size = best_hyperparams['batch_size']
  408.     best_dropout1 = best_hyperparams['dropout1']
  409.     best_dropout2 = best_hyperparams['dropout2']
  410.     def weights_init(m):
  411.         if isinstance(m, (nn.Conv2d, nn.Linear)):
  412.             nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
  413.             if m.bias is not None:
  414.                 nn.init.constant_(m.bias, 0)
  415.     # 重新实例化模型以确保权重是新的
  416.     model = SDCNN_model(num_classes=2, dropout1=best_dropout1, dropout2=best_dropout2)
  417.     model.apply(weights_init)
  418.     optimizer = optim.Adam(model.parameters(), lr=0.001)
  419.     # 使用最佳批量大小创建数据加载器
  420.     train_loader = DataLoader(train_dataset, batch_size=best_batch_size, shuffle=True)
  421.     valid_loader = DataLoader(valid_dataset, batch_size=best_batch_size, shuffle=False)
  422.     # 实例化学习率调度器 #diff 添加学习率调度器
  423.     scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
  424.     # 特征可视化准备
  425.     feature_maps = {}
  426.     def get_activation(name):
  427.         def hook(model, input, output):
  428.             feature_maps[name] = output.detach()
  429.         return hook
  430.     # 注册钩子 #diff 注册前向钩子以提取特征图
  431.     for name, layer in model.named_modules():
  432.         if isinstance(layer, nn.Conv2d) or isinstance(layer, DictConv2d):
  433.             layer.register_forward_hook(get_activation(name))
  434.     # 训练模型
  435.     for e in range(best_epoch):
  436.         model.train()
  437.         running_loss = 0.0
  438.         for inputs, targets in train_loader:
  439.             inputs, targets = inputs, targets
  440.             optimizer.zero_grad()
  441.             outputs = model(inputs)
  442.             loss = criterion(outputs.squeeze(), targets.squeeze())
  443.             loss.backward()
  444.             optimizer.step()
  445.             running_loss += loss.item()  # 累加损失以计算平均损失
  446.         scheduler.step()
  447.         print(f'Epoch {e + 1}/{best_epoch}, Loss: {running_loss / len(train_loader):.4f}')
  448.         # 评估模型
  449.         model.eval()  # 设置模型为评估模式
  450.         validation_loss = 0.0
  451.         with torch.no_grad():
  452.             for inputs, targets in valid_loader:
  453.                 inputs, targets = inputs, targets
  454.                 outputs = model(inputs)
  455.                 validation_loss += criterion(outputs.squeeze(), targets.squeeze()).item()
  456.         print(f'Validation Loss: {validation_loss / len(valid_loader):.4f}')
  457.     model.eval()
  458.     with torch.no_grad():
  459.         sample_inputs = validate_input[:1]
  460.         model(sample_inputs)
  461.     def visualize_features(feature_maps, layer_names, num_images=5):
  462.         for layer_name in layer_names:
  463.             act = feature_maps.get(layer_name)
  464.             if act is None:
  465.                 continue
  466.             act = act.cpu().numpy()
  467.             num_channels = act.shape[1]
  468.             plt.figure(figsize=(20, 10))
  469.             for i in range(min(num_channels, 64)):
  470.                 plt.subplot(8, 8, i + 1)
  471.                 plt.imshow(act[0, i, :, :], cmap='viridis')
  472.                 plt.axis('off')
  473.             plt.suptitle(f'Feature Maps of {layer_name}')
  474.             plt.savefig(f'feature_maps_{layer_name}.png')
  475.             plt.close()
  476.     layers_to_visualize = ['conv0', 'conv1', 'DictConv2d', 'conv2', 'conv3']
  477.     visualize_features(feature_maps, layers_to_visualize)
  478.     model.eval()
  479.     with torch.no_grad():
  480.         predictions = model(test_input.float())
  481.         probabilities = predictions
  482.         predicted_labels = torch.argmax(probabilities, dim=1)
  483.         predict = predicted_labels.cpu().numpy()
  484.         print(predict)
  485.     with open(r'C:\Users\sun\Desktop\SDNET\SDNet-main\predict_label.csv', 'w', newline='') as pr_file:
  486.         writer = csv.writer(pr_file)
  487.         for label in predict:
  488.             writer.writerow([label])
  489.     with open(r'C:\Users\sun\Desktop\SDNET\SDNet-main\pr.csv', 'w+') as pr_file:
  490.         out = [f"{i[0]},{i[1]}" for i in probabilities]
  491.         pr_file.write("\n".join(out))
  492.     # 调用函数保存预测结果
  493.     # save_predictions_to_csv(probabilities.cpu().numpy(), 'pr.csv')
  494.     def save_model_complete(model, filename=r'C:\Users\sun\Desktop\SDNET\SDNet-main\sdnet_model.pth'):
  495.         torch.save(model.state_dict(), filename)
  496.         print(f"Complete model saved as {filename}")
  497.     save_model_complete(model)
  498. if __name__ == '__main__':
  499.     main()
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!qidao123.com:ToB企服之家,中国第一个企服评测及软件市场,开放入驻,技术点评得现金
回复

使用道具 举报

登录后关闭弹窗

登录参与点评抽奖  加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表