从代码学习深度学习 - 序列到序列学习 GRU编解码器 PyTorch 版 ...

打印 上一主题 下一主题

主题 1508|帖子 1508|积分 4524

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
<hr> 前言

Seq2Seq 模型的核心头脑是将一个输入序列(例如英语句子)通过编码器(Encoder)转化为一个固定长度的上下文向量,再由解码器(Decoder)根据该向量生成目标序列(例如法语句子)。这种编码-解码的架构最初由 RNN 实现,厥后发展出 LSTM 和 Transformer 等变种。在本文中,我们将聚焦于基于 RNN 的经典实现,并通过 PyTorch 代码逐步拆解其关键组件。
本文的代码来源于一个完整的呆板翻译使命示例,数据集为英语-法语翻译对。我们将从数据加载与预处理开始,逐步构建编码器和解码器,末了通过 BLEU 分数评估翻译效果。全部代码都经过解释,确保易于明确,同时保留了附件中的完整性。
让我们开始吧!
<hr> 一、数据加载与预处理

Seq2Seq 模型的第一步是准备数据。我们需要将原始的英语-法语翻译对数据加载到内存中,并对其进行预处理和词元化(tokenization),以便后续输入到模型中。以下是相关代码及其解释:
1.1 读取数据

  1. from collections import Counter  # 用于词频统计
  2. import torch  # PyTorch 核心库
  3. from torch.utils import data  # PyTorch 数据加载工具
  4. import numpy as np  # NumPy 用于数组操作
  5. def read_data_nmt():
  6.     """
  7.     载入“英语-法语”数据集
  8.    
  9.     返回值:
  10.         str: 文件内容的完整字符串
  11.     """
  12.     with open('fra.txt', 'r', encoding='utf-8') as f:
  13.         return f.read()
复制代码
read_data_nmt 函数简单地读取名为 fra.txt 的文件,该文件包含英语和法语的翻译对,每行以制表符分隔。它返回整个文件的字符串内容,为后续处理奠定底子。
1.2 预处理数据

  1. def preprocess_nmt(text):
  2.     """
  3.     预处理“英语-法语”数据集
  4.    
  5.     参数:
  6.         text (str): 输入的原始文本字符串
  7.    
  8.     返回值:
  9.         str: 处理后的文本字符串
  10.     """
  11.     def no_space(char, prev_char):
  12.         """
  13.         判断当前字符是否需要前置空格
  14.         """
  15.         return char in set(',.!?') and prev_char != ' '
  16.     # 使用空格替换不间断空格(\u202f)和非断行空格(\xa0),并转换为小写
  17.     text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
  18.    
  19.     # 在单词和标点符号之间插入空格
  20.     out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
  21.            for i, char in enumerate(text)]
  22.            
  23.     return ''.join(out)
复制代码
preprocess_nmt 函数对文本进行标准化处理:

  • 将特殊空格字符替换为平凡空格,并将全部字符转换为小写。
  • 在标点符号(如逗号、句号)前插入空格,便于后续按空格分割词元。这种处理确保标点符号被视为独立的词元,而不是粘附在单词上。
1.3 词元化

  1. def tokenize_nmt(text, num_examples=None):
  2.     """
  3.     词元化“英语-法语”数据集
  4.    
  5.     参数:
  6.         text (str): 输入的文本字符串,每行包含英语和法语句子,用制表符分隔
  7.         num_examples (int, optional): 最大处理样本数,默认值为 None 表示处理全部
  8.    
  9.     返回值:
  10.         tuple: 包含两个列表的元组
  11.             - source (list): 英语句子词元列表
  12.             - target (list): 法语句子词元列表
  13.     """
  14.     source, target = [], []
  15.     for i, line in enumerate(text.split('\n')):
  16.         if num_examples and i > num_examples:
  17.             break
  18.         parts = line.split('\t')
  19.         if len(parts) == 2:
  20.             source.append(parts[0].split(' '))
  21.             target.append(parts[1].split(' '))
  22.     return source, target
