第N11周:seq2seq翻译实战-Pytorch复现

打印 上一主题 下一主题

主题 560|帖子 560|积分 1680

使命:
●为解码器添加上注意力机制
一、前期预备工作
  1. from __future__ import unicode_literals, print_function, division
  2. from io import open
  3. import unicodedata
  4. import string
  5. import re
  6. import random
  7. import torch
  8. import torch.nn as nn
  9. from torch import optim
  10. import torch.nn.functional as F
  11. device = torch.device("cuda" if torch.cuda.is_available() else "cpu
  12. ")
  13. print(device)
复制代码
代码输出
  1. cpu
复制代码

  • 搭建语言类
  1. SOS_token = 0
  2. EOS_token = 1
  3. # 语言类,方便对语料库进行操作
  4. class Lang:
  5.     def __init__(self, name):
  6.         self.name = name
  7.         self.word2index = {}
  8.         self.word2count = {}
  9.         self.index2word = {0: "SOS", 1: "EOS"}
  10.         self.n_words    = 2  # Count SOS and EOS
  11.     def addSentence(self, sentence):
  12.         for word in sentence.split(' '):
  13.             self.addWord(word)
  14.     def addWord(self, word):
  15.         if word not in self.word2index:
  16.             self.word2index[word] = self.n_words
  17.             self.word2count[word] = 1
  18.             self.index2word[self.n_words] = word
  19.             self.n_words += 1
  20.         else:
  21.             self.word2count[word] += 1
复制代码

  • 文本处理处罚函数
  1. def unicodeToAscii(s):
  2.     return ''.join(
  3.         c for c in unicodedata.normalize('NFD', s)
  4.         if unicodedata.category(c) != 'Mn'
  5.     )
  6. # 小写化,剔除标点与非字母符号
  7. def normalizeString(s):
  8.     s = unicodeToAscii(s.lower().strip())
  9.     s = re.sub(r"([.!?])", r" \1", s)
  10.     s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
  11.     return s
复制代码

  • 文件读取函数
  1. def readLangs(lang1, lang2, reverse=False):
  2.     print("Reading lines...")
  3.     # 以行为单位读取文件
  4.     lines = open('N11/%s-%s.txt'%(lang1,lang2), encoding='utf-8').\
  5.             read().strip().split('\n')
  6.     # 将每一行放入一个列表中
  7.     # 一个列表中有两个元素,A语言文本与B语言文本
  8.     pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
  9.     # 创建Lang实例,并确认是否反转语言顺序
  10.     if reverse:
  11.         pairs       = [list(reversed(p)) for p in pairs]
  12.         input_lang  = Lang(lang2)
  13.         output_lang = Lang(lang1)
  14.     else:
  15.         input_lang  = Lang(lang1)
  16.         output_lang = Lang(lang2)
  17.     return input_lang, output_lang, pairs
复制代码
.startswith(eng_prefixes) 是字符串方法 startswith() 的调用。它用于查抄一个字符串是否以指定的前缀开始。
  1. MAX_LENGTH = 10      # 定义语料最长长度
  2. eng_prefixes = (
  3.     "i am ", "i m ",
  4.     "he is", "he s ",
  5.     "she is", "she s ",
  6.     "you are", "you re ",
  7.     "we are", "we re ",
  8.     "they are", "they re "
  9. )
  10. def filterPair(p):
  11.     return len(p[0].split(' ')) < MAX_LENGTH and \
  12.            len(p[1].split(' ')) < MAX_LENGTH and p[1].startswith(eng_prefixes)
  13. def filterPairs(pairs):
  14.     # 选取仅仅包含 eng_prefixes 开头的语料
  15.     return [pair for pair in pairs if filterPair(pair)]
复制代码
  1. def prepareData(lang1, lang2, reverse=False):
  2.     # 读取文件中的数据
  3.     input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
  4.     print("Read %s sentence pairs" % len(pairs))
  5.    
  6.     # 按条件选取语料
  7.     pairs = filterPairs(pairs[:])
  8.     print("Trimmed to %s sentence pairs" % len(pairs))
  9.     print("Counting words...")
  10.    
  11.     # 将语料保存至相应的语言类
  12.     for pair in pairs:
  13.         input_lang.addSentence(pair[0])
  14.         output_lang.addSentence(pair[1])
  15.         
  16.     # 打印语言类的信息   
  17.     print("Counted words:")
  18.     print(input_lang.name, input_lang.n_words)
  19.     print(output_lang.name, output_lang.n_words)
  20.     return input_lang, output_lang, pairs
  21. input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
  22. print(random.choice(pairs))
复制代码
代码输出
  1. Reading lines...
  2. Read 135842 sentence pairs
  3. Trimmed to 10599 sentence pairs
  4. Counting words...
  5. Counted words:
  6. fra 4345
  7. eng 2803
  8. ['je volerai vers la lune .', 'i m going to fly to the moon .']
复制代码
二、Seq2Seq 模型

  • 编码器(Encoder)
  1. class EncoderRNN(nn.Module):
  2.     def __init__(self, input_size, hidden_size):
  3.         super(EncoderRNN, self).__init__()
  4.         self.hidden_size = hidden_size
  5.         self.embedding   = nn.Embedding(input_size, hidden_size)
  6.         self.gru         = nn.GRU(hidden_size, hidden_size)
  7.     def forward(self, input, hidden):
  8.         embedded       = self.embedding(input).view(1, 1, -1)
  9.         output         = embedded
  10.         output, hidden = self.gru(output, hidden)
  11.         return output, hidden
  12.     def initHidden(self):
  13.         return torch.zeros(1, 1, self.hidden_size, device=device)
复制代码

  • 解码器(Decoder)
  1. class AttnDecoderRNN(nn.Module):
  2.     def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
  3.         super(AttnDecoderRNN, self).__init__()
  4.         self.hidden_size = hidden_size
  5.         self.output_size = output_size
  6.         self.dropout_p = dropout_p
  7.         self.max_length = max_length
  8.         self.embedding = nn.Embedding(self.output_size, self.hidden_size)
  9.         self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
  10.         self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
  11.         self.dropout = nn.Dropout(self.dropout_p)
  12.         self.gru = nn.GRU(self.hidden_size, self.hidden_size)
  13.         self.out = nn.Linear(self.hidden_size, self.output_size)
  14.     def forward(self, input, hidden, encoder_outputs):
  15.         embedded = self.embedding(input).view(1, 1, -1)
  16.         embedded = self.dropout(embedded)
  17.         attn_weights = F.softmax(
  18.             self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
  19.         attn_applied = torch.bmm(attn_weights.unsqueeze(0),
  20.                                  encoder_outputs.unsqueeze(0))
  21.         output = torch.cat((embedded[0], attn_applied[0]), 1)
  22.         output = self.attn_combine(output).unsqueeze(0)
  23.         output = F.relu(output)
  24.         output, hidden = self.gru(output, hidden)
  25.         output = F.log_softmax(self.out(output[0]), dim=1)
  26.         return output, hidden, attn_weights
  27.     def initHidden(self):
  28.         return torch.zeros(1, 1, self.hidden_size, device=device)
复制代码
三、训练

  • 数据预处理处罚
  1. # 将文本数字化,获取词汇index
  2. def indexesFromSentence(lang, sentence):
  3.     return [lang.word2index[word] for word in sentence.split(' ')]
  4. # 将数字化的文本,转化为tensor数据
  5. def tensorFromSentence(lang, sentence):
  6.     indexes = indexesFromSentence(lang, sentence)
  7.     indexes.append(EOS_token)
  8.     return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
  9. # 输入pair文本,输出预处理好的数据
  10. def tensorsFromPair(pair):
  11.     input_tensor  = tensorFromSentence(input_lang, pair[0])
  12.     target_tensor = tensorFromSentence(output_lang, pair[1])
  13.     return (input_tensor, target_tensor)
复制代码

  • 训练函数
  1. teacher_forcing_ratio = 0.5
  2. def train(input_tensor, target_tensor,
  3.           encoder, decoder,
  4.           encoder_optimizer, decoder_optimizer,
  5.           criterion, max_length=MAX_LENGTH):
  6.    
  7.     # 编码器初始化
  8.     encoder_hidden = encoder.initHidden()
  9.    
  10.     # grad属性归零
  11.     encoder_optimizer.zero_grad()
  12.     decoder_optimizer.zero_grad()
  13.     input_length  = input_tensor.size(0)
  14.     target_length = target_tensor.size(0)
  15.    
  16.     # 用于创建一个指定大小的全零张量(tensor),用作默认编码器输出
  17.     encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
  18.     loss = 0
  19.    
  20.     # 将处理好的语料送入编码器
  21.     for ei in range(input_length):
  22.         encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
  23.         encoder_outputs[ei]            = encoder_output[0, 0]
  24.    
  25.     # 解码器默认输出
  26.     decoder_input  = torch.tensor([[SOS_token]], device=device)
  27.     decoder_hidden = encoder_hidden
  28.     use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
  29.    
  30.     # 将编码器处理好的输出送入解码器
  31.     if use_teacher_forcing:
  32.         # Teacher forcing: Feed the target as the next input
  33.         for di in range(target_length):
  34.             decoder_output, decoder_hidden, decoder_attention = decoder(
  35.                 decoder_input, decoder_hidden, encoder_outputs)
  36.             
  37.             loss         += criterion(decoder_output, target_tensor[di])
  38.             decoder_input = target_tensor[di]  # Teacher forcing
  39.     else:
  40.         # Without teacher forcing: use its own predictions as the next input
  41.         for di in range(target_length):
  42.             decoder_output, decoder_hidden, decoder_attention = decoder(
  43.                 decoder_input, decoder_hidden, encoder_outputs)
  44.             
  45.             topv, topi    = decoder_output.topk(1)
  46.             decoder_input = topi.squeeze().detach()  # detach from history as input
  47.             loss         += criterion(decoder_output, target_tensor[di])
  48.             if decoder_input.item() == EOS_token:
  49.                 break
  50.     loss.backward()
  51.     encoder_optimizer.step()
  52.     decoder_optimizer.step()
  53.     return loss.item() / target_length
复制代码
在序列天生的使掷中,如机器翻译或文本天生,解码器(decoder)的输入通常是由解码器自己天生的预测结果,即前一个时间步的输出。然而,这种自回归方式大概存在一个问题,即在训练过程中,解码器大概会产生累积误差,并导致输出与目标序列渐渐偏离。
为了解决这个问题,引入了一种称为"Teacher Forcing"的技术。在训练过程中,Teacher Forcing将目标序列的真实值作为解码器的输入,而不是使用解码器自己的预测结果。这样可以提供更准确的引导信号,帮助解码器更快地学习到精确的输出。
在这段代码中,use_teacher_forcing变量用于确定解码器在训练阶段使用何种策略作为下一个输入。
当use_teacher_forcing为True时,接纳"Teacher Forcing"的策略,即将目标序列中的真实标签作为解码器的下一个输入。而当use_teacher_forcing为False时,接纳"Without Teacher Forcing"的策略,即将解码器自身的预测作为下一个输入。
使用use_teacher_forcing的目的是在训练过程中平衡解码器的预测能力和稳固性。以下是对两种策略的表明:
   

  • Teacher Forcing: 在每个时间步(di循环中),解码器的输入都是目标序列中的真实标签。这样做的好处是,解码器可以直接获得精确的输入信息,加速训练速度,并且在训练早期提供更准确的梯度信号,帮助解码器更好地学习。然而,过分依靠目标序列大概会导致模型过于敏感,一旦目标序列中出现错误,大概会在解码器中产生累积的误差。
  • Without Teacher Forcing: 在每个时间步,解码器的输入是前一个时间步的预测输出。这样做的好处是,解码器必要依靠自身的预测能力来天生下一个输入,从而更好地顺应真实应用场景中大概出现的输入变化。这种策略可以提高模型的稳固性,但大概会导致训练过程更加困难,特殊是在初始阶段。
  一样平常来说,Teacher Forcing策略在训练过程中可以帮助模型快速收敛,而Without Teacher Forcing策略则更接近真实应用中的天生场景。通常会使用一定比例的Teacher Forcing,在训练过程中渐渐减小这个比例,以便模型渐渐过渡到更自主的天生模式。
综上所述,通过使用use_teacher_forcing来选择不同的策略,可以在训练解码器时平衡模型的预测能力和稳固性,同时也提供了更灵活的天生模式选择。
   

  • topv, topi = decoder_output.topk(1)
    这一行代码使用.topk(1)函数从decoder_output中获取最大的元素及其对应的索引。decoder_output是一个张量(tensor),它包罗了解码器的输出结果,大概是一个概率分布或是其他的数值。.topk(1)函数将返回两个张量:topv和topi。topv是最大的元素值,而topi是对应的索引值。
  • decoder_input = topi.squeeze().detach() 这一行代码对topi进行处理处罚,以便作为下一个解码器的输入。起首,.squeeze()函数被调用,它的作用是去除张量中维度为1的维度,从而将topi的形状进行压缩。然后,.detach()函数被调用,它的作用是将张量从盘算图中分离出来,使得在后续的盘算中不会对该张量进行梯度盘算。最后,将处理处罚后的张量赋值给decoder_input,作为下一个解码器的输入。
  1. import time
  2. import math
  3. def asMinutes(s):
  4.     m = math.floor(s / 60)
  5.     s -= m * 60
  6.     return '%dm %ds' % (m, s)
  7. def timeSince(since, percent):
  8.     now = time.time()
  9.     s = now - since
  10.     es = s / (percent)
  11.     rs = es - s
  12.     return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
复制代码
  1. def trainIters(encoder,decoder,n_iters,print_every=1000,
  2.                plot_every=100,learning_rate=0.01):
  3.    
  4.     start = time.time()
  5.     plot_losses      = []
  6.     print_loss_total = 0  # Reset every print_every
  7.     plot_loss_total  = 0  # Reset every plot_every
  8.     encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
  9.     decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
  10.    
  11.     # 在 pairs 中随机选取 n_iters 条数据用作训练集
  12.     training_pairs    = [tensorsFromPair(random.choice(pairs)) for i in range(n_iters)]
  13.     criterion         = nn.NLLLoss()
  14.     for iter in range(1, n_iters + 1):
  15.         training_pair = training_pairs[iter - 1]
  16.         input_tensor  = training_pair[0]
  17.         target_tensor = training_pair[1]
  18.         loss = train(input_tensor, target_tensor, encoder,
  19.                      decoder, encoder_optimizer, decoder_optimizer, criterion)
  20.         print_loss_total += loss
  21.         plot_loss_total  += loss
  22.         if iter % print_every == 0:
  23.             print_loss_avg   = print_loss_total / print_every
  24.             print_loss_total = 0
  25.             print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
  26.                                          iter, iter / n_iters * 100, print_loss_avg))
  27.         if iter % plot_every == 0:
  28.             plot_loss_avg = plot_loss_total / plot_every
  29.             plot_losses.append(plot_loss_avg)
  30.             plot_loss_total = 0
  31.     return plot_losses
