马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
x
<hr> 前言
Seq2Seq 模型的核心头脑是将一个输入序列(例如英语句子)通过编码器(Encoder)转化为一个固定长度的上下文向量,再由解码器(Decoder)根据该向量生成目标序列(例如法语句子)。这种编码-解码的架构最初由 RNN 实现,厥后发展出 LSTM 和 Transformer 等变种。在本文中,我们将聚焦于基于 RNN 的经典实现,并通过 PyTorch 代码逐步拆解其关键组件。
本文的代码来源于一个完整的呆板翻译使命示例,数据集为英语-法语翻译对。我们将从数据加载与预处理开始,逐步构建编码器和解码器,末了通过 BLEU 分数评估翻译效果。全部代码都经过解释,确保易于明确,同时保留了附件中的完整性。
让我们开始吧!
<hr> 一、数据加载与预处理
Seq2Seq 模型的第一步是准备数据。我们需要将原始的英语-法语翻译对数据加载到内存中,并对其进行预处理和词元化(tokenization),以便后续输入到模型中。以下是相关代码及其解释:
1.1 读取数据
- from collections import Counter # 用于词频统计
- import torch # PyTorch 核心库
- from torch.utils import data # PyTorch 数据加载工具
- import numpy as np # NumPy 用于数组操作
- def read_data_nmt():
- """
- 载入“英语-法语”数据集
-
- 返回值:
- str: 文件内容的完整字符串
- """
- with open('fra.txt', 'r', encoding='utf-8') as f:
- return f.read()
复制代码 read_data_nmt 函数简单地读取名为 fra.txt 的文件,该文件包含英语和法语的翻译对,每行以制表符分隔。它返回整个文件的字符串内容,为后续处理奠定底子。
1.2 预处理数据
- def preprocess_nmt(text):
- """
- 预处理“英语-法语”数据集
-
- 参数:
- text (str): 输入的原始文本字符串
-
- 返回值:
- str: 处理后的文本字符串
- """
- def no_space(char, prev_char):
- """
- 判断当前字符是否需要前置空格
- """
- return char in set(',.!?') and prev_char != ' '
- # 使用空格替换不间断空格(\u202f)和非断行空格(\xa0),并转换为小写
- text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
-
- # 在单词和标点符号之间插入空格
- out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
- for i, char in enumerate(text)]
-
- return ''.join(out)
复制代码 preprocess_nmt 函数对文本进行标准化处理:
- 将特殊空格字符替换为平凡空格,并将全部字符转换为小写。
- 在标点符号(如逗号、句号)前插入空格,便于后续按空格分割词元。这种处理确保标点符号被视为独立的词元,而不是粘附在单词上。
1.3 词元化
- def tokenize_nmt(text, num_examples=None):
- """
- 词元化“英语-法语”数据集
-
- 参数:
- text (str): 输入的文本字符串,每行包含英语和法语句子,用制表符分隔
- num_examples (int, optional): 最大处理样本数,默认值为 None 表示处理全部
-
- 返回值:
- tuple: 包含两个列表的元组
- - source (list): 英语句子词元列表
- - target (list): 法语句子词元列表
- """
- source, target = [], []
- for i, line in enumerate(text.split('\n')):
- if num_examples and i > num_examples:
- break
- parts = line.split('\t')
- if len(parts) == 2:
- source.append(parts[0].split(' '))
- target.append(parts[1].split(' '))
- return source, target
复制代码 tokenize_nmt 函数将预处理后的文本按行分割,并进一步将每行按制表符分为英语和法语部分,然后按空格分割成词元列表。它返回两个列表:source(英语词元列表)和 target(法语词元列表)。
1.4 词频统计
- def count_corpus(tokens):
- """
- 统计词元的频率
-
- 参数:
- tokens: 词元列表,可以是一维或二维列表
-
- 返回值:
- Counter: Counter 对象,统计每个词元的出现次数
- """
- if not tokens:
- return Counter()
- if isinstance(tokens[0], list):
- flattened_tokens = [token for sublist in tokens for token in sublist]
- else:
- flattened_tokens = tokens
- return Counter(flattened_tokens)
复制代码 count_corpus 函数使用 Counter 类统计词元的出现频率,支持一维和二维列表输入。它是构建词汇表的底子工具。
1.5 构建词汇表
- class Vocab:
- """文本词表类,用于管理词元及其索引的映射关系"""
- def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
- """初始化词表"""
- self.tokens = tokens if tokens is not None else []
- self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
- counter = self._count_corpus(self.tokens)
- self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
- self.idx_to_token = ['<unk>'] + self.reserved_tokens
- self.token_to_idx = {
- token: idx for idx, token in enumerate(self.idx_to_token)}
- for token, freq in self._token_freqs:
- if freq < min_freq:
- break
- if token not in self.token_to_idx:
- self.idx_to_token.append(token)
- self.token_to_idx[token] = len(self.idx_to_token) - 1
- @staticmethod
- def _count_corpus(tokens):
- """统计词元频率"""
- if not tokens:
- return Counter()
- if isinstance(tokens[0], list):
- tokens = [token for sublist in tokens for token in sublist]
- return Counter(tokens)
- def __len__(self):
- return len(self.idx_to_token)
- def __getitem__(self, tokens):
- if not isinstance(tokens, (list, tuple)):
- return self.token_to_idx.get(tokens, self.unk)
- return [self[token] for token in tokens]
- def to_tokens(self, indices):
- if not isinstance(indices, (list, tuple)):
- return self.idx_to_token[indices]
- return [self.idx_to_token[index] for index in indices]
- @property
- def unk(self):
- return 0
- @property
- def token_freqs(self):
- return self._token_freqs
复制代码 Vocab 类用于构建词汇表并管理词元与索引之间的映射:
- 初始化时接受词元列表、最小频率阈值和预留特殊词元(如 <pad>、<bos>、<eos>)。
- 内部使用 count_corpus 统计词频,并按频率排序。
- 提供 __getitem__ 和 to_tokens 方法,分别用于词元到索引和索引到词元的转换。
- <unk> 表示未知词元,默认索引为 0。
1.6 截断与填充
- def truncate_pad(line, num_steps, padding_token):
- """
- 截断或填充文本序列
-
- 参数:
- line (list): 输入的文本序列(词元列表)
- num_steps (int): 目标序列长度
- padding_token (str): 用于填充的标记
-
- 返回值:
- list: 截断或填充后的序列,长度为 num_steps
- """
- if len(line) > num_steps:
- return line[:num_steps]
- return line + [padding_token] * (num_steps - len(line))
复制代码 truncate_pad 函数确保全部序列长度同等:
- 假如序列长度凌驾 num_steps,则截断。
- 假如不足,则用 padding_token(通常是 <pad>)填充。
1.7 转换为张量
- def build_array_nmt(lines, vocab, num_steps):
- """
- 将机器翻译的文本序列转换为小批量
-
- 参数:
- lines (list): 文本序列列表,每个元素是一个词元列表
- vocab (dict): 词汇表,将词元映射为索引
- num_steps (int): 目标序列长度
-
- 返回值:
- tuple: 包含两个元素的元组
- - array (torch.Tensor): 转换后的张量,形状为 (样本数, num_steps)
- - valid_len (np.ndarray): 每个序列的有效长度,形状为 (样本数,)
- """
- lines =
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |