神经网络初始化 (init) 介绍

打印 上一主题 下一主题

主题 984|帖子 984|积分 2954

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
弁言

在深度学习的世界中,构建一个高效且性能优异的神经网络模子需要综合考虑多个因素。尽管选择合适的架构和优化算法至关紧张,但权重初始化这一环节同样不容忽视。合适的初始化策略不仅能加速模子的收敛速率,提升训练稳定性,还能明显影响终极的模子性能。
1. 初始化的紧张性

在深入探讨各种初始化方法之前,首先需要明确权重初始化在神经网络训练中的关键作用。
1.1 打破对称性

如果全部神经元的权重初始化为相同的值,网络在训练初期将无法学习到多样化的特性。这是由于在前向传播和反向传播过程中,每个神经元都会计算出相同的输出和梯度,导致它们在训练过程中同步更新,学习到相同的内容。打破对称性通过确保每个神经元的初始权重具有一定的随机性,使得每个神经元能够独立地探索差别的特性空间,从而进步模子的表达能力。
1.2 控制方差

在深层网络中,信号通过多层传播可能会渐渐放大或缩小,导致梯度消失或爆炸。这些问题会严重影响模子的训练结果,尤其是在反向传播阶段。合理的权重初始化能够保持每一层输出的方差稳定,确保信号在整个网络中均匀传播,避免梯度消失或爆炸,从而促进稳定的训练过程。
1.3 加速收敛与进步泛化能力

正确的初始化策略能够引导丧失函数的优化过程,使其更轻易找到好的局部最小值或全局最小值。这不仅能加快训练速率,还能提升模子的终极性能。别的,合理的初始化还能够资助模子更快地进入一个具备精良泛化能力的参数地区,提升其在未见数据上的表现。
2. 常见的初始化方法及其应用场景

根据差别的激活函数和网络架构,存在多种权重初始化方法。以下是几种常见且有效的初始化策略及其适用场景。
2.1 Xavier/Glorot 初始化

适用场景:适用于激活函数为 Sigmoid 或 Tanh 的神经网络。
原理:Xavier 初始化通过维持每一层输入和输出信号的方差一致,防止梯度在传播过程中渐渐消失或爆炸。详细来说,它根据输入和输出神经元的数量来设定权重的初始化范围,通常接纳均匀分布或正态分布。
示例
  1. import torch.nn as nn
  2. # Xavier 初始化示例
  3. linear = nn.Linear(in_features=256, out_features=128)
  4. nn.init.xavier_uniform_(linear.weight)
复制代码
2.2 He 初始化

适用场景:专为 ReLU 及其变体设计的神经网络。
原理:He 初始化考虑到 ReLU 激活函数的非负输出特性,调整了初始化时权重的方差,使其更恰当 ReLU 的特性。这样可以有效地保持信号在前向传播过程中的尺度差不变,避免梯度消失。
示例
  1. import torch.nn as nn
  2. # He 初始化示例
  3. conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3)
  4. nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu')
复制代码
2.3 正交初始化

适用场景:特别恰当循环神经网络(RNNs)和非常深的前馈网络。
原理:正交初始化通过确保权重矩阵的列相互正交且具有单位长度,有效地防止了梯度消失或爆炸的问题。对于 RNN 来说,正交矩阵能够保持序列数据的长时间依靠关系,提升模子的表现。
示例
  1. import torch.nn as nn
  2. # 正交初始化示例
  3. linear = nn.Linear(in_features=256, out_features=256)
  4. nn.init.orthogonal_(linear.weight)
复制代码
2.4 其他初始化方法

除了上述重要方法外,还有一些其他初始化策略,尽管在现代实践中使用较少,但在特定场景下也有其应用代价。


  • 零初始化:将权重初始化为零。这种方法通常不推荐用于隐蔽层,由于会导致对称性问题,但可以用于初始化偏置项。
    1. import torch.nn as nn
    2. # 零初始化示例
    3. linear = nn.Linear(in_features=256, out_features=128)
    4. nn.init.constant_(linear.weight, 0)
    复制代码
  • 随机初始化:使用尺度正态分布或均匀分布随机初始化权重。虽然简单,但在深层网络中可能导致梯度问题。
    1. import torch.nn as nn
    2. # 随机初始化示例
    3. linear = nn.Linear(in_features=256, out_features=128)
    4. nn.init.normal_(linear.weight, mean=0.0, std=0.02)
    复制代码
  • 希罕初始化:初始化部分权重为非零,其他为零。适用于希望网络具有希罕连接的场景。
    1. import torch.nn as nn
    2. # 稀疏初始化示例
    3. linear = nn.Linear(in_features=256, out_features=128)
    4. nn.init.sparse_(linear.weight, sparsity=0.1)
    复制代码
3. 如何设置初始化

在深度学习框架如 PyTorch 中,权重初始化可以通过自界说初始化方法或直接利用内置函数来实现。以下以一个简单的卷积神经网络(CNN)为例,展示如安在构造函数中应用差别的初始化策略。
  1. import torch.nn as nn
  2. class SimpleCNN(nn.Module):
  3.     def __init__(self, num_classes=10):
  4.         super(SimpleCNN, self).__init__()
  5.         # 定义网络层
  6.         self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
  7.         self.bn1 = nn.BatchNorm2d(32)
  8.         self.relu = nn.ReLU(inplace=True)
  9.         self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
  10.         self.fc = nn.Linear(32 * 16 * 16, num_classes)  # 假设输入图像大小为32x32
  11.         # 初始化权重
  12.         self._initialize_weights()
  13.     def _initialize_weights(self):
  14.         for m in self.modules():
  15.             if isinstance(m, nn.Conv2d):
  16.                 # 使用 He 初始化
  17.                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  18.                 if m.bias is not None:
  19.                     nn.init.constant_(m.bias, 0)
  20.             elif isinstance(m, nn.BatchNorm2d):
  21.                 # 批归一化层权重初始化为1,偏置初始化为0
  22.                 nn.init.constant_(m.weight, 1)
  23.                 nn.init.constant_(m.bias, 0)
  24.             elif isinstance(m, nn.Linear):
  25.                 # 使用 Xavier 初始化
  26.                 nn.init.xavier_uniform_(m.weight)
  27.                 if m.bias is not None:
  28.                     nn.init.constant_(m.bias, 0)
  29.     def forward(self, x):
  30.         x = self.relu(self.bn1(self.conv1(x)))
  31.         x = self.pool(x)
  32.         x = x.view(x.size(0), -1)
  33.         x = self.fc(x)
  34.         return x
复制代码
分析


  • 卷积层(Conv2d):接纳 He 初始化,恰当 ReLU 激活函数,确保信号在前向传播中的稳定。
  • 批归一化层(BatchNorm2d):权重初始化为1,偏置初始化为0,保证初始状态下批归一化层的尺度化结果。
  • 全连接层(Linear):接纳 Xavier 初始化,恰当 Sigmoid 或 Tanh 激活函数,保持输入和输出信号的方差一致。
4. 基于 BERT 的文天职类如何进行初始化

为了更深入地明确权重初始化的实际应用,本文将通过一个详细的文天职类任务,展示如安在预训练模子 BERT 的基础上进行初始化和微调。
4.1 项目配景

文天职类是天然语言处理中的基础任务之一,广泛应用于情绪分析、垃圾邮件检测、话题分类等场景。比年来,预训练语言模子如 BERT(Bidirectional Encoder Representations from Transformers)因其强盛的语言明确能力,成为文天职类任务的首选基础模子。
4.2 模子构建

以下代码展示了如何构建一个基于 BERT 的文天职类器,并在新增长的分类层上应用权重初始化。
  1. from transformers import BertModel, BertTokenizer
  2. import torch.nn as nn
  3. class BertForTextClassification(nn.Module):
  4.     def __init__(self, num_labels=2, dropout_rate=0.3):
  5.         super(BertForTextClassification, self).__init__()
  6.         # 加载预训练的 BERT 模型
  7.         self.bert = BertModel.from_pretrained('bert-base-uncased')
  8.         self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')        
  9.         # 冻结 BERT 模型的参数以加快训练速度(可选)
  10.         for param in self.bert.parameters():
  11.             param.requires_grad = False      
  12.         # 定义分类头
  13.         self.dropout = nn.Dropout(dropout_rate)
  14.         self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)        
  15.         # 初始化分类层的权重
  16.         self._initialize_weights()
  17.     def _initialize_weights(self):
  18.         # 使用 Xavier 初始化分类层权重
  19.         nn.init.xavier_uniform_(self.classifier.weight)
  20.         if self.classifier.bias is not None:
  21.             nn.init.zeros_(self.classifier.bias)
  22.     def forward(self, input_ids, attention_mask=None, token_type_ids=None):
  23.         # 获取 BERT 的输出
  24.         outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
  25.         # 取 [CLS] 标记的输出作为分类依据
  26.         cls_output = outputs.last_hidden_state[:, 0, :]
  27.         cls_output = self.dropout(cls_output)
  28.         logits = self.classifier(cls_output)
  29.         return logits
复制代码
关键步骤分析

  • 加载预训练模子
    1. self.bert = BertModel.from_pretrained('bert-base-uncased')
    2. self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    复制代码
    通过 transformers 库加载预训练的 BERT 模子及其对应的分词器。
  • 冻结预训练模子参数(可选)
    1. for param in self.bert.parameters():
    2.     param.requires_grad = False
    复制代码
    冻结 BERT 模子的参数,仅训练新增长的分类层,能够明显淘汰训练时间和计算资源消耗,适用于数据量较小的场景。
  • 界说分类头
    1. self.dropout = nn.Dropout(dropout_rate)
    2. self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
    复制代码
    使用 dropout 层防止过拟合,并通过全连接层将 BERT 的输出映射到分类标签空间。
  • 初始化分类层权重
    1. self._initialize_weights()
    复制代码
    为新增长的分类层应用 Xavier 初始化,确保其在训练开始时具有精良的表现。
