马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
x
一、引言
在机器学习的练习流程中,模型构建是核心环节之一。从传统机器学习的线性模型到深度学习的神经网络,模型的复杂度呈指数级增长。PyTorch 作为主流深度学习框架,通过nn.Module类提供了同一的模型构建接口,使得复杂网络结构的界说与管理变得高效且规范。
二、三要素
2.1 网络层构建
深度学习模型的底子是各类网络层,常见范例包罗:
- 卷积层:nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0),用于提取空间特征(如 LeNet 中的 Conv1/Conv2)
- 池化层:nn.MaxPool2d(kernel_size, stride=None),实现特征降维(如 LeNet 中的 Pool1/Pool2)
- 全连接层:nn.Linear(in_features, out_features),完成特征到输出的映射(如 LeNet 中的 Fc1/Fc2/Fc3)
- 激活函数层:nn.ReLU()、nn.Sigmoid()等,引入非线性(现实代码中常使用nn.functional中的函数以减少参数)
2.2 网络拼接
以 LeNet 为例,其网络结构可拆解为:
- 输入`32x32x3`
- → Conv1(6核,5x5,步长1)
- → 输出`28x28x6`
- → MaxPool1(2x2,步长2)
- → 输出`14x14x6`
- → Conv2(16核,5x5,步长1)
- → 输出`10x10x16`
- → MaxPool2(2x2,步长2)
- → 输出`5x5x16`
- → 展平为400维
- → Fc1(400→120)
- → Fc2(120→84)
- → Fc3(84→10)
- → Softmax输出
复制代码 2.3 权值初始化
公道的初始化可制止梯度消失/爆炸,常见方法:
- Xavier初始化(实用于sigmoid/tanh):
nn.init.xavier_uniform_(weight, gain=1.0),保证输入输出方差划一
- Kaiming初始化(实用于ReLU系列):
nn.init.kaiming_normal_(weight, mode='fan_in', nonlinearity='relu'),思量激活函数的非线性特性
- 匀称/正态分布:
nn.init.uniform_(weight, -a, a)、nn.init.normal_(weight, mean=0, std=0.01)
三、nn.Module
3.1 两大要素
每个自界说模型需继承nn.Module,并实现两大核心方法:
- __init__():界说子模块(如self.conv1 = nn.Conv2d(...)),初始化可学习参数
- forward():界说前向传播逻辑,调用子模块并组合运算(克制直接修改输入张量的内存,需返回新张量)
代码示例:LeNet 模型界说
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class LeNet(nn.Module):
- def __init__(self):
- super(LeNet, self).__init__()
- # 卷积层与池化层
- self.conv1 = nn.Conv2d(3, 6, 5) # 输入3通道,输出6通道,5x5卷积核
- self.conv2 = nn.Conv2d(6, 16, 5)
- self.pool = nn.MaxPool2d(2, 2) # 2x2池化,步长2
- # 全连接层
- self.fc1 = nn.Linear(16*5*5, 120) # 5x5是池化后尺寸(10/2=5)
- self.fc2 = nn.Linear(120, 84)
- self.fc3 = nn.Linear(84, 10)
-
- def forward(self, x): # x形状:(batch_size, 3, 32, 32)
- x = self.pool(F.relu(self.conv1(x))) # (6, 28, 28) → (6, 14, 14)
- x = self.pool(F.relu(self.conv2(x))) # (16, 10, 10) → (16, 5, 5)
- x = x.view(-1, 16*5*5) # 展平为批量维度+特征维度
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- x = self.fc3(x) # 输出logits,Softmax在损失函数中处理
- return x
复制代码 3.2 四大属性
nn.Module通过8个OrderedDict管理内部状态,核心属性包罗:
- parameters():迭代所有可学习参数(如weight、bias),用于优化器更新
- model = LeNet()
- for name, param in model.named_parameters():
- print(f"参数名:{name},形状:{param.shape}")
复制代码 - modules():递归遍历所有子模块(包罗自身),用于模型结构查抄
- for name, module in model.named_modules():
- print(f"模块名:{name},类型:{type(module)}")
复制代码 - buffers():存储非可学习状态(如 BN 层的running_mean、running_var)
- for name, buf in model.named_buffers():
- print(f"缓冲区名:{name},形状:{buf.shape}")
复制代码 - hooks:注册前向/反向钩子函数,用于获取中心层输出(调试或特征可视化)
- def hook_fn(module, input, output):
- print(f"模块{type(module).__name__}的输出形状:{output.shape}")
- handle = model.conv1.register_forward_hook(hook_fn) # 注册钩子
- handle.remove() # 使用后移除避免内存泄漏
复制代码 四、进阶本领
4.1 层次化计划:子模块复用
复杂模型(如 ResNet)通过界说子模块(如残差块)提拔代码复用性:
- class ResidualBlock(nn.Module):
- def __init__(self, in_channels, out_channels, stride=1):
- super().__init__()
- self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
- self.bn1 = nn.BatchNorm2d(out_channels)
- self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
- self.bn2 = nn.BatchNorm2d(out_channels)
- self.shortcut = nn.Sequential()
- if stride != 1 or in_channels != out_channels:
- self.shortcut = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
- nn.BatchNorm2d(out_channels)
- )
-
- def forward(self, x):
- out = F.relu(self.bn1(self.conv1(x)))
- out = self.bn2(self.conv2(out))
- out += self.shortcut(x)
- return F.relu(out)
复制代码 4.2 动态外形处理:制止硬编码尺寸
在forward中通过x.shape动态获取维度,提拔模型通用性(如支持不同输入尺寸的图像)。
4.3 混合使用nn.Module与nn.functional
- 保举场景:
- 具有可学习参数的层(如 Conv、Linear、BN)使用nn.Module子类
- 无参数的运算(如激活函数、池化、展平)使用nn.functional函数(减少内存占用,更灵活)
五、注意事项
- 制止在forward中使用Python控制流:如需条件判断或循环,只管使用PyTorch内置函数(如torch.where),以保证模型可序列化和JIT编译
- 参数初始化的显式调用:在__init__中对自界说层举行初始化,制止使用默认初始化(如全零初始化可能导致对称性破缺)
- 模型生存与加载:使用torch.save(model.state_dict(), 'model.pth')生存参数,加载时通过model.load_state_dict(torch.load('model.pth'))恢复,保持模块命名划一性
微语录:不要由于走得太远,而忘记为什么出发。— — 卡里·纪伯伦
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |