PyTorch使用教程(4)-torch.nn

打印 上一主题 下一主题

主题 841|帖子 841|积分 2523

torch.nn 是 PyTorch 深度学习框架中的一个焦点模块,专门用于构建和训练神经网络。它提供了一系列用于构建神经网络所需的组件,包括层(Layers)、激活函数(Activation Functions)、损失函数(Loss Functions)等。orch.nn模块的焦点是nn.Module类,它是所有神经网络组件的基类。通过继承nn.Module类并实现其forward方法,我们可以定义自己的神经网络模子。torch.nn的常用组件预览如下图:

一、基本功能

torch.nn 模块的主要功能是提供神经网络构建所需的各种类和函数。这些类和函数使得开发者能够轻松地定义、初始化和训练神经网络模子。无论是简朴的全连接网络照旧复杂的卷积神经网络(CNN)和循环神经网络(RNN),torch.nn 都能提供必要的组件和工具。
二、关键组件

2.1 层(Layers)

1.卷积层
在计算机视觉方面,卷积层最常用的就是torch.nn.Conv2d,函数原型如下:
  1. torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)
复制代码
对由多个输入平面组成的输入张量应用二维卷积,在最简朴的情况下,具有输入大小                                   (                         N                         ,                                   C                                       i                               n                                            ,                         H                         ,                         W                         )                              (N,C_{in},H,W)                  (N,Cin​,H,W)和输出                                   (                         N                         ,                                   C                                       o                               u                               t                                            ,                                   H                                       o                               u                               t                                            ,                                   W                                       o                               u                               t                                            )                              (N,C_{out},H_{out},W_{out})                  (N,Cout​,Hout​,Wout​)的层的输出值可以正确地描述为:
                                         out                            ⁡                            (                                       N                               i                                      ,                                       C                                                        out                                     ⁡                                              j                                                 )                            =                            bias                            ⁡                            (                                       C                                                        out                                     ⁡                                              j                                                 )                            +                                       ∑                                           k                                  =                                  0                                                      C                                               in                                     ⁡                                     −                                     1                                                             weight                            ⁡                            (                                       C                                                        out                                     ⁡                                              j                                                 ,                            k                            )                            ⋆                            input                            ⁡                            (                                       N                               i                                      ,                            k                            )                                  \operatorname{out}(N_i,C_{\operatorname{out}_j})=\operatorname{bias}(C_{\operatorname{out}_j})+\sum_{k=0}^{C_{\operatorname{in}-1}}\operatorname{weight}(C_{\operatorname{out}_j},k)\star\operatorname{input}(N_i,k)                     out(Ni​,Coutj​​)=bias(Coutj​​)+k=0∑Cin−1​​weight(Coutj​​,k)⋆input(Ni​,k)
                                    b                         i                         a                         s                              bias                  bias是偏置量、                                   w                         e                         i                         g                         h                         t                              weight                  weight就是常说的卷积核权重参数。

常用参数


  • in_channels (int) – 输入图像中的通道数
  • out_channels (int) – 卷积产生的通道数
  • kernel_size (int 或 tuple) – 卷积核的大小
  • stride (int 或 tuple, 可选) – 卷积的步幅。默认值:1
  • padding (int, tuple 或 str, 可选) – 添加到输入所有四边的添补。它可以是字符串 {‘valid’, ‘same’} 或整数/整数元组,表示应用于两侧的隐式添补量。默认值:0
  • dilation (int 或 tuple, 可选) – 内核元素之间的间距。默认值:1
  • groups (int, 可选) – 从输入通道到输出通道的阻塞连接数。默认值:1
  • bias (bool, 可选) – 如果为 True,则将可学习偏差添加到输出。默认值:True
  • padding_mode (str, 可选) – ‘zeros’,‘reflect’,‘replicate’ 或 ‘circular’。默认值:‘zeros’