4.3 模子训练与评估

在训练过程中,合理的初始化策略能够资助模子更快地收敛,并在有限的训练迭代中到达较好的性能。以下是一个简单的训练和评估流程示例:
  1. import torch
  2. from torch.utils.data import DataLoader, Dataset
  3. from transformers import AdamW
  4. from sklearn.metrics import accuracy_score
  5. # 假设已定义好文本数据集和数据加载器
  6. class TextDataset(Dataset):
  7.     def __init__(self, texts, labels, tokenizer, max_length=128):
  8.         self.texts = texts
  9.         self.labels = labels
  10.         self.tokenizer = tokenizer
  11.         self.max_length = max_length   
  12.     def __len__(self):
  13.         return len(self.texts)   
  14.     def __getitem__(self, idx):
  15.         encoding = self.tokenizer.encode_plus(
  16.             self.texts[idx],
  17.             add_special_tokens=True,
  18.             max_length=self.max_length,
  19.             padding='max_length',
  20.             truncation=True,
  21.             return_attention_mask=True,
  22.             return_tensors='pt'
  23.         )
  24.         return {
  25.             'input_ids': encoding['input_ids'].flatten(),
  26.             'attention_mask': encoding['attention_mask'].flatten(),
  27.             'labels': torch.tensor(self.labels[idx], dtype=torch.long)
  28.         }
  29. # 初始化数据集和数据加载器
  30. train_dataset = TextDataset(train_texts, train_labels, model.tokenizer)
  31. train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
  32. val_dataset = TextDataset(val_texts, val_labels, model.tokenizer)
  33. val_loader = DataLoader(val_dataset, batch_size=32)
  34. # 初始化模型、优化器和损失函数
  35. model = BertForTextClassification(num_labels=2)
  36. optimizer = AdamW(model.parameters(), lr=2e-5)
  37. criterion = nn.CrossEntropyLoss()
  38. # 训练循环
  39. for epoch in range(3):  # 假设训练3个epoch
  40.     model.train()
  41.     for batch in train_loader:
  42.         optimizer.zero_grad()
  43.         input_ids = batch['input_ids']
  44.         attention_mask = batch['attention_mask']
  45.         labels = batch['labels']
  46.         
  47.         outputs = model(input_ids=input_ids, attention_mask=attention_mask)
  48.         loss = criterion(outputs, labels)
  49.         loss.backward()
  50.         optimizer.step()   
  51.     # 评估
  52.     model.eval()
  53.     all_preds, all_labels = [], []
  54.    
  55.     with torch.no_grad():
  56.         for batch in val_loader:
  57.             input_ids = batch['input_ids']
  58.             attention_mask = batch['attention_mask']
  59.             labels = batch['labels']
  60.             
  61.             outputs = model(input_ids=input_ids, attention_mask=attention_mask)
  62.             preds = torch.argmax(outputs, dim=1)
  63.             all_preds.extend(preds.cpu().numpy())
  64.             all_labels.extend(labels.cpu().numpy())
  65.             
  66.     acc = accuracy_score(all_labels, all_preds)
  67.     print(f"Epoch {epoch + 1} - Validation Accuracy: {acc:.4f}")
复制代码
关键点


  • 数据预处理:使用 BERT 的分词器将文本转换为模子可接受的输入格式,包罗 input_ids 和 attention_mask。
  • 冻结与微调:根据详细需求,可以选择冻结 BERT 的部分或全部参数,仅训练新增长的层,或进行全模子微调。
  • 优化器与丧失函数:使用 AdamW 优化器和交叉熵丧失函数,适用于分类任务。
4.4 结果分析

通过合理的权重初始化和预训练模子的优势,基于 BERT 的文天职类器在多个尺度数据集上表现精彩。例如,在情绪分析任务中,冻结 BERT 参数并仅训练分类层的模子,能够在较短的训练时间内到达接近全模子微调的性能,同时明显淘汰计算资源的消耗。
结论

权重初始化在神经网络训练中饰演着至关紧张的角色。合适的初始化策略不仅能够打破对称性,控制信号的方差,还能加速模子的收敛,进步泛化能力。本文系统地介绍了几种常见的初始化方法及其适用场景,并通过基于 BERT 的文天职类示例,展示了如安在实际项目中应用这些初始化策略。
参考资料



  • Glorot, X., & Bengio, Y. (2010). Understanding the difficulty of training deep feedforward neural networks.
  • He, K., Zhang, X., Ren, S., & Sun, J. (2015). Delving deep into rectifiers: Surpassing human-level performance on imagenet classification.
  • BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
  • PyTorch 官方文档 - 权重初始化

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

祗疼妳一个

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表