Transformer 代码剖析2 - 模型训练 (pytorch实现)

打印 上一主题 下一主题

主题 860|帖子 860|积分 2580

一、模型初始化模块

参考:项目代码
1.1 参数统计函数

  1. def count_parameters(model):
  2.     return sum(p.numel() for p in model.parameters() if p.requires_grad)
复制代码
    技术剖析:


  • numel()方法盘算张量元素总数
  • requires_grad筛选必要梯度更新的参数
  • 统计效果反映模型复杂度,典型Transformer-base约65M参数
1.2 权重初始化

  1. def initialize_weights(m):
  2.     if hasattr(m, 'weight') and m.weight.dim() > 1:
  3.         nn.init.kaiming_uniform_(m.weight.data)
复制代码
    初始化原理:


  • Kaiming初始化针对ReLU族激活函数设计
  • 保持前向传播时方差划一性
  • 公式:                                        W                            ∼                            U                            (                            −                                                   6                                  /                                               n                                                   i                                        n                                                                          ,                                                   6                                  /                                               n                                                   i                                        n                                                                          )                                  W \sim U(-\sqrt{6/n_{in}}, \sqrt{6/n_{in}})                     W∼U(−6/nin​             ​,6/nin​             ​)
1.3 模型实例化

  1. model = Transformer(
  2.     src_pad_idx=src_pad_idx,
  3.     trg_pad_idx=trg_pad_idx,
  4.     trg_sos_idx=trg_sos_idx,
  5.     d_model=d_model,
  6.     enc_voc_size=enc_voc_size,
  7.     dec_voc_size=dec_voc_size,
  8.     max_len=max_len,
  9.     ffn_hidden=ffn_hidden,
  10.     n_head=n_heads,
  11.     n_layers=n_layers,
  12.     drop_prob=drop_prob,
  13.     device=device).to(device)
复制代码
关键参数剖析:
参数典型值作用d_model512向量表征维度n_head8留意力头数量ffn_hidden2048前馈网络隐层维度n_layers6编码器/解码器堆叠层数drop_prob0.1Dropout概率
二、训练准备模块

2.1 优化器设置

  1. optimizer = Adam(
  2.     params=model.parameters(),
  3.     lr=init_lr,
  4.     weight_decay=weight_decay,
  5.     eps=adam_eps)
复制代码
Adam优化器数学原理:
                                                    θ                                           t                                  +                                  1                                                 =                                       θ                               t                                      −                                       η                                                                                     v                                           ^                                                      t                                                           +                                  ϵ                                                                        m                                  ^                                          t                                            \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon}\hat{m}_t                     θt+1​=θt​−v^t​                    ​+ϵη​m^t​
其中                                                        m                               ^                                      t                                       \hat{m}_t                  m^t​和                                                        v                               ^                                      t                                       \hat{v}_t                  v^t​为一阶、二阶矩估计的偏差修正项
2.2 学习率调理器

  1. scheduler = optim.lr_scheduler.ReduceLROnPlateau(
  2.     optimizer=optimizer,
  3.     verbose=True,
  4.     factor=factor,
  5.     patience=patience)
复制代码
调理策略:


  • 监控验证集损失变化
  • 当损失停滞时按factor比例(典型0.5)衰减学习率
  • patience=5表示一连5次无改善触发衰减
2.3 损失函数

  1. criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)
复制代码
Padding处理机制:


  • 通过ignore_index屏蔽添补符的梯度盘算
  • 数学表达式修正为:
                                                  L                               =                               −                                           ∑                                               i                                     =                                     1                                              n                                                      y                                  i                                          log                               ⁡                                           p                                  i                                          ⋅                               I                               (                                           y                                  i                                          ≠                               pad                               )                                      \mathcal{L} = -\sum_{i=1}^{n} y_i \log p_i \cdot \mathbb{I}(y_i \neq \text{pad})                        L=−i=1∑n​yi​logpi​⋅I(yi​=pad)

三、训练与评估模块

3.1 训练函数

  1. def train(model, iterator, optimizer, criterion, clip):
  2.     model.train()
  3.     epoch_loss = 0
  4.     for i, batch in enumerate(iterator):
  5.         src = batch.src
  6.         trg = batch.trg
  7.         optimizer.zero_grad()
  8.         output = model(src, trg[:, :-1])
  9.         output_reshape = output.contiguous().view(-1, output.shape[-1])
  10.         trg = trg[:, 1:].contiguous().view(-1)
  11.         loss = criterion(output_reshape, trg)
  12.         loss.backward()
  13.         torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
  14.         optimizer.step()
  15.         epoch_loss += loss.item()
  16.         print('step :', round((i / len(iterator)) * 100, 2), '% , loss :', loss.item())
  17.     return epoch_loss / len(iterator)
复制代码
    关键技术点:

  • 教师逼迫(Teacher Forcing):使用真实目标序列作为解码器输入
  • 序列切片trg[:, :-1]去除停止符
  • 梯度裁剪防止梯度爆炸
3.2 评估函数

  1. def evaluate(model, iterator, criterion):
  2.     model.eval()
  3.     epoch_loss = 0
  4.     batch_bleu = []
  5.     with torch.no_grad():
  6.         for i, batch in enumerate(iterator):
  7.             src = batch.src
  8.             trg = batch.trg
  9.             output = model(src, trg[:, :-1])
  10.             output_reshape = output.contiguous().view(-1, output.shape[-1])
  11.             trg = trg[:, 1:].contiguous().view(-1)
  12.             loss = criterion(output_reshape, trg)
  13.             epoch_loss += loss.item()
  14.             total_bleu = []
  15.             for j in range(batch_size):
  16.                 try:
  17.                     trg_words = idx_to_word(batch.trg[j], loader.target.vocab)
  18.                     output_words = output[j].max(dim=1)[1]
  19.                     output_words = idx_to_word(output_words, loader.target.vocab)
  20.                     bleu = get_bleu(hypotheses=output_words.split(), reference=trg_words.split())
  21.                     total_bleu.append(bleu)
  22.                 except:
  23.                     pass
  24.             total_bleu = sum(total_bleu) / len(total_bleu)
  25.             batch_bleu.append(total_bleu)
  26.     batch_bleu = sum(batch_bleu) / len(batch_bleu)
  27.     return epoch_loss / len(iterator), batch_bleu
复制代码
    BLEU盘算原理:
                                         B                            L                            E                            U                            =                            B                            P                            ⋅                            exp                            ⁡                                       (                                           ∑                                               n                                     =                                     1                                              N                                                      w                                  n                                          log                               ⁡                                           p                                  n                                          )                                            BLEU = BP \cdot \exp\left(\sum_{n=1}^N w_n \log p_n\right)                     BLEU=BP⋅exp(n=1∑N​wn​logpn​)
其中:


  • BP为简洁惩罚因子
  •                                                    p                               n                                            p_n                     pn​为n-gram精度
  •                                                    w                               n                                            w_n                     wn​为各阶权重(通常平均加权)

四、运行控制模块

4.1 训练循环

  1. def run(total_epoch, best_loss):
  2.     train_losses, test_losses, bleus = [], [], []
  3.     for step in range(total_epoch):
  4.         start_time = time.time()
  5.         train_loss = train(model, train_iter, optimizer, criterion, clip)
  6.         valid_loss, bleu = evaluate(model, valid_iter, criterion)
  7.         end_time = time.time()
  8.         if step > warmup:
  9.             scheduler.step(valid_loss)
  10.         train_losses.append(train_loss)
  11.         test_losses.append(valid_loss)
  12.         bleus.append(bleu)
  13.         epoch_mins, epoch_secs = epoch_time(start_time, end_time)
  14.         if valid_loss < best_loss:
  15.             best_loss = valid_loss
  16.             torch.save(model.state_dict(), 'saved/model-{0}.pt'.format(valid_loss))
  17.         f = open('result/train_loss.txt', 'w')
  18.         f.write(str(train_losses))
  19.         f.close()
  20.         f = open('result/bleu.txt', 'w')
  21.         f.write(str(bleus))
  22.         f.close()
  23.         f = open('result/test_loss.txt', 'w')
  24.         f.write(str(test_losses))
  25.         f.close()
  26.         print(f'Epoch: {step + 1} | Time: {epoch_mins}m {epoch_secs}s')
  27.         print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
  28.         print(f'\tVal Loss: {valid_loss:.3f} |  Val PPL: {math.exp(valid_loss):7.3f}')
  29.         print(f'\tBLEU Score: {bleu:.3f}')
复制代码
    模型保存策略:


  • 接纳验证损失作为保存标准
  • 使用model.state_dict()保存参数快照
  • 文件命名包罗验证损失便于版本管理

五、工程实践要点

5.1 训练技巧


  • Warm-up策略:前warmup个epoch不启动学习率衰减
  • 混淆精度训练:可结合torch.cuda.amp加快训练
  • 梯度累积:小批量数据累积梯度模仿大批量效果
5.2 性能优化

  1. torch.backends.cudnn.benchmark = True  # 启用cuDNN自动优化器
  2. torch.autograd.set_detect_anomaly(False)  # 禁用异常检测提升速度
复制代码
5.3 扩展实现

  1. 模型并行改造示例
  2. class ParallelTransformer(Transformer):
  3.     def __init__(self, ...):
  4.         super().__init__(...)
  5.         self.encoder = nn.DataParallel(self.encoder)
  6.         self.decoder = nn.DataParallel(self.decoder)
复制代码

   本节从代码实现到理论机制进行了多角度剖析,完整保留原始代码结构的同时通过流程图解耦了各模块的运作机制。现实应用中可根据任务规模调解超参数,建议在8*V100 GPU环境下进行大规模预训练,结合混淆精度训练提升训练效率。
  
源码(附):

  1. """@author : Hyunwoong@when : 2019-10-22@homepage : https://github.com/gusdnd852"""import mathimport timefrom torch import nn, optimfrom torch.optim import Adamfrom data import *from models.model.transformer import Transformerfrom util.bleu import idx_to_word, get_bleufrom util.epoch_timer import epoch_timedef count_parameters(model):
  2.     return sum(p.numel() for p in model.parameters() if p.requires_grad)
  3. def initialize_weights(m):    if hasattr(m, 'weight') and m.weight.dim() > 1:        nn.init.kaiming_uniform(m.weight.data)model = Transformer(src_pad_idx=src_pad_idx,                    trg_pad_idx=trg_pad_idx,                    trg_sos_idx=trg_sos_idx,                    d_model=d_model,                    enc_voc_size=enc_voc_size,                    dec_voc_size=dec_voc_size,                    max_len=max_len,                    ffn_hidden=ffn_hidden,                    n_head=n_heads,                    n_layers=n_layers,                    drop_prob=drop_prob,                    device=device).to(device)print(f'The model has {count_parameters(model):,} trainable parameters')model.apply(initialize_weights)optimizer = Adam(params=model.parameters(),                 lr=init_lr,                 weight_decay=weight_decay,                 eps=adam_eps)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,                                                 verbose=True,                                                 factor=factor,                                                 patience=patience)criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)
  4. def train(model, iterator, optimizer, criterion, clip):    model.train()    epoch_loss = 0    for i, batch in enumerate(iterator):        src = batch.src        trg = batch.trg        optimizer.zero_grad()        output = model(src, trg[:, :-1])        output_reshape = output.contiguous().view(-1, output.shape[-1])        trg = trg[:, 1:].contiguous().view(-1)        loss = criterion(output_reshape, trg)        loss.backward()        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)        optimizer.step()        epoch_loss += loss.item()        print('step :', round((i / len(iterator)) * 100, 2), '% , loss :', loss.item())    return epoch_loss / len(iterator)def evaluate(model, iterator, criterion):
  5.     model.eval()
  6.     epoch_loss = 0
  7.     batch_bleu = []
  8.     with torch.no_grad():
  9.         for i, batch in enumerate(iterator):
  10.             src = batch.src
  11.             trg = batch.trg
  12.             output = model(src, trg[:, :-1])
  13.             output_reshape = output.contiguous().view(-1, output.shape[-1])
  14.             trg = trg[:, 1:].contiguous().view(-1)
  15.             loss = criterion(output_reshape, trg)
  16.             epoch_loss += loss.item()
  17.             total_bleu = []
  18.             for j in range(batch_size):
  19.                 try:
  20.                     trg_words = idx_to_word(batch.trg[j], loader.target.vocab)
  21.                     output_words = output[j].max(dim=1)[1]
  22.                     output_words = idx_to_word(output_words, loader.target.vocab)
  23.                     bleu = get_bleu(hypotheses=output_words.split(), reference=trg_words.split())
  24.                     total_bleu.append(bleu)
  25.                 except:
  26.                     pass
  27.             total_bleu = sum(total_bleu) / len(total_bleu)
  28.             batch_bleu.append(total_bleu)
  29.     batch_bleu = sum(batch_bleu) / len(batch_bleu)
  30.     return epoch_loss / len(iterator), batch_bleu
  31. def run(total_epoch, best_loss):
  32.     train_losses, test_losses, bleus = [], [], []
  33.     for step in range(total_epoch):
  34.         start_time = time.time()
  35.         train_loss = train(model, train_iter, optimizer, criterion, clip)
  36.         valid_loss, bleu = evaluate(model, valid_iter, criterion)
  37.         end_time = time.time()
  38.         if step > warmup:
  39.             scheduler.step(valid_loss)
  40.         train_losses.append(train_loss)
  41.         test_losses.append(valid_loss)
  42.         bleus.append(bleu)
  43.         epoch_mins, epoch_secs = epoch_time(start_time, end_time)
  44.         if valid_loss < best_loss:
  45.             best_loss = valid_loss
  46.             torch.save(model.state_dict(), 'saved/model-{0}.pt'.format(valid_loss))
  47.         f = open('result/train_loss.txt', 'w')
  48.         f.write(str(train_losses))
  49.         f.close()
  50.         f = open('result/bleu.txt', 'w')
  51.         f.write(str(bleus))
  52.         f.close()
  53.         f = open('result/test_loss.txt', 'w')
  54.         f.write(str(test_losses))
  55.         f.close()
  56.         print(f'Epoch: {step + 1} | Time: {epoch_mins}m {epoch_secs}s')
  57.         print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
  58.         print(f'\tVal Loss: {valid_loss:.3f} |  Val PPL: {math.exp(valid_loss):7.3f}')
  59.         print(f'\tBLEU Score: {bleu:.3f}')
  60. if __name__ == '__main__':    run(total_epoch=epoch, best_loss=inf)
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

没腿的鸟

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表