呆板学习周记(第四十二周:AT-LSTM)2024.6.3~2024.6.9

打印 上一主题 下一主题

主题 668|帖子 668|积分 2004

择要

本周阅读了题为Water Quality Prediction Based on LSTM and Attention Mechanism: A Case Study of the Burnett River, Australia的论文。这项工作提出了一种基于长期短期记忆的神经网络和 注意力机制的肴杂模子——AT-LSTM。其中,LSTM缺乏对子窗口特征进行差别程度关注的能力,这大概会导致一些干系信息被忽略,无法重视时间序列的紧张特征。该文应用注意力机制来有用捕获更远的关键信息,并通过在每个时间步对隐蔽层元素进行加权来增强紧张特征对预测模子的影响。如许就提出了一个集成模子并有用地得到了统计特征。基于真实数据的实验证明,注意力机制的加入进步了 LSTM 模子的预测性能。
Abstract

This week read the paper titled Water Quality Prediction Based on LSTM and Attention Mechanism: A Case Study of the Burnett River, Australia. This work proposes a hybrid model of neural network and attention mechanism based on long-term short-term memory, AT-LSTM. Among them, LSTM lacks the ability to pay attention to the features of sub-windows to varying degrees, which may lead to some relevant information being ignored and unable to pay attention to the important features of the time series. In this paper, the attention mechanism is applied to effectively capture the key information in the distance, and the influence of important features on the prediction model is enhanced by weighting the hidden layer elements at each time step. In this way, an ensemble model is proposed and statistical features are effectively obtained. Experiments based on real data show that the addition of attention mechanism improves the prediction performance of the LSTM model.
一、文献阅读

1. 题目

标题:Water Quality Prediction Based on LSTM and Attention Mechanism: A Case Study of the Burnett River, Australia
作者:Honglei Chen, etc.
期刊名:Sustainability
链接:https://doi.org/10.3390/su142013231
2. abstract

该研究旨在开发是非期记忆(LSTM)网络及其基于注意力的(AT-LSTM)模子,以实现澳大利亚伯内特河水质的预测。该研究开发的模子在对伯内特河断面水质数据进行特征提取后,考虑差别时间序列对预测结果的影响,引入注意力机制,增强关键特征对预测结果的影响。该研究利用 LSTM 和 AT-LSTM 模子对伯内特河的溶解氧 (DO) 进行一步预报和多步预报,并对结果进行比力。研究结果表明,注意力机制的加入进步了 LSTM 模子的预测性能。因此,本研究开发的基于 AT-LSTM 的水质预测模子显示出比 LSTM 模子更强的能力,可以为澳大利亚昆士兰州水质改善计划提供准确预测伯内特河水质的能力。
The present study aims to develop a long short-term memory (LSTM) network and its attention-based (AT-LSTM) model to achieve the prediction of water quality in the Burnett River of Australia. The models developed in this study introduced an attention mechanism after feature extraction of water quality data in the section of Burnett River considering the effect of the sequences on the prediction results at different moments to enhance the influence of key features on the prediction results. This study provides one-step-ahead forecasting and multistep forward forecasting of dissolved oxygen (DO) of the Burnett River utilizing LSTM and AT-LSTM models and the comparison of the results. The research outcomes demonstrated that the inclusion of the attention mechanism improves the prediction performance of the LSTM model. Therefore, the AT-LSTM-based water quality forecasting model, developed in this study, demonstrated its stronger capability than the LSTM model for informing the Water Quality Improvement Plan of Queensland, Australia, to accurately predict water quality in the Burnett River.
3. 网络架构

该文提出的模子改进了编码器-解码器网络结构,使其能够更好地适应多步时间序列预测。同时,为相识决时间序列数据的降噪问题,该工作采用SG滤波器对原始数据进行去噪。G 滤波器可以有用保存时间序列的特征,并去除其噪声。同时,联合基于LSTM的编码器-解码器模子,模子显着进步了多步预测的准确性。
3.1 LSTM

Long Short-term Memory的结构:


  • LSTM中的memory存储空间是受网络控制的
  • 输入门/写门(input gate):当收到对应的网络网络信号时,才可以向空间中写入信息
  • 记忆门/读门(output gate):收到信号后,才能从空间中读取信息
  • 忘记门(forget gate):收到信号后,会选择性的删除存储空间中的信息
这个结构的记忆是相对短时的,故为short-term;而不但仅像RNN中仅保存上次输入的记忆,故Long Short-term;同时需要forget gate删除一些信息来保持信息的有用性,故不为long-term
LSTM的运行逻辑如下


  • 通常信号控制为sigmoid function,如许可以保证数值分布在0-1
  • 假设输入为                                        g                            (                            z                            )                                  g(z)                     g(z),输入门为                                        f                            (                                       z                               i                                      )                                  f(z_i)                     f(zi​),memory存储空间中为c,忘记门为                                        f                            (                                       z                               f                                      )                                  f(z_f)                     f(zf​)
  • 在上述情况下,存储空间中的值                                                   c                               ′                                      =                            g                            (                            z                            )                            f                            (                                       z                               i                                      )                            +                            c                            f                            (                                       z                               f                                      )                                  c'=g(z)f(z_i)+cf(z_f)                     c′=g(z)f(zi​)+cf(zf​)
  • 若输出门为                                        f                            (                                       z                               0                                      )                                  f(z_0)                     f(z0​),则                                        a                            =                            h                            (                                       c                               ′                                      )                            f                            (                                       z                               0                                      )                                  a=h(c')f(z_0)                     a=h(c′)f(z0​)
LSTM总体框架扼要概括如下,下图上半部分为LSTM的结构,后半部分为各个门对应盘算

3.2 注意力机制概述

注意力机制的扼要框架概括如下图

紧张用到的方法是dot-product
在网络中利用dot-product盘算干系性的流程如下,假设要查询                                             a                            1                                       a^1                  a1与其他向量的干系性


  • 首先,盘算                                        q                            u                            e                            r                            y                                  query                     query向量                                                   q                               1                                      =                                       W                               q                                                 a                               1                                            q^1=W^qa^1                     q1=Wqa1,之后盘算                                        k                            e                            y                                  key                     key向量                                                   k                               i                                      =                                       W                               k                                                 a                               i                                            k^i=W^ka^i                     ki=Wkai,                                                   a                               i                                            a^i                     ai为输入序列中的全部向量。
  • 其次,盘算                                        a                            t                            t                            e                            n                            t                            i                            o                            n                                                         s                            c                            o                            r                            e                                  attention\ score                     attention score,若查询向量对应                                                   a                               1                                            a^1                     a1、关键词向量对应                                                   a                               2                                            a^2                     a2,则有                                                   α                                           1                                  ,                                  2                                                 =                                       q                               1                                      ⋅                                       k                               2                                            \alpha_{1,2}=q^1\cdot k^2                     α1,2​=q1⋅k2

    • 以此类推盘算全部向量的attention score

  • 之后,将全部的                                        a                            t                            t                            e                            n                            t                            i                            o                            n                                                         s                            c                            o                            r                            e                                  attention\ score                     attention score输入soft-max中,将其映射为一个分布,                                                   α                                           1                                  ,                                  2                                                       \alpha_{1,2}                     α1,2​对应的输出为                                                   α                                           1                                  ,                                  2                                          ′                                            \alpha'_{1,2}                     α1,2′​
  • 末了,将                                                   a                               i                                            a^i                     ai乘上矩阵                                                   W                               v                                            W^v                     Wv,得到                                                   v                               i                                            v^i                     vi,用                                                   α                                           1                                  ,                                  i                                          ′                                            \alpha'_{1,i}                     α1,i′​乘上                                                   v                               i                                            v^i                     vi,将全部的按照该流程的得到的结果累加                                                   b                               1                                      =                                       ∑                               i                                                             α                                               1                                     ,                                     i                                              ′                                                      v                                  i                                                       b^1=\sum_i{\alpha'_{1,i}v^i}                     b1=∑i​α1,i′​vi,                                                   a                               i                                            a^i                     ai为输入序列中的全部向量。若其他向量                                                   b                               i                                            b^i                     bi与                                                   b                               1                                            b^1                     b1越相近,则                                                   a                               i                                            a^i                     ai与                                                   a                               1                                            a^1                     a1越相近
大致盘算过程如下

多头注意力机制
以下以两个head为例,盘算过程如下

3.3 AT-LSTM

差别时序输入序列首先经过LSTM网络,然后作为注意力层的输入,经过全连接层和softmax激活后输入多头注意力层,经过flattern扁平化操作后输入全连接层,末了输出结果

该模子的紧张思想是通过对神经网络隐含层元素进行自适应加权,减少无关因素对结果的影响,突出干系因素的影响,从而进步预测精度。模子框架如图7所示,紧张构成部分是LSTM层和注意力层。
3.4 数据预处理


  • 将数据清洗,将非常值设定为控制,然后通过缺失值补充填充空值。
  • 首先,应用Pearson干系性检验选取特征,差别水质参数之间的干系性分析、执行,并且与要预测的特征干系的关键特征被用作模子的输入。然后,利用窗口巨细为100的滑动窗口技能来捕获水质变量的趋势。末了通过最小-最大归一化用于减轻差别特征标准对模子练习的影响。
  • 末了将经过数据预处理的数据用于练习网络
综上,算法的整体结构如下

4. 文献解读

4.1 Introduction

研究以水质评价的关键参数DO作为模子构建和预测评价的目的。以往的方法仍然没有充实学习时间序列中隐蔽的干系特征,这显着影响了预测精度。LSTM缺乏对子窗口特征进行差别程度关注的能力,这大概会导致一些干系信息被忽略,无法重视时间序列的紧张特征。该文应用注意力机制来有用捕获更远的关键信息,并通过在每个时间步对隐蔽层元素进行加权来增强紧张特征对预测模子的影响。在此底子上,引入了注意力机制,并在LSTM模子的底子上开发了AT-LSTM模子,重点是更好地捕捉水质变量。利用水质监测原始数据预测了澳大利亚伯内特河河段的溶解氧浓度。末了将预测结果与LSTM模子进行比力。我们的目的是实现多元时间数据的长期依靠性和隐蔽干系性特征的自适应学习,以使河道水质预测更加准确。伯内特河被认为是一个案例研究,以阐明所提出的 AT-LSTM 模子的适用性。
4.2 创新点

这项工作设计了一种肴杂模子,利用基于 LSTM和注意力机制的神经网络(称为 AT-LSTM)来预测未来的水质。紧张贡献总结如下:

  • 提出了一种改进的网络结构,可以更好地预测多步水质时间序列数据。因此,所提出的AT-LSTM可以更好地处理时间序列数据中的长序列。
  • 创新性地将注意力机制同LSTM联合和集成,显着进步了多步预测精度。
4.3 实验过程

4.3.1 练习参数

模子结构和紧张参数如表3所示。该文通过试错法将时间窗口设置为100,利用贝叶斯优化[56]进行模子超参数优化,辨认出相对较好的超参数和激活函数。

4.3.2 数据集

研究利用的数据为伯内特河自动监测点的水质数据,其位置及流域边界如图1所示。

为保证模子的可靠性和适用性,利用了伯内特河2015年1月至2020年1月采集的水质监测数据。每半小时收集一次数据,包罗五个特征:水温 (Temp)、pH、溶解氧 (DO)、电导率 (EC)、叶绿素-a (Chl-a) 和浊度 (NTU)。该文采用每小时39752个特征的水质数据和溶解氧作为输出变量。表1显示了数据的描述性统计。

研究期间DO的变化如图2所示。

水质数据按照8:1:1的比例分为三个数据集:练习集、验证集和测试集。在本研究中,练习集包罗31,802个每小时条目(从2015年1月1日到2019年2月4日),验证集包罗3975个每小时条目(从2019年2月4日到2019年7月20日),测试集包罗3975个每小时条目(从2019年2月4日到2019年7月20日)。 2019年7月20日至2020年1月1日)。
4.3.3 实验设置

评估指标:采用匀称绝对偏差(MAE)、均方根偏差(RMSE)和决定系数(                                             R                            2                                       R^2                  R2)来定量评价模子预测效果,盘算方法如下

利用MSE作为模子的损失函数,并利用以下标准方程进行盘算:

两个模子均利用 Adam 优化器在练习集上进行练习,批量巨细为 64。为了加快偏差的收敛,利用了反向传播学习方法。验证集用作early stop方法,以确保模子不会过分练习。
利用LSTM作为基准模子
4.3.4 实验结果

从下图中,可以看到AT-LSTM模子在Burnett River测试集的水质预测方面优于LSTM模子:

在图(b)中,蓝色曲线体现现实值,橙色曲线体现建模的预测值。固然LSTM可以预测水质变化,但AT-LSTM的预测与现实值的差异较小,表明AT-LSTM的泛化能力比LSTM更强。

下表总结了LSTM和AT-LSTM模子在监测段中预测溶解氧DO使命的性能

以下是两种模子多步预测结果的比力:

5. 基于pytorch的transformer

掩码部分代码如下图
  1. class Embeddings(nn.Module):
  2.     def __init__(self, d_model, vocab):
  3.         """
  4.         类的初始化函数
  5.         d_model:指词嵌入的维度
  6.         vocab:指词表的大小
  7.         """
  8.         super(Embeddings, self).__init__()
  9.         #之后就是调用nn中的预定义层Embedding,获得一个词嵌入对象self.lut
  10.         self.lut = nn.Embedding(vocab, d_model)
  11.         #最后就是将d_model传入类中
  12.         self.d_model =d_model
  13.     def forward(self, x):
  14.         """
  15.         Embedding层的前向传播逻辑
  16.         参数x:这里代表输入给模型的单词文本通过词表映射后的one-hot向量
  17.         将x传给self.lut并与根号下self.d_model相乘作为结果返回
  18.         """
  19.         embedds = self.lut(x)
  20.         return embedds * math.sqrt(self.d_model)
复制代码
位置编码部分代码如下
  1. class PositionalEncoding(nn.Module):
  2.     def __init__(self, d_model, dropout, max_len=5000):
  3.         """
  4.         位置编码器类的初始化函数
  5.         
  6.         共有三个参数,分别是
  7.         d_model:词嵌入维度
  8.         dropout: dropout触发比率
  9.         max_len:每个句子的最大长度
  10.         """
  11.         super(PositionalEncoding, self).__init__()
  12.         self.dropout = nn.Dropout(p=dropout)
  13.         
  14.         # Compute the positional encodings
  15.         # 注意下面代码的计算方式与公式中给出的是不同的,但是是等价的,你可以尝试简单推导证明一下。
  16.         # 这样计算是为了避免中间的数值计算结果超出float的范围,
  17.         pe = torch.zeros(max_len, d_model)
  18.         position = torch.arange(0, max_len).unsqueeze(1)
  19.         div_term = torch.exp(torch.arange(0, d_model, 2) *
  20.                              -(math.log(10000.0) / d_model))
  21.         pe[:, 0::2] = torch.sin(position * div_term)
  22.         pe[:, 1::2] = torch.cos(position * div_term)
  23.         pe = pe.unsqueeze(0)
  24.         self.register_buffer('pe', pe)
  25.         
  26.     def forward(self, x):
  27.         x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
  28.         return self.dropout(x)
复制代码
编码器代码如下
  1. # 定义一个clones函数,来更方便的将某个结构复制若干份
  2. def clones(module, N):
  3.     "Produce N identical layers."
  4.     return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
  5. class Encoder(nn.Module):
  6.     """
  7.     Encoder
  8.     The encoder is composed of a stack of N=6 identical layers.
  9.     """
  10.     def __init__(self, layer, N):
  11.         super(Encoder, self).__init__()
  12.         # 调用时会将编码器层传进来,我们简单克隆N分,叠加在一起,组成完整的Encoder
  13.         self.layers = clones(layer, N)
  14.         self.norm = LayerNorm(layer.size)
  15.         
  16.     def forward(self, x, mask):
  17.         "Pass the input (and mask) through each layer in turn."
  18.         for layer in self.layers:
  19.             x = layer(x, mask)
  20.         return self.norm(x)
复制代码
编码器层代码如下
  1. class EncoderLayer(nn.Module):
  2.     "EncoderLayer is made up of two sublayer: self-attn and feed forward"                                                                                                         
  3.     def __init__(self, size, self_attn, feed_forward, dropout):
  4.         super(EncoderLayer, self).__init__()
  5.         self.self_attn = self_attn
  6.         self.feed_forward = feed_forward
  7.         self.sublayer = clones(SublayerConnection(size, dropout), 2)
  8.         self.size = size   # embedding's dimention of model, 默认512
  9.     def forward(self, x, mask):
  10.         # attention sub layer
  11.         x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
  12.         # feed forward sub layer
  13.         z = self.sublayer[1](x, self.feed_forward)
  14.         return z
复制代码
注意力机制层代码如下
  1. def attention(query, key, value, mask=None, dropout=None):
  2.     "Compute 'Scaled Dot Product Attention'"
  3.     #首先取query的最后一维的大小,对应词嵌入维度
  4.     d_k = query.size(-1)
  5.     #按照注意力公式,将query与key的转置相乘,这里面key是将最后两个维度进行转置,再除以缩放系数得到注意力得分张量scores
  6.     scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
  7.    
  8.     #接着判断是否使用掩码张量
  9.     if mask is not None:
  10.         #使用tensor的masked_fill方法,将掩码张量和scores张量每个位置一一比较,如果掩码张量则对应的scores张量用-1e9这个置来替换
  11.         scores = scores.masked_fill(mask == 0, -1e9)
  12.         
  13.     #对scores的最后一维进行softmax操作,使用F.softmax方法,这样获得最终的注意力张量
  14.     p_attn = F.softmax(scores, dim = -1)
  15.    
  16.     #之后判断是否使用dropout进行随机置0
  17.     if dropout is not None:
  18.         p_attn = dropout(p_attn)
  19.    
  20.     #最后,根据公式将p_attn与value张量相乘获得最终的query注意力表示,同时返回注意力张量
  21.     return torch.matmul(p_attn, value), p_attn
复制代码
多头注意力机制代码如下
  1. class MultiHeadedAttention(nn.Module):
  2.     def __init__(self, h, d_model, dropout=0.1):
  3.         #在类的初始化时,会传入三个参数,h代表头数,d_model代表词嵌入的维度,dropout代表进行dropout操作时置0比率,默认是0.1
  4.         super(MultiHeadedAttention, self).__init__()
  5.         #在函数中,首先使用了一个测试中常用的assert语句,判断h是否能被d_model整除,这是因为我们之后要给每个头分配等量的词特征,也就是embedding_dim/head个
  6.         assert d_model % h == 0
  7.         #得到每个头获得的分割词向量维度d_k
  8.         self.d_k = d_model // h
  9.         #传入头数h
  10.         self.h = h
  11.         
  12.         #创建linear层,通过nn的Linear实例化,它的内部变换矩阵是embedding_dim x embedding_dim,然后使用,为什么是四个呢,这是因为在多头注意力中,Q,K,V各需要一个,最后拼接的矩阵还需要一个,因此一共是四个
  13.         self.linears = clones(nn.Linear(d_model, d_model), 4)
  14.         #self.attn为None,它代表最后得到的注意力张量,现在还没有结果所以为None
  15.         self.attn = None
  16.         self.dropout = nn.Dropout(p=dropout)
  17.         
  18.     def forward(self, query, key, value, mask=None):
  19.         #前向逻辑函数,它输入参数有四个,前三个就是注意力机制需要的Q,K,V,最后一个是注意力机制中可能需要的mask掩码张量,默认是None
  20.         if mask is not None:
  21.             # Same mask applied to all h heads.
  22.             #使用unsqueeze扩展维度,代表多头中的第n头
  23.             mask = mask.unsqueeze(1)
  24.         #接着,我们获得一个batch_size的变量,他是query尺寸的第1个数字,代表有多少条样本
  25.         nbatches = query.size(0)
  26.         
  27.         # 1) Do all the linear projections in batch from d_model => h x d_k
  28.         # 首先利用zip将输入QKV与三个线性层组到一起,然后利用for循环,将输入QKV分别传到线性层中,做完线性变换后,开始为每个头分割输入,这里使用view方法对线性变换的结构进行维度重塑,多加了一个维度h代表头,这样就意味着每个头可以获得一部分词特征组成的句子,其中的-1代表自适应维度,计算机会根据这种变换自动计算这里的值,然后对第二维和第三维进行转置操作,为了让代表句子长度维度和词向量维度能够相邻,这样注意力机制才能找到词义与句子位置的关系,从attention函数中可以看到,利用的是原始输入的倒数第一和第二维,这样我们就得到了每个头的输入
  29.         query, key, value = \
  30.             [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
  31.              for l, x in zip(self.linears, (query, key, value))]
  32.         # 2) Apply attention on all the projected vectors in batch.
  33.         # 得到每个头的输入后,接下来就是将他们传入到attention中,这里直接调用我们之前实现的attention函数,同时也将mask和dropout传入其中
  34.         x, self.attn = attention(query, key, value, mask=mask,
  35.                                  dropout=self.dropout)
  36.         # 3) "Concat" using a view and apply a final linear.
  37.         # 通过多头注意力计算后,我们就得到了每个头计算结果组成的4维张量,我们需要将其转换为输入的形状以方便后续的计算,因此这里开始进行第一步处理环节的逆操作,先对第二和第三维进行转置,然后使用contiguous方法。这个方法的作用就是能够让转置后的张量应用view方法,否则将无法直接使用,所以,下一步就是使用view重塑形状,变成和输入形状相同。  
  38.         x = x.transpose(1, 2).contiguous() \
  39.              .view(nbatches, -1, self.h * self.d_k)
  40.         #最后使用线性层列表中的最后一个线性变换得到最终的多头注意力结构的输出
  41.         return self.linears[-1](x)
复制代码
解码器整体结构代码结构如下
  1. #使用类Decoder来实现解码器
  2. class Decoder(nn.Module):
  3.     "Generic N layer decoder with masking."
  4.     def __init__(self, layer, N):
  5.         #初始化函数的参数有两个,第一个就是解码器层layer,第二个是解码器层的个数N
  6.         super(Decoder, self).__init__()
  7.         #首先使用clones方法克隆了N个layer,然后实例化一个规范化层,因为数据走过了所有的解码器层后最后要做规范化处理。
  8.         self.layers = clones(layer, N)
  9.         self.norm = LayerNorm(layer.size)
  10.         
  11.     def forward(self, x, memory, src_mask, tgt_mask):
  12.         #forward函数中的参数有4个,x代表目标数据的嵌入表示,memory是编码器层的输出,source_mask,target_mask代表源数据和目标数据的掩码张量,然后就是对每个层进行循环,当然这个循环就是变量x通过每一个层的处理,得出最后的结果,再进行一次规范化返回即可。
  13.         for layer in self.layers:
  14.             x = layer(x, memory, src_mask, tgt_mask)
  15.         return self.norm(x)
复制代码
解码器层代码如下
  1. #使用DecoderLayer的类实现解码器层
  2. class DecoderLayer(nn.Module):
  3.     "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
  4.     def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
  5.         #初始化函数的参数有5个,分别是size,代表词嵌入的维度大小,同时也代表解码器的尺寸,第二个是self_attn,多头自注意力对象,也就是说这个注意力机制需要Q=K=V,第三个是src_attn,多头注意力对象,这里Q!=K=V,第四个是前馈全连接层对象,最后就是dropout置0比率
  6.         super(DecoderLayer, self).__init__()
  7.         self.size = size
  8.         self.self_attn = self_attn
  9.         self.src_attn = src_attn
  10.         self.feed_forward = feed_forward
  11.         #按照结构图使用clones函数克隆三个子层连接对象
  12.         self.sublayer = clones(SublayerConnection(size, dropout), 3)
  13.     def forward(self, x, memory, src_mask, tgt_mask):
  14.         #forward函数中的参数有4个,分别是来自上一层的输入x,来自编码器层的语义存储变量memory,以及源数据掩码张量和目标数据掩码张量,将memory表示成m之后方便使用。
  15.         "Follow Figure 1 (right) for connections."
  16.         m = memory
  17.         #将x传入第一个子层结构,第一个子层结构的输入分别是x和self-attn函数,因为是自注意力机制,所以Q,K,V都是x,最后一个参数时目标数据掩码张量,这时要对目标数据进行遮掩,因为此时模型可能还没有生成任何目标数据。
  18.         #比如在解码器准备生成第一个字符或词汇时,我们其实已经传入了第一个字符以便计算损失,但是我们不希望在生成第一个字符时模型能利用这个信息,因此我们会将其遮掩,同样生成第二个字符或词汇时,模型只能使用第一个字符或词汇信息,第二个字符以及之后的信息都不允许被模型使用。
  19.         x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
  20.         #接着进入第二个子层,这个子层中常规的注意力机制,q是输入x;k,v是编码层输出memory,同样也传入source_mask,但是进行源数据遮掩的原因并非是抑制信息泄露,而是遮蔽掉对结果没有意义的padding。
  21.         x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
  22.         
  23.         #最后一个子层就是前馈全连接子层,经过它的处理后就可以返回结果,这就是我们的解码器结构
  24.         return self.sublayer[2](x, self.feed_forward)
复制代码
整体网络框架如下
  1. # Model Architecture
  2. #使用EncoderDecoder类来实现编码器-解码器结构
  3. class EncoderDecoder(nn.Module):
  4.     """
  5.     A standard Encoder-Decoder architecture.
  6.     Base for this and many other models.
  7.     """
  8.     def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
  9.         #初始化函数中有5个参数,分别是编码器对象,解码器对象,源数据嵌入函数,目标数据嵌入函数,以及输出部分的类别生成器对象.
  10.         super(EncoderDecoder, self).__init__()
  11.         self.encoder = encoder
  12.         self.decoder = decoder
  13.         self.src_embed = src_embed    # input embedding module(input embedding + positional encode)
  14.         self.tgt_embed = tgt_embed    # ouput embedding module
  15.         self.generator = generator    # output generation module
  16.         
  17.     def forward(self, src, tgt, src_mask, tgt_mask):
  18.         "Take in and process masked src and target sequences."
  19.         #在forward函数中,有四个参数,source代表源数据,target代表目标数据,source_mask和target_mask代表对应的掩码张量,在函数中,将source source_mask传入编码函数,得到结果后与source_mask target 和target_mask一同传给解码函数
  20.         memory = self.encode(src, src_mask)
  21.         res = self.decode(memory, src_mask, tgt, tgt_mask)
  22.         return res
  23.    
  24.     def encode(self, src, src_mask):
  25.         #编码函数,以source和source_mask为参数,使用src_embed对source做处理,然后和source_mask一起传给self.encoder
  26.         src_embedds = self.src_embed(src)
  27.         return self.encoder(src_embedds, src_mask)
  28.    
  29.     def decode(self, memory, src_mask, tgt, tgt_mask):
  30.         #解码函数,以memory即编码器的输出,source_mask target target_mask为参数,使用tgt_embed对target做处理,然后和source_mask,target_mask,memory一起传给self.decoder
  31.         target_embedds = self.tgt_embed(tgt)
  32.         return self.decoder(target_embedds, memory, src_mask, tgt_mask)
  33. # Full Model
  34. def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
  35.     """
  36.     构建模型
  37.     params:
  38.         src_vocab:
  39.         tgt_vocab:
  40.         N: 编码器和解码器堆叠基础模块的个数
  41.         d_model: 模型中embedding的size,默认512
  42.         d_ff: FeedForward Layer层中embedding的size,默认2048
  43.         h: MultiHeadAttention中多头的个数,必须被d_model整除
  44.         dropout:
  45.     """
  46.     c = copy.deepcopy
  47.     attn = MultiHeadedAttention(h, d_model)
  48.     ff = PositionwiseFeedForward(d_model, d_ff, dropout)
  49.     position = PositionalEncoding(d_model, dropout)
  50.     model = EncoderDecoder(
  51.         Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
  52.         Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
  53.         nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
  54.         nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
  55.         Generator(d_model, tgt_vocab))
  56.    
  57.     # This was important from their code.
  58.     # Initialize parameters with Glorot / fan_avg.
  59.     for p in model.parameters():
  60.         if p.dim() > 1:
  61.             nn.init.xavier_uniform_(p)
  62.     return model
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

九天猎人

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表