复制代码

  • 评估
  1. def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
  2.     with torch.no_grad():
  3.         input_tensor    = tensorFromSentence(input_lang, sentence)
  4.         input_length    = input_tensor.size()[0]
  5.         encoder_hidden  = encoder.initHidden()
  6.         encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
  7.         for ei in range(input_length):
  8.             encoder_output, encoder_hidden = encoder(input_tensor[ei],encoder_hidden)
  9.             encoder_outputs[ei]           += encoder_output[0, 0]
  10.         decoder_input  = torch.tensor([[SOS_token]], device=device)  # SOS
  11.         decoder_hidden = encoder_hidden
  12.         decoded_words  = []
  13.         decoder_attentions = torch.zeros(max_length, max_length)
  14.         for di in range(max_length):
  15.             decoder_output, decoder_hidden, decoder_attention = decoder(
  16.                 decoder_input, decoder_hidden, encoder_outputs)
  17.             
  18.             decoder_attentions[di] = decoder_attention.data
  19.             topv, topi             = decoder_output.data.topk(1)
  20.             
  21.             if topi.item() == EOS_token:
  22.                 decoded_words.append('<EOS>')
  23.                 break
  24.             else:
  25.                 decoded_words.append(output_lang.index2word[topi.item()])
  26.             decoder_input = topi.squeeze().detach()
  27.         return decoded_words, decoder_attentions[:di + 1]
复制代码
  1. def evaluateRandomly(encoder, decoder, n=5):
  2.     for i in range(n):
  3.         pair = random.choice(pairs)
  4.         print('>', pair[0])
  5.         print('=', pair[1])
  6.         output_words, attentions = evaluate(encoder, decoder, pair[0])
  7.         output_sentence = ' '.join(output_words)
  8.         print('<', output_sentence)
  9.         print('')
复制代码
四、训练与评估
  1. hidden_size   = 256
  2. encoder1      = EncoderRNN(input_lang.n_words, hidden_size).to(device)
  3. attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
  4. plot_losses   = trainIters(encoder1, attn_decoder1, 10000, print_every=5000)
复制代码
代码输出
  1. 6m 41s (- 6m 41s) (5000 50%) 2.8497
  2. 13m 28s (- 0m 0s) (10000 100%) 2.2939
复制代码
  1. evaluateRandomly(encoder1, attn_decoder1)
复制代码
代码输出
  1. > tu es en grave danger .
  2. = you re in serious danger .
  3. < you are the of . . <EOS>
  4. > il est parfait pour le poste .
  5. = he is just right for the job .
  6. < he is out to the . . <EOS>
  7. > je te quitte demain .
  8. = i m leaving you tomorrow .
  9. < i am glad to . . <EOS>
  10. > c est un auteur .
  11. = he s an author .
  12. < he s a good . <EOS>
  13. > nous sommes des prisonniers .
  14. = we re prisoners .
  15. < we re in . <EOS>
