IT评测·应用市场-qidao123.com
标题:
第N11周:seq2seq翻译实战-Pytorch复现
[打印本页]
作者:
反转基因福娃
时间:
2024-8-24 12:57
标题:
第N11周:seq2seq翻译实战-Pytorch复现
使命:
●为解码器添加上注意力机制
一、前期预备工作
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu
")
print(device)
复制代码
代码输出
cpu
复制代码
搭建语言类
SOS_token = 0
EOS_token = 1
# 语言类,方便对语料库进行操作
class Lang:
def __init__(self, name):
self.name = name
self.word2index = {}
self.word2count = {}
self.index2word = {0: "SOS", 1: "EOS"}
self.n_words = 2 # Count SOS and EOS
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
复制代码
文本处理处罚函数
def unicodeToAscii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
# 小写化,剔除标点与非字母符号
def normalizeString(s):
s = unicodeToAscii(s.lower().strip())
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
return s
复制代码
文件读取函数
def readLangs(lang1, lang2, reverse=False):
print("Reading lines...")
# 以行为单位读取文件
lines = open('N11/%s-%s.txt'%(lang1,lang2), encoding='utf-8').\
read().strip().split('\n')
# 将每一行放入一个列表中
# 一个列表中有两个元素,A语言文本与B语言文本
pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
# 创建Lang实例,并确认是否反转语言顺序
if reverse:
pairs = [list(reversed(p)) for p in pairs]
input_lang = Lang(lang2)
output_lang = Lang(lang1)
else:
input_lang = Lang(lang1)
output_lang = Lang(lang2)
return input_lang, output_lang, pairs
复制代码
.startswith(eng_prefixes) 是字符串方法 startswith() 的调用。它用于查抄一个字符串是否以指定的前缀开始。
MAX_LENGTH = 10 # 定义语料最长长度
eng_prefixes = (
"i am ", "i m ",
"he is", "he s ",
"she is", "she s ",
"you are", "you re ",
"we are", "we re ",
"they are", "they re "
)
def filterPair(p):
return len(p[0].split(' ')) < MAX_LENGTH and \
len(p[1].split(' ')) < MAX_LENGTH and p[1].startswith(eng_prefixes)
def filterPairs(pairs):
# 选取仅仅包含 eng_prefixes 开头的语料
return [pair for pair in pairs if filterPair(pair)]
复制代码
def prepareData(lang1, lang2, reverse=False):
# 读取文件中的数据
input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
print("Read %s sentence pairs" % len(pairs))
# 按条件选取语料
pairs = filterPairs(pairs[:])
print("Trimmed to %s sentence pairs" % len(pairs))
print("Counting words...")
# 将语料保存至相应的语言类
for pair in pairs:
input_lang.addSentence(pair[0])
output_lang.addSentence(pair[1])
# 打印语言类的信息
print("Counted words:")
print(input_lang.name, input_lang.n_words)
print(output_lang.name, output_lang.n_words)
return input_lang, output_lang, pairs
input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
print(random.choice(pairs))
复制代码
代码输出
Reading lines...
Read 135842 sentence pairs
Trimmed to 10599 sentence pairs
Counting words...
Counted words:
fra 4345
eng 2803
['je volerai vers la lune .', 'i m going to fly to the moon .']
复制代码
二、Seq2Seq 模型
编码器(Encoder)
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size)
def forward(self, input, hidden):
embedded = self.embedding(input).view(1, 1, -1)
output = embedded
output, hidden = self.gru(output, hidden)
return output, hidden
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
复制代码
解码器(Decoder)
class AttnDecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
super(AttnDecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.max_length = max_length
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.out = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
attn_weights = F.softmax(
self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
encoder_outputs.unsqueeze(0))
output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1)
return output, hidden, attn_weights
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
复制代码
三、训练
数据预处理处罚
# 将文本数字化,获取词汇index
def indexesFromSentence(lang, sentence):
return [lang.word2index[word] for word in sentence.split(' ')]
# 将数字化的文本,转化为tensor数据
def tensorFromSentence(lang, sentence):
indexes = indexesFromSentence(lang, sentence)
indexes.append(EOS_token)
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
# 输入pair文本,输出预处理好的数据
def tensorsFromPair(pair):
input_tensor = tensorFromSentence(input_lang, pair[0])
target_tensor = tensorFromSentence(output_lang, pair[1])
return (input_tensor, target_tensor)
复制代码
训练函数
teacher_forcing_ratio = 0.5
def train(input_tensor, target_tensor,
encoder, decoder,
encoder_optimizer, decoder_optimizer,
criterion, max_length=MAX_LENGTH):
# 编码器初始化
encoder_hidden = encoder.initHidden()
# grad属性归零
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
input_length = input_tensor.size(0)
target_length = target_tensor.size(0)
# 用于创建一个指定大小的全零张量(tensor),用作默认编码器输出
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
loss = 0
# 将处理好的语料送入编码器
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0]
# 解码器默认输出
decoder_input = torch.tensor([[SOS_token]], device=device)
decoder_hidden = encoder_hidden
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
# 将编码器处理好的输出送入解码器
if use_teacher_forcing:
# Teacher forcing: Feed the target as the next input
for di in range(target_length):
decoder_output, decoder_hidden, decoder_attention = decoder(
decoder_input, decoder_hidden, encoder_outputs)
loss += criterion(decoder_output, target_tensor[di])
decoder_input = target_tensor[di] # Teacher forcing
else:
# Without teacher forcing: use its own predictions as the next input
for di in range(target_length):
decoder_output, decoder_hidden, decoder_attention = decoder(
decoder_input, decoder_hidden, encoder_outputs)
topv, topi = decoder_output.topk(1)
decoder_input = topi.squeeze().detach() # detach from history as input
loss += criterion(decoder_output, target_tensor[di])
if decoder_input.item() == EOS_token:
break
loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
return loss.item() / target_length
复制代码
在序列天生的使掷中,如机器翻译或文本天生,解码器(decoder)的输入通常是由解码器自己天生的预测结果,即前一个时间步的输出。然而,这种自回归方式大概存在一个问题,即在训练过程中,解码器大概会产生累积误差,并导致输出与目标序列渐渐偏离。
为了解决这个问题,引入了一种称为"Teacher Forcing"的技术。在训练过程中,Teacher Forcing将目标序列的真实值作为解码器的输入,而不是使用解码器自己的预测结果。这样可以提供更准确的引导信号,帮助解码器更快地学习到精确的输出。
在这段代码中,use_teacher_forcing变量用于确定解码器在训练阶段使用何种策略作为下一个输入。
当use_teacher_forcing为True时,接纳"Teacher Forcing"的策略,即将目标序列中的真实标签作为解码器的下一个输入。而当use_teacher_forcing为False时,接纳"Without Teacher Forcing"的策略,即将解码器自身的预测作为下一个输入。
使用use_teacher_forcing的目的是在训练过程中平衡解码器的预测能力和稳固性。以下是对两种策略的表明:
Teacher Forcing: 在每个时间步(di循环中),解码器的输入都是目标序列中的真实标签。这样做的好处是,解码器可以直接获得精确的输入信息,加速训练速度,并且在训练早期提供更准确的梯度信号,帮助解码器更好地学习。然而,过分依靠目标序列大概会导致模型过于敏感,一旦目标序列中出现错误,大概会在解码器中产生累积的误差。
Without Teacher Forcing: 在每个时间步,解码器的输入是前一个时间步的预测输出。这样做的好处是,解码器必要依靠自身的预测能力来天生下一个输入,从而更好地顺应真实应用场景中大概出现的输入变化。这种策略可以提高模型的稳固性,但大概会导致训练过程更加困难,特殊是在初始阶段。
一样平常来说,Teacher Forcing策略在训练过程中可以帮助模型快速收敛,而Without Teacher Forcing策略则更接近真实应用中的天生场景。通常会使用一定比例的Teacher Forcing,在训练过程中渐渐减小这个比例,以便模型渐渐过渡到更自主的天生模式。
综上所述,通过使用use_teacher_forcing来选择不同的策略,可以在训练解码器时平衡模型的预测能力和稳固性,同时也提供了更灵活的天生模式选择。
topv, topi = decoder_output.topk(1)
这一行代码使用.topk(1)函数从decoder_output中获取最大的元素及其对应的索引。decoder_output是一个张量(tensor),它包罗了解码器的输出结果,大概是一个概率分布或是其他的数值。.topk(1)函数将返回两个张量:topv和topi。topv是最大的元素值,而topi是对应的索引值。
decoder_input = topi.squeeze().detach() 这一行代码对topi进行处理处罚,以便作为下一个解码器的输入。起首,.squeeze()函数被调用,它的作用是去除张量中维度为1的维度,从而将topi的形状进行压缩。然后,.detach()函数被调用,它的作用是将张量从盘算图中分离出来,使得在后续的盘算中不会对该张量进行梯度盘算。最后,将处理处罚后的张量赋值给decoder_input,作为下一个解码器的输入。
import time
import math
def asMinutes(s):
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s / (percent)
rs = es - s
return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
复制代码
def trainIters(encoder,decoder,n_iters,print_every=1000,
plot_every=100,learning_rate=0.01):
start = time.time()
plot_losses = []
print_loss_total = 0 # Reset every print_every
plot_loss_total = 0 # Reset every plot_every
encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
# 在 pairs 中随机选取 n_iters 条数据用作训练集
training_pairs = [tensorsFromPair(random.choice(pairs)) for i in range(n_iters)]
criterion = nn.NLLLoss()
for iter in range(1, n_iters + 1):
training_pair = training_pairs[iter - 1]
input_tensor = training_pair[0]
target_tensor = training_pair[1]
loss = train(input_tensor, target_tensor, encoder,
decoder, encoder_optimizer, decoder_optimizer, criterion)
print_loss_total += loss
plot_loss_total += loss
if iter % print_every == 0:
print_loss_avg = print_loss_total / print_every
print_loss_total = 0
print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
iter, iter / n_iters * 100, print_loss_avg))
if iter % plot_every == 0:
plot_loss_avg = plot_loss_total / plot_every
plot_losses.append(plot_loss_avg)
plot_loss_total = 0
return plot_losses
复制代码
评估
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
with torch.no_grad():
input_tensor = tensorFromSentence(input_lang, sentence)
input_length = input_tensor.size()[0]
encoder_hidden = encoder.initHidden()
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(input_tensor[ei],encoder_hidden)
encoder_outputs[ei] += encoder_output[0, 0]
decoder_input = torch.tensor([[SOS_token]], device=device) # SOS
decoder_hidden = encoder_hidden
decoded_words = []
decoder_attentions = torch.zeros(max_length, max_length)
for di in range(max_length):
decoder_output, decoder_hidden, decoder_attention = decoder(
decoder_input, decoder_hidden, encoder_outputs)
decoder_attentions[di] = decoder_attention.data
topv, topi = decoder_output.data.topk(1)
if topi.item() == EOS_token:
decoded_words.append('<EOS>')
break
else:
decoded_words.append(output_lang.index2word[topi.item()])
decoder_input = topi.squeeze().detach()
return decoded_words, decoder_attentions[:di + 1]
复制代码
def evaluateRandomly(encoder, decoder, n=5):
for i in range(n):
pair = random.choice(pairs)
print('>', pair[0])
print('=', pair[1])
output_words, attentions = evaluate(encoder, decoder, pair[0])
output_sentence = ' '.join(output_words)
print('<', output_sentence)
print('')
复制代码
四、训练与评估
hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
plot_losses = trainIters(encoder1, attn_decoder1, 10000, print_every=5000)
复制代码
代码输出
6m 41s (- 6m 41s) (5000 50%) 2.8497
13m 28s (- 0m 0s) (10000 100%) 2.2939
复制代码
evaluateRandomly(encoder1, attn_decoder1)
复制代码
代码输出
> tu es en grave danger .
= you re in serious danger .
< you are the of . . <EOS>
> il est parfait pour le poste .
= he is just right for the job .
< he is out to the . . <EOS>
> je te quitte demain .
= i m leaving you tomorrow .
< i am glad to . . <EOS>
> c est un auteur .
= he s an author .
< he s a good . <EOS>
> nous sommes des prisonniers .
= we re prisoners .
< we re in . <EOS>
复制代码
Loss图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") # 忽略警告信息
# plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 # 分辨率
epochs_range = range(len(plot_losses))
plt.figure(figsize=(8, 3))
plt.subplot(1, 1, 1)
plt.plot(epochs_range, plot_losses, label='Training Loss')
plt.legend(loc='upper right')
plt.title('Training Loss')
plt.show()
复制代码
代码输出
可视化注意力
import matplotlib.pyplot as plt
output_words, attentions = evaluate(encoder1, attn_decoder1, "je suis trop froid .")
plt.matshow(attentions.numpy())
复制代码
代码输出
<matplotlib.image.AxesImage at 0x1f912b9d600>
复制代码
import matplotlib.ticker as ticker
#隐藏警告
import warnings
warnings.filterwarnings("ignore") # 忽略警告信息
def showAttention(input_sentence, output_words, attentions):
# Set up figure with colorbar
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(attentions.numpy(), cmap='bone')
fig.colorbar(cax)
# Set up axes
ax.set_xticklabels([''] + input_sentence.split(' ') +
['<EOS>'], rotation=90)
ax.set_yticklabels([''] + output_words)
# Show label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.show()
def evaluateAndShowAttention(input_sentence):
output_words, attentions = evaluate(
encoder1, attn_decoder1, input_sentence)
print('input =', input_sentence)
print('output =', ' '.join(output_words))
showAttention(input_sentence, output_words, attentions)
evaluateAndShowAttention("elle a cinq ans de moins que moi .")
evaluateAndShowAttention("elle est trop petit .")
evaluateAndShowAttention("je ne crains pas de mourir .")
evaluateAndShowAttention("c est un jeune directeur plein de talent .")
复制代码
代码输出(下面的内容全都是代码运行输出的结果)
input = elle a cinq ans de moins que moi .
output = she s taller than me than me me . .
复制代码
input = elle est trop petit .
output = she s too old . <EOS>
复制代码
input = je ne crains pas de mourir .
output = i m not going to . . . <EOS>
复制代码
input = c est un jeune directeur plein de talent .
output = he s a good at . . <EOS>
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
欢迎光临 IT评测·应用市场-qidao123.com (https://dis.qidao123.com/)
Powered by Discuz! X3.4