ToB企服应用市场:ToB评测及商务社交产业平台

标题: python-pytorch编写transformer模型实现翻译0.5.00-训练与预测 [打印本页]

作者: 傲渊山岳    时间: 2024-6-21 05:35
标题: python-pytorch编写transformer模型实现翻译0.5.00-训练与预测
接上一篇文章
https://blog.csdn.net/m0_60688978/article/details/139359541?csdn_share_tail=%7B%22type%22%3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%22139359541%22%2C%22source%22%3A%22m0_60688978%22%7D
  1. sentences = [
  2.     ['咖哥 喜欢 小冰', 'KaGe likes XiaoBing'],
  3.     ['我 爱 学习 人工智能', 'I love studying AI'],
  4.     ['深度学习 改变 世界', ' DL changed the world'],
  5.     ['自然语言处理 很 强大', 'NLP is powerful'],
  6.     ['神经网络 非常 复杂', 'Neural-networks are complex'] ]
  7. class TranslationCorpus:
  8.     def __init__(self, sentences):
  9.         self.sentences = sentences
  10.         # 计算源语言和目标语言的最大句子长度,并分别加 1 和 2 以容纳填充符和特殊符号
  11.         self.src_len = max(len(sentence[0].split()) for sentence in sentences) + 1
  12.         self.tgt_len = max(len(sentence[1].split()) for sentence in sentences) + 2
  13.         # 创建源语言和目标语言的词汇表
  14.         self.src_vocab, self.tgt_vocab = self.create_vocabularies()
  15.         # 创建索引到单词的映射
  16.         self.src_idx2word = {v: k for k, v in self.src_vocab.items()}
  17.         self.tgt_idx2word = {v: k for k, v in self.tgt_vocab.items()}
  18.     # 定义创建词汇表的函数
  19.     def create_vocabularies(self):
  20.         # 统计源语言和目标语言的单词频率
  21.         src_counter = Counter(word for sentence in self.sentences for word in sentence[0].split())
  22.         tgt_counter = Counter(word for sentence in self.sentences for word in sentence[1].split())        
  23.         # 创建源语言和目标语言的词汇表,并为每个单词分配一个唯一的索引
  24.         src_vocab = {'<pad>': 0, **{word: i+1 for i, word in enumerate(src_counter)}}
  25.         tgt_vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2,
  26.                      **{word: i+3 for i, word in enumerate(tgt_counter)}}        
  27.         return src_vocab, tgt_vocab
  28.     # 定义创建批次数据的函数
  29.     def make_batch(self, batch_size, test_batch=False):
  30.         input_batch, output_batch, target_batch = [], [], []
  31.         # 随机选择句子索引
  32.         sentence_indices = torch.randperm(len(self.sentences))[:batch_size]
  33.         for index in sentence_indices:
  34.             src_sentence, tgt_sentence = self.sentences[index]
  35.             # 将源语言和目标语言的句子转换为索引序列
  36.             src_seq = [self.src_vocab[word] for word in src_sentence.split()]
  37.             tgt_seq = [self.tgt_vocab['<sos>']] + [self.tgt_vocab[word] \
  38.                          for word in tgt_sentence.split()] + [self.tgt_vocab['<eos>']]            
  39.             # 对源语言和目标语言的序列进行填充
  40.             src_seq += [self.src_vocab['<pad>']] * (self.src_len - len(src_seq))
  41.             tgt_seq += [self.tgt_vocab['<pad>']] * (self.tgt_len - len(tgt_seq))            
  42.             # 将处理好的序列添加到批次中
  43.             input_batch.append(src_seq)
  44.             output_batch.append([self.tgt_vocab['<sos>']] + ([self.tgt_vocab['<pad>']] * \
  45.                                     (self.tgt_len - 2)) if test_batch else tgt_seq[:-1])
  46.             target_batch.append(tgt_seq[1:])        
  47.           # 将批次转换为 LongTensor 类型
  48.         input_batch = torch.LongTensor(input_batch)
  49.         output_batch = torch.LongTensor(output_batch)
  50.         target_batch = torch.LongTensor(target_batch)            
  51.         return input_batch, output_batch, target_batch
  52. # 创建语料库类实例
  53. corpus = TranslationCorpus(sentences)
  54. #训练
  55. import torch # 导入 torch
  56. import torch.optim as optim # 导入优化器
  57. model = Transformer(corpus) # 创建模型实例
  58. criterion = nn.CrossEntropyLoss() # 损失函数
  59. optimizer = optim.Adam(model.parameters(), lr=0.00001) # 优化器
  60. epochs = 1 # 训练轮次
  61. for epoch in range(epochs): # 训练 100 轮
  62.     optimizer.zero_grad() # 梯度清零
  63.     enc_inputs, dec_inputs, target_batch = corpus.make_batch(batch_size) # 创建训练数据   
  64.     print(enc_inputs, dec_inputs, target_batch)
  65.    
  66.    
  67.     outputs, _, _, _ = model(enc_inputs, dec_inputs) # 获取模型输出
  68.     loss = criterion(outputs.view(-1, len(corpus.tgt_vocab)), target_batch.view(-1)) # 计算损失
  69.     if (epoch + 1) % 1 == 0: # 打印损失
  70.         print(f"Epoch: {epoch + 1:04d} cost = {loss:.6f}")
  71.     loss.backward()# 反向传播        
  72.     optimizer.step()# 更新参数
  73. #预测
  74. # 创建一个大小为 1 的批次,目标语言序列 dec_inputs 在测试阶段,仅包含句子开始符号 <sos>
  75. enc_inputs, dec_inputs, target_batch = corpus.make_batch(batch_size=1,test_batch=True)
  76. # enc_inputs=torch.tensor([[14, 15, 16,  0,  0]])
  77. dec_inputs=torch.tensor([[1, 0, 0,  0,  0]])
  78. outt=1
  79. for i in range(5):
  80.     dec_inputs[0][i]=outt
  81.     print("+++",i,dec_inputs[0][i],dec_inputs,outt)
  82.     predict, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs) # 用模型进行翻译
  83.     predict = predict.view(-1, len(corpus.tgt_vocab)) # 将预测结果维度重塑
  84.     predict = predict.data.max(1, keepdim=True)[1] # 找到每个位置概率最大的词汇的索引
  85.     print(predict)
  86.     outt=predict[i].item()
  87.    
  88. print("编码器输入 :", enc_inputs) # 打印编码器输入
  89. print("解码器输入 :", dec_inputs) # 打印解码器输入
  90. print("目标数据 :", target_batch) # 打印目标数据
  91. predict, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs) # 用模型进行翻译
  92. print(predict.data.max(-1))
  93. predict = predict.view(-1, len(corpus.tgt_vocab)) # 将预测结果维度重塑
  94. predict = predict.data.max(1, keepdim=True)[1] # 找到每个位置概率最大的词汇的索引
  95. # 解码预测的输出,将所预测的目标句子中的索引转换为单词
  96. translated_sentence = [corpus.tgt_idx2word[idx.item()] for idx in predict.squeeze()]
  97. # 将输入的源语言句子中的索引转换为单词
  98. input_sentence = ' '.join([corpus.src_idx2word[idx.item()] for idx in enc_inputs[0]])
  99. print(input_sentence, '->', translated_sentence) # 打印原始句子和翻译后的句子
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4