示例
  1. #创建一个随机张量
  2. >>> input = torch.rand(1,1,3,3)
  3. >>> input
  4. tensor([[[[0.7609, 0.8803, 0.2294],
  5.           [0.8200, 0.8600, 0.5657],
  6.           [0.2421, 0.0077, 0.9762]]]])
  7. #使用默认的padding模式,创建一个输入为1个通道,输出为3个通道,卷积核尺寸为3的卷积层
  8. >>> conv=torch.nn.Conv2d(1, 3, 3)
  9. >>> conv(input)
  10. tensor([[[[-0.0243]],
  11.          [[-0.1504]],
  12.          [[-1.2081]]]], grad_fn=<ConvolutionBackward0>)
  13. #使用padding为‘same’的模式
  14. >>> conv=torch.nn.Conv2d(1, 3, 3,padding='same')  
  15. >>> conv(input)                                 
  16. tensor([[[[-0.0795, -0.2914, -0.4644],
  17.           [-0.0520, -0.1241, -0.2314],
  18.           [-0.1710, -0.1893,  0.2325]],
  19.          [[-0.6988, -0.4558, -0.1104],
  20.           [-0.8113, -0.6543, -0.4082],
  21.           [-0.3180, -0.4468, -0.2224]],
  22.          [[ 0.2227,  0.2641,  0.3581],
  23.           [ 0.5351,  0.2648,  0.5343],
  24.           [ 0.3264,  0.7865,  0.3547]]]], grad_fn=<ConvolutionBackward0>)
复制代码
打印出卷积层中的                                   b                         i                         a                         s                              bias                  bias、                                   w                         e                         i                         g                         h                         t                              weight                  weight的参数值。
  1. >>> conv.bias
  2. Parameter containing:
  3. tensor([-0.2056,  0.0071,  0.0876], requires_grad=True)
  4. >>> conv.weight
  5. Parameter containing:
  6. tensor([[[[ 0.0767,  0.2002, -0.2259],
  7.           [-0.3162,  0.2678, -0.0168],
  8.           [-0.0362, -0.0189, -0.0551]]],
  9.          
  10.         [[[ 0.1302, -0.0640, -0.2265],
  11.           [-0.2353, -0.3108, -0.3259],
  12.           [ 0.2809, -0.1426, -0.0763]]],
  13.         [[[ 0.2515,  0.2941,  0.0281],
  14.           [-0.1044, -0.1175,  0.2562],
  15.           [ 0.2423,  0.3199, -0.3061]]]], requires_grad=True)
复制代码
2. 池化层
池化层在卷积神经网络中起到了特征降维、特征提取、平移稳定性、扩大感受野、减少计算量以及防止过拟合等多重作用。这些特性使得卷积神经网络在处理图像、视频等复杂数据时具有出色的性能。常见的池化操纵包括最大池化(Max Pooling)和均匀池化(Average Pooling)。这里以最大池化层为例具体说明。
最大池化层函数原型如下:
  1. torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
复制代码
常用参数


  • kernel_size – 需要进行最大值运算的窗口大小
  • stride – 窗口的步长。默认值为 kernel_size

示例
  1. >>> input = torch.randn(1,1,4,4)        
  2. >>> input  
  3. tensor([[[[-0.3371,  0.1317,  0.9040,  1.3387],
  4.           [-2.2840,  2.3960,  0.2795,  0.5269],
  5.           [ 0.5753, -0.7660,  0.3067, -0.4039],
  6.           [-0.8702, -0.2735,  0.7680, -2.3174]]]])
  7. >>> maxpool=torch.nn.MaxPool2d(2)
  8. >>> maxpool(input)
  9. tensor([[[[2.3960, 1.3387],
  10.           [0.5753, 0.7680]]]])
复制代码
3.线性层
线性层也称为全连接层(Fully Connected Layer),它对输入数据执行                                   y                         =                         x                                   A                            T                                  +                         b                              y=xA^T+b                  y=xAT+b。函数原型如下:
  1. torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
复制代码
参数


  • in_features (int) – 每个输入样本的大小
  • out_features (int) – 每个输出样本的大小
  • bias (bool) – 如果设置为 False,则该层不会学习加性偏差。默认值:True

示例
  1. >>> m = nn.Linear(20, 30)
  2. >>> input = torch.randn(128, 20)
  3. >>> output = m(input)
  4. >>> print(output.size())
  5. torch.Size([128, 30])
复制代码
2.2 激活函数(Activation Functions)

1. ReLU
函数原型如下:
  1. torch.nn.ReLU(inplace=False)
复制代码
ReLU激活函数的数学定义如下:
                                                    R                               e                               L                               U                                      (                            x                            )                            =                            (                            x                                       )                               +                                      =                            max                            ⁡                            (                            0                            ,                            x                            )                                  \mathrm{ReLU}(x)=(x)^+=\max(0,x)                     ReLU(x)=(x)+=max(0,x)
ReLU函数的特点:


  • 非线性:ReLU函数通过引入非线性因素,使得神经网络能够学习更复杂的模式。尽管其定义简朴,但ReLU实际上是非线性的,这增强了模子的表达能力。
  • 计算简朴高效:ReLU函数的计算非常高效,因为它仅涉及简朴的阈值判断,无需复杂的数学运算。这使得ReLU函数在处理大规模数据时具有显著上风。
  • 缓解梯度消失问题:在正区间内,ReLU函数的梯度为常数1,这有助于梯度的有效通报,从而缓解了梯度消失问题。这一特点使得ReLU函数在训练深层网络时体现更为出色。

示例
  1. >>> input = torch.randn(1,3)
  2. >>> input                    
  3. tensor([[-1.1377,  0.1660, -0.8894]])
  4. >>> relu = torch.nn.ReLU()        
  5. >>> relu(input)            
  6. tensor([[0.0000, 0.1660, 0.0000]])
复制代码
2.LeakyReLU
函数原型如下:
  1. torch.nn.LeakyReLU(negative_slope=0.01, inplace=False)
复制代码
LeakyReLU激活函数的数学定义如下:
                                                    L                               e                               a                               k                               y                               R                               e                               L                               U                                      (                            x                            )                            =                                       {                                                                                                     x                                              ,                                                                                                                                             i                                                 f                                                              x                                              ≥                                              0                                                                                                                                                  negative slope                                              ×                                              x                                              ,                                                                                                                            o                                              t                                              h                                              e                                              r                                              w                                              i                                              s                                              e                                                                                                                                                                             \mathrm{LeakyReLU}(x)=\begin{cases}x,&\mathrm{if}x\geq0\\\text{negative slope}\times x,&\mathrm{otherwise}&\end{cases}                     LeakyReLU(x)={x,negative slope×x,​ifx≥0otherwise​​
LeakyReLU激活函数的特点如下:


  • 在ReLU函数中,当输入值小于0时,输出值为0,这可能导致神经元在训练过程中“死亡”或停止学习。LeakyReLU函数通过允许负输入值有一个很小的斜率(通常为0.01),使得即使输入为负,输出也不会完全为零。这有助于避免神经元死亡的问题,使网络能够学习更多的特征1。
  • LeakyReLU函数在负值地区引入一个小的正斜率,这有助于梯度在网络中的传播,从而可以加快网络的收敛速度1。
  • LeakyReLU函数包含一个超参数α(alpha),它决定了负输入时的梯度大小。在实际应用中,α通常被设置为一个小于1的正数,如0.01。当α小于1时,LeakyReLU函数可以防止梯度消失;当α大于1时,它可以使梯度更快收敛;当α即是1时,LeakyReLU函数等价于ReLU激活函数1。

示例
  1. >>> input = torch.randn(4)
  2. >>> input
  3. tensor([-0.3135, -0.3198,  0.2046,  2.0089])
  4. >>> m = torch.nn.LeakyReLU(0.1)
  5. >>> m(input)
  6. tensor([-0.0313, -0.0320,  0.2046,  2.0089])
复制代码
3.Sigmoid
函数原型如下:
  1. torch.nn.Sigmoid(*args, **kwargs)
复制代码
Sigmoid激活函数的数学定义如下:
                                                    S                               i                               g                               m                               o                               i                               d                                      (                            x                            )                            =                            σ                            (                            x                            )                            =                                       1                                           1                                  +                                  exp                                  ⁡                                  (                                  −                                  x                                  )                                                       \mathrm{Sigmoid}(x)=\sigma(x)=\frac{1}{1+\exp(-x)}                     Sigmoid(x)=σ(x)=1+exp(−x)1​
Sigmoid激活函数的特点如下:


  • Sigmoid函数能够将任何实数输入映射到0和1之间的输出值。这一特性使得Sigmoid函数特殊适适用于二分类问题,如“是”与“否”的判断,为分类任务提供了坚实的基础。
  • Sigmoid函数是平滑且一连的,这使得它在数学上易于处理,求导也相对简朴。这一优点在神经网络的训练过程中尤为重要,因为激活函数的导数会用于权重的更新。
  • Sigmoid函数具有非线性特性,能够捕捉数据间复杂的非线性关系。这为神经网络提供了强大的建模能力,使其能够处理更加复杂的问题。

示例
  1. >>> m = torch.nn.Sigmoid()
  2. >>> input = torch.randn(2)
  3. >>> input
  4. tensor([ 0.2905, -0.1586])
  5. >>> m(input)
  6. tensor([0.5721, 0.4604])
复制代码
4.SiLU
SiLU(Sigmoid-Weighted Linear Unit)激活函数,也被称为Swish,是一种结合了线性和非线性特性的现代激活函数,由Google的研究职员在2017年提出。
函数原型定义如下:
  1. torch.nn.SiLU(inplace=False)
复制代码
SiLU激活函数的数学定义如下:
                                                    s                               i                               l                               u                                      (                            x                            )                            =                            x                            ∗                            σ                            (                            x                            )                                  \mathrm{silu}(x)=x*\sigma(x)                     silu(x)=x∗σ(x)
SiLU激活函数的特点:


  • SiLU函数是一连且光滑的,其梯度也一连变革,避免了ReLU等激活函数中的不一连点。这种平滑性有助于优化算法更快地收敛,而且在整个实数域上都是可导的,使得在反向传播算法中梯度更加平滑和一连。
  • SiLU激活函数是非线性的,能够帮助神经网络学习复杂的非线性模式和特征。这种非线性特性是神经网络能够处理复杂任务的关键。
  • 随着输入值的增大,SiLU激活函数的输出值会趋向于线性变革,这有助于防止梯度消失或梯度爆炸问题。同时,通过x⋅σ(x)的情势,SiLU会动态调解激活值的大小,体现出平滑且动态的特性。

2.3 损失函数(Loss Functions)

损失函数(Loss Function)用于权衡模子预测效果与真实标签之间的差距,反映了模子的预测误差大小。在神经网络的训练和推理过程中,损失函数是一个关键概念,训练的目的就是最小化损失函数,以提高模子的正确度。损失函数的主要作用包括:


  • 评估模子的好坏‌:每一次前向传播后,将模子的输出与真实标签进行比力,计算损失值。模子再通过反向传播基于该损失调解参数,使其逐渐优化‌。
  • 指导模子参数的更新‌:损失函数告诉我们模子预测有多大误差,并指导模子参数的更新方向和幅度‌。
    这里,介绍几种常用的损失函数。
1. MSELoss均方误差损失
均方误差MSE作为损失函数常用于回归问题,特殊是在需要优化图像像素级正确性的任务中。在PyTorch中,函数原型如下:
  1. torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')
复制代码
均方误差MSE计算输入 x 和目的 y 中每个元素之间的均方误差(平方 L2 范数),数学表达如下:
                                         ℓ                            (                            x                            ,                            y                            )                            =                            L                            =                            {                                       l                               1                                      ,                            …                            ,                                       l                               N                                                 }                               ⊤                                      ,                                                l                               n                                      =                                                   (                                               x                                     n                                              −                                               y                                     n                                              )                                          2                                            \ell(x,y)=L=\{l_1,\ldots,l_N\}^\top,\quad l_n=\left(x_n-y_n\right)^2                     ℓ(x,y)=L={l1​,…,lN​}⊤,ln​=(xn​−yn​)2
此中                                   N                              N                  N 是批次大小。如果 reduction 不是 ‘none’(默认值为 ‘mean’),则
                                         ℓ                            (                            x                            ,                            y                            )                            =                                       {                                                                                                                      m                                                 e                                                 a                                                 n                                                              (                                              L                                              )                                              ,                                                                                                                            if reduction                                              =                                                                                                  ′                                                                  m                                                 e                                                 a                                                                   n                                                    ′                                                                               ;                                                                                                                                                                   s                                                 u                                                 m                                                              (                                              L                                              )                                              ,                                                                                                                            if reduction                                              =                                                                                                  ′                                                                  s                                                 u                                                                   m                                                    ′                                                                               .                                                                                                                                                                             \ell(x,y)=\begin{cases}\mathrm{mean}(L),&\text{if reduction}=\mathrm{'mean'};\\\mathrm{sum}(L),&\text{if reduction}=\mathrm{'sum'}.&\end{cases}                     ℓ(x,y)={mean(L),sum(L),​if reduction=′mean′;if reduction=′sum′.​​
2. CrossEntropyLoss交叉熵损失
交叉熵损失函数(Cross-Entropy Loss)是深度学习中分类问题常用的损失函数,特殊适用于多分类问题。在PyTorch中,函数原型如下:
  1. torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)
复制代码
交叉熵损失函数是深度学习中分类问题常用的损失函数,特殊适用于多分类问题。它通过度量预测分布与真实分布之间的差别,来权衡模子输出的正确性。数学定义如下:
                                         C                            r                            o                            s                            s                            E                            n                            t                            r                            o                            y                            L                            o                            s                            s                            =                            −                                       ∑                                           i                                  =                                  1                                          N                                                 y                               i                                      ⋅                            l                            o                            g                            (                                                   y                                  ^                                          i                                      )                                  CrossEntroyLoss=-\sum_{i=1}^Ny_i\cdot log(\hat{y}_i)                     CrossEntroyLoss=−i=1∑N​yi​⋅log(y^​i​)


  • N:类别数
  •                                                    y                               i                                            y_i                     yi​:真实的标签(用 one-hot 编码表示,只有目的类别对应的位置为 1,其他位置为 0)。
  •                                                                y                                  ^                                          i                                            \hat{y}_i                     y^​i​ :模子的预测概率,即 softmax 的输出值。
3. L1Loss均匀绝对误差
MAE计算输入                                   x                              x                  x 和目的                                   y                              y                  y 中每个元素之间的均匀绝对误差 (MAE)。在PyTorch中,函数原型如下:
  1. torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')
复制代码
L1 Loss(也称为绝对误差损失或曼哈顿损失)是一种常用的损失函数,广泛应用于回归任务中,其主要目的是权衡模子预测值与真实值之间的差别。数学定义如下:
                                         ℓ                            (                            x                            ,                            y                            )                            =                            L                            =                            {                                       l                               1                                      ,                            …                            ,                                       l                               N                                                 }                               ⊤                                      ,                                                l                               n                                      =                            ∣                                       x                               n                                      −                                       y                               n                                      ∣                            ,                                  \ell(x,y)=L=\{l_1,\ldots,l_N\}^\top,\quad l_n=|x_n-y_n|,                     ℓ(x,y)=L={l1​,…,lN​}⊤,ln​=∣xn​−yn​∣,
2.4 容器模块(Container Modules)

在PyTorch中,容器模块(Container Modules)用于组织和管理神经网络中的各个层。这些容器模块使得模子的构建和管理变得更加灵活和方便。以下是一些主要的PyTorch容器模块:
1. nn.Sequentia
这是一个非常常用的容器模块,它按照在构造函数中添加它们的次序来组织多个子模块(通常是网络层)。使用nn.Sequential,每个添加的模块或层的输出自动成为下一个模块的输入,这简化了模子的构建过程,使代码更加清晰和易于理解。它适用于大多数前馈神经网络(feed-forward neural networks),如简朴的卷积神经网络、全连接网络等。但对于需要复杂数据流的模子,如具有跳跃连接或多输入/多输出的网络,可能不太适合。
示例
  1. model = nn.Sequential(
  2.           nn.Conv2d(1,20,5),
  3.           nn.ReLU(),
  4.           nn.Conv2d(20,64,5),
  5.           nn.ReLU()
  6.         )
复制代码
2. nn.ModuleLis
这是一个简朴的列表容器,可以包含多个子模块。与Python的常规列表不同,nn.ModuleList中的模块会被正确注册,从而确保它们能够参与到模子的参数更新和生存/加载过程中。
示例
  1. class MyModule(nn.Module):
  2.     def __init__(self) -> None:
  3.         super().__init__()
  4.         #使用列表容器构造模型
  5.         self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
  6.     def forward(self, x):
  7.         # ModuleList can act as an iterable, or be indexed using ints
  8.         for i, l in enumerate(self.linears):
  9.             x = self.linears[i // 2](x) + l(x)
  10.         return x
复制代码
3.nn.ModuleDic
类似于nn.ModuleList,但它是基于字典的容器,可以通过键值对的情势存储子模块。这对于需要根据名称动态访问模块的情况非常有用。
示例
  1. class MyModule(nn.Module):
  2.     def __init__(self) -> None:
  3.         super().__init__()
  4.         self.choices = nn.ModuleDict({
  5.                 'conv': nn.Conv2d(10, 10, 3),
  6.                 'pool': nn.MaxPool2d(3)
  7.         })
  8.         self.activations = nn.ModuleDict([
  9.                 ['lrelu', nn.LeakyReLU()],
  10.                 ['prelu', nn.PReLU()]
  11.         ])
  12.         #在推理时,可以根据choice传入名字选择网络层进行推理
  13.     def forward(self, x, choice, act):
  14.         x = self.choices[choice](x)
  15.         x = self.activations[act](x)
  16.         return x
复制代码
三、 nn.Module

nn.Module 类是 PyTorch 中构建所有神经网络的基类。它提供了模子构建所需的基本功能,包括参数的注册、前向传播的定义、以及模子生存和加载等。
以下是 nn.Module 类的一些关键特性和方法:
1. 参数注册
在 nn.Module 子类中,你可以通过 self.register_parameter 方法注册模子的参数(通常是权重和偏置)。PyTorch 会自动跟踪这些参数,并在训练过程中进行更新。
  1. import torch
  2. import torch.nn as nn
  3. class MyModule(nn.Module):
  4.     def __init__(self):
  5.         super(MyModule, self).__init__()
  6.         # 创建一个参数,并注册它
  7.         self.weight = nn.Parameter(torch.randn(3, 3))
  8.     def forward(self, x):
  9.         return x @ self.weight
  10. # 实例化模块
  11. model = MyModule()
  12. print(list(model.parameters()))  # 输出包含已注册参数的列表
复制代码
2.前向传播
你需要重写 forward 方法来定义模子的前向传播逻辑。
在调用模子时,PyTorch 会自动调用 forward 方法,并将输入数据通报给它。
  1. class SimpleNN(nn.Module):
  2.     def __init__(self):
  3.         super(SimpleNN, self).__init__()
  4.         self.fc = nn.Linear(10, 1)  # 添加一个全连接层
  5.     def forward(self, x):
  6.         return self.fc(x)
  7. # 实例化模型
  8. model = SimpleNN()
  9. input_data = torch.randn(1, 10)  # 创建一个输入张量
  10. output = model(input_data)  # 调用模型,自动执行 forward 方法
  11. print(output)
复制代码
3.子模块
nn.Module 支持将其他 nn.Module 实例作为子模块。子模块可以通过 self.add_module 方法添加,但更常见的是直接在构造函数中通过赋值给 self 的属性来创建和添加子模块。
  1. class ComplexModule(nn.Module):
  2.     def __init__(self):
  3.         super(ComplexModule, self).__init__()
  4.         self.submodule1 = nn.Linear(10, 20)
  5.         self.submodule2 = nn.ReLU()
  6.     def forward(self, x):
  7.         x = self.submodule1(x)
  8.         x = self.submodule2(x)
  9.         return x
  10. # 实例化模型
  11. model = ComplexModule()
  12. input_data = torch.randn(1, 10)
  13. output = model(input_data)
  14. print(output)
复制代码
4. 模子生存和读取
nn.Module 提供了 save 和 load_state_dict 方法,用于生存和加载模子的参数。通常,会生存模子的 state_dict(一个包含所有参数和缓冲区的字典),而不是直接生存模子对象。
模子生存
  1. # 假设你有一个已经训练好的模型
  2. torch.save(model.state_dict(), 'model.pth')
复制代码
模子读取
  1. # 实例化一个新的模型对象
  2. new_model = SimpleNN()
  3. # 加载参数
  4. new_model.load_state_dict(torch.load('model.pth'))
复制代码
5.设备迁徙
你可以使用 to 方法将模子移动到不同的设备(如 CPU、GPU)上。这对于在多设备情况下进行模子训练和推理非常有用。
  1. # 将模型移动到 GPU 上(如果可用)
  2. if torch.cuda.is_available():
  3.     model = model.to('cuda')
复制代码
四、小结

torch.nn模块是PyTorch深度学习框架中不可或缺的一部分,它为开发者提供了构建和训练神经网络所需的所有工具和功能。通过深入相识torch.nn模块的使用方法和技巧,开发者可以更好地应用深度学习技术来解决实际问题。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

傲渊山岳

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表