Pytorch参数初始化设置

打印 上一主题 下一主题

主题 1044|帖子 1044|积分 3132

在PyTorch中,如果不对网络参数举行显式初始化,各层会使用其默认的初始化方法。不同层范例的初始化策略有所不同,以下是常见层的默认初始化方式:
1. 全连接层 (nn.Linear)

权重初始化:使用Kaiming匀称分布(He初始化),假设激活函数为Leaky ReLU(负斜率a=sqrt(5))。初始化范围根据输入维度(fan_in)计算,公式为:
                                         bound                            =                                       1                                           fan_in                                                       \text{bound} = \frac{1}{\sqrt{\text{fan\_in}}}                     bound=fan_in                    ​1​
权重从匀称分布 ( U(-\text{bound}, \text{bound}) ) 中采样。
偏置初始化:匀称分布在 ([-1/\sqrt{\text{fan_in}}, 1/\sqrt{\text{fan_in}}]) 之间。
2. 卷积层 (nn.Conv2d, nn.Conv1d, nn.Conv3d)

权重初始化:与全连接层类似,使用Kaiming匀称分布,但fan_in计算为输入通道数乘以卷积核面积(比方,对于Conv2d,fan_in = in_channels * kernel_height * kernel_width)。
偏置初始化:与全连接层雷同,匀称分布在 ([-1/\sqrt{\text{fan_in}}, 1/\sqrt{\text{fan_in}}]) 之间。
3. LSTM/GRU层 (nn.LSTM, nn.GRU)

权重初始化:权重从匀称分布 ( U(-1/\sqrt{\text{hidden_size}}, 1/\sqrt{\text{hidden_size}}) ) 中采样。
偏置初始化:偏置分为两部门,一部门初始化为零,另一部门从匀称分布 ( U(-1/\sqrt{\text{hidden_size}}, 1/\sqrt{\text{hidden_size}}) ) 中采样。
4. 批归一化层 (nn.BatchNorm1d, nn.BatchNorm2d)

缩放参数(weight):初始化为1。
偏移参数(bias):初始化为0。
5. 嵌入层 (nn.Embedding)

权重初始化:从正态分布 ( N(0, 1) ) 中采样。
默认初始化的潜在问题

激活函数不匹配:Kaiming初始化默认假设使用Leaky ReLU(a=sqrt(5)),若使用ReLU或其他激活函数,大概需要手动调整初始化方式以制止梯度不稳固。
深层网络训练:默认初始化在较浅网络中表现精良,但在深层网络中大概需要更精细的初始化(如Xavier或正交初始化)。
代码示例:查看默认初始化

  1. import torch.nn as nn
  2. # 定义层
  3. linear = nn.Linear(100, 50)
  4. conv = nn.Conv2d(3, 16, kernel_size=3)
  5. lstm = nn.LSTM(input_size=10, hidden_size=20)
  6. # 打印权重范围和标准差
  7. def print_init_info(module):
  8.     for name, param in module.named_parameters():
  9.         if 'weight' in name:
  10.             print(f"{name} mean: {param.data.mean():.4f}, std: {param.data.std():.4f}, range: [{param.data.min():.4f}, {param.data.max():.4f}]")
  11.         elif 'bias' in name:
  12.             print(f"{name} mean: {param.data.mean():.4f}")
  13. print("Linear层初始化信息:")
  14. print_init_info(linear)
  15. print("\nConv2d层初始化信息:")
  16. print_init_info(conv)
  17. print("\nLSTM层初始化信息:")
  18. print_init_info(lstm)
复制代码
手动初始化推荐

若默认初始化不实用,可手动初始化以适配激活函数:
  1. # 针对ReLU的Kaiming初始化
  2. for module in model.modules():
  3.     if isinstance(module, (nn.Linear, nn.Conv2d)):
  4.         nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
  5.         if module.bias is not None:
  6.             nn.init.zeros_(module.bias)
  7.     elif isinstance(module, nn.LSTM):
  8.         for name, param in module.named_parameters():
  9.             if 'weight' in name:
  10.                 nn.init.xavier_uniform_(param)
  11.             elif 'bias' in name:
  12.                 nn.init.zeros_(param)
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

北冰洋以北

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表