复制代码
tokenize_nmt 函数将预处理后的文本按行分割,并进一步将每行按制表符分为英语和法语部分,然后按空格分割成词元列表。它返回两个列表:source(英语词元列表)和 target(法语词元列表)。
1.4 词频统计

  1. def count_corpus(tokens):
  2.     """
  3.     统计词元的频率
  4.    
  5.     参数:
  6.         tokens: 词元列表,可以是一维或二维列表
  7.    
  8.     返回值:
  9.         Counter: Counter 对象,统计每个词元的出现次数
  10.     """
  11.     if not tokens:
  12.         return Counter()
  13.     if isinstance(tokens[0], list):
  14.         flattened_tokens = [token for sublist in tokens for token in sublist]
  15.     else:
  16.         flattened_tokens = tokens
  17.     return Counter(flattened_tokens)
复制代码
count_corpus 函数使用 Counter 类统计词元的出现频率,支持一维和二维列表输入。它是构建词汇表的底子工具。
1.5 构建词汇表

  1. class Vocab:
  2.     """文本词表类,用于管理词元及其索引的映射关系"""
  3.     def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
  4.         """初始化词表"""
  5.         self.tokens = tokens if tokens is not None else []
  6.         self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
  7.         counter = self._count_corpus(self.tokens)
  8.         self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
  9.         self.idx_to_token = ['<unk>'] + self.reserved_tokens
  10.         self.token_to_idx = {
  11.    token: idx for idx, token in enumerate(self.idx_to_token)}
  12.         for token, freq in self._token_freqs:
  13.             if freq < min_freq:
  14.                 break
  15.             if token not in self.token_to_idx:
  16.                 self.idx_to_token.append(token)
  17.                 self.token_to_idx[token] = len(self.idx_to_token) - 1
  18.     @staticmethod
  19.     def _count_corpus(tokens):
  20.         """统计词元频率"""
  21.         if not tokens:
  22.             return Counter()
  23.         if isinstance(tokens[0], list):
  24.             tokens = [token for sublist in tokens for token in sublist]
  25.         return Counter(tokens)
  26.     def __len__(self):
  27.         return len(self.idx_to_token)
  28.     def __getitem__(self, tokens):
  29.         if not isinstance(tokens, (list, tuple)):
  30.             return self.token_to_idx.get(tokens, self.unk)
  31.         return [self[token] for token in tokens]
  32.     def to_tokens(self, indices):
  33.         if not isinstance(indices, (list, tuple)):
  34.             return self.idx_to_token[indices]
  35.         return [self.idx_to_token[index] for index in indices]
  36.     @property
  37.     def unk(self):
  38.         return 0
  39.     @property
  40.     def token_freqs(self):
  41.         return self._token_freqs
复制代码
Vocab 类用于构建词汇表并管理词元与索引之间的映射:


  • 初始化时接受词元列表、最小频率阈值和预留特殊词元(如 <pad>、<bos>、<eos>)。
  • 内部使用 count_corpus 统计词频,并按频率排序。
  • 提供 __getitem__ 和 to_tokens 方法,分别用于词元到索引和索引到词元的转换。
  • <unk> 表示未知词元,默认索引为 0。
1.6 截断与填充

  1. def truncate_pad(line, num_steps, padding_token):
  2.     """
  3.     截断或填充文本序列
  4.    
  5.     参数:
  6.         line (list): 输入的文本序列(词元列表)
  7.         num_steps (int): 目标序列长度
  8.         padding_token (str): 用于填充的标记
  9.    
  10.     返回值:
  11.         list: 截断或填充后的序列,长度为 num_steps
  12.     """
  13.     if len(line) > num_steps:
  14.         return line[:num_steps]
  15.     return line + [padding_token] * (num_steps - len(line))
复制代码
truncate_pad 函数确保全部序列长度同等:


  • 假如序列长度凌驾 num_steps,则截断。
  • 假如不足,则用 padding_token(通常是 <pad>)填充。
1.7 转换为张量

  1. def build_array_nmt(lines, vocab, num_steps):
  2.     """
  3.     将机器翻译的文本序列转换为小批量
  4.    
  5.     参数:
  6.         lines (list): 文本序列列表,每个元素是一个词元列表
  7.         vocab (dict): 词汇表,将词元映射为索引
  8.         num_steps (int): 目标序列长度
  9.    
  10.     返回值:
  11.         tuple: 包含两个元素的元组
  12.             - array (torch.Tensor): 转换后的张量,形状为 (样本数, num_steps)
  13.             - valid_len (np.ndarray): 每个序列的有效长度,形状为 (样本数,)
  14.     """
  15.     lines =
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

干翻全岛蛙蛙

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表