复制代码

  • Loss图
  1. import matplotlib.pyplot as plt
  2. #隐藏警告
  3. import warnings
  4. warnings.filterwarnings("ignore")               # 忽略警告信息
  5. # plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
  6. plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
  7. plt.rcParams['figure.dpi']         = 100        # 分辨率
  8. epochs_range = range(len(plot_losses))
  9. plt.figure(figsize=(8, 3))
  10. plt.subplot(1, 1, 1)
  11. plt.plot(epochs_range, plot_losses, label='Training Loss')
  12. plt.legend(loc='upper right')
  13. plt.title('Training Loss')
  14. plt.show()
复制代码
代码输出


  • 可视化注意力
  1. import matplotlib.pyplot as plt
  2. output_words, attentions = evaluate(encoder1, attn_decoder1, "je suis trop froid .")
  3. plt.matshow(attentions.numpy())
复制代码
代码输出
  1. <matplotlib.image.AxesImage at 0x1f912b9d600>
复制代码

  1. import matplotlib.ticker as ticker
  2. #隐藏警告
  3. import warnings
  4. warnings.filterwarnings("ignore")               # 忽略警告信息
  5. def showAttention(input_sentence, output_words, attentions):
  6.     # Set up figure with colorbar
  7.     fig = plt.figure()
  8.     ax = fig.add_subplot(111)
  9.     cax = ax.matshow(attentions.numpy(), cmap='bone')
  10.     fig.colorbar(cax)
  11.     # Set up axes
  12.     ax.set_xticklabels([''] + input_sentence.split(' ') +
  13.                        ['<EOS>'], rotation=90)
  14.     ax.set_yticklabels([''] + output_words)
  15.     # Show label at every tick
  16.     ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
  17.     ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
  18.     plt.show()
  19. def evaluateAndShowAttention(input_sentence):
  20.     output_words, attentions = evaluate(
  21.         encoder1, attn_decoder1, input_sentence)
  22.     print('input =', input_sentence)
  23.     print('output =', ' '.join(output_words))
  24.     showAttention(input_sentence, output_words, attentions)
  25. evaluateAndShowAttention("elle a cinq ans de moins que moi .")
  26. evaluateAndShowAttention("elle est trop petit .")
  27. evaluateAndShowAttention("je ne crains pas de mourir .")
  28. evaluateAndShowAttention("c est un jeune directeur plein de talent .")
复制代码
代码输出(下面的内容全都是代码运行输出的结果)
  1. input = elle a cinq ans de moins que moi .
  2. output = she s taller than me than me me . .
复制代码

  1. input = elle est trop petit .
  2. output = she s too old . <EOS>
复制代码

  1. input = je ne crains pas de mourir .
  2. output = i m not going to . . . <EOS>
复制代码

  1. input = c est un jeune directeur plein de talent .
  2. output = he s a good at . . <EOS>
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

反转基因福娃

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表