在PyTorch中,如果不对网络参数举行显式初始化,各层会使用其默认的初始化方法。不同层范例的初始化策略有所不同,以下是常见层的默认初始化方式:
1. 全连接层 (nn.Linear)
• 权重初始化:使用Kaiming匀称分布(He初始化),假设激活函数为Leaky ReLU(负斜率a=sqrt(5))。初始化范围根据输入维度(fan_in)计算,公式为:
bound = 1 fan_in \text{bound} = \frac{1}{\sqrt{\text{fan\_in}}} bound=fan_in 1
权重从匀称分布 ( U(-\text{bound}, \text{bound}) ) 中采样。
• 偏置初始化:匀称分布在 ([-1/\sqrt{\text{fan_in}}, 1/\sqrt{\text{fan_in}}]) 之间。
2. 卷积层 (nn.Conv2d, nn.Conv1d, nn.Conv3d)
• 权重初始化:与全连接层类似,使用Kaiming匀称分布,但fan_in计算为输入通道数乘以卷积核面积(比方,对于Conv2d,fan_in = in_channels * kernel_height * kernel_width)。
• 偏置初始化:与全连接层雷同,匀称分布在 ([-1/\sqrt{\text{fan_in}}, 1/\sqrt{\text{fan_in}}]) 之间。
3. LSTM/GRU层 (nn.LSTM, nn.GRU)
• 权重初始化:权重从匀称分布 ( U(-1/\sqrt{\text{hidden_size}}, 1/\sqrt{\text{hidden_size}}) ) 中采样。
• 偏置初始化:偏置分为两部门,一部门初始化为零,另一部门从匀称分布 ( U(-1/\sqrt{\text{hidden_size}}, 1/\sqrt{\text{hidden_size}}) ) 中采样。
4. 批归一化层 (nn.BatchNorm1d, nn.BatchNorm2d)
• 缩放参数(weight):初始化为1。
• 偏移参数(bias):初始化为0。
5. 嵌入层 (nn.Embedding)
• 权重初始化:从正态分布 ( N(0, 1) ) 中采样。
默认初始化的潜在问题
• 激活函数不匹配:Kaiming初始化默认假设使用Leaky ReLU(a=sqrt(5)),若使用ReLU或其他激活函数,大概需要手动调整初始化方式以制止梯度不稳固。
• 深层网络训练:默认初始化在较浅网络中表现精良,但在深层网络中大概需要更精细的初始化(如Xavier或正交初始化)。
代码示例:查看默认初始化
- import torch.nn as nn
- # 定义层
- linear = nn.Linear(100, 50)
- conv = nn.Conv2d(3, 16, kernel_size=3)
- lstm = nn.LSTM(input_size=10, hidden_size=20)
- # 打印权重范围和标准差
- def print_init_info(module):
- for name, param in module.named_parameters():
- if 'weight' in name:
- print(f"{name} mean: {param.data.mean():.4f}, std: {param.data.std():.4f}, range: [{param.data.min():.4f}, {param.data.max():.4f}]")
- elif 'bias' in name:
- print(f"{name} mean: {param.data.mean():.4f}")
- print("Linear层初始化信息:")
- print_init_info(linear)
- print("\nConv2d层初始化信息:")
- print_init_info(conv)
- print("\nLSTM层初始化信息:")
- print_init_info(lstm)
复制代码 手动初始化推荐
若默认初始化不实用,可手动初始化以适配激活函数:
- # 针对ReLU的Kaiming初始化
- for module in model.modules():
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- elif isinstance(module, nn.LSTM):
- for name, param in module.named_parameters():
- if 'weight' in name:
- nn.init.xavier_uniform_(param)
- elif 'bias' in name:
- nn.init.zeros_(param)
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |