Transformer中Decoder的计算过程及各部分维度变化

打印 上一主题 下一主题

主题 1807|帖子 1807|积分 5421

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
在Transformer模型中,解码器的计算过程涉及多个步调,告急包括自注意力机制、编码器-解码器注意力和前馈神经网络。以下是解码器的详细计算过程及数据维度变化:
1. 输入嵌入和位置编码

解码器的输入起首经过嵌入层和位置编码:
                                              Input                            d                                  =                         Embedding                         (                         x                         )                         +                         PositionEncoding                         (                         x                         )                              \text{Input}_d = \text{Embedding}(x) + \text{PositionEncoding}(x)                  Inputd​=Embedding(x)+PositionEncoding(x)


  • 维度变化:                                        x                                  x                     x: 输入序列的标记,维度为                                         (                            n                            ,                                       d                                           m                                  o                                  d                                  e                                  l                                                 )                                  (n, d_{model})                     (n,dmodel​)                                        Embedding                            (                            x                            )                                  \text{Embedding}(x)                     Embedding(x): 输出维度为                                         (                            n                            ,                                       d                                           m                                  o                                  d                                  e                                  l                                                 )                                  (n, d_{model})                     (n,dmodel​)                                        PositionEncoding                            (                            x                            )                                  \text{PositionEncoding}(x)                     PositionEncoding(x): 输出维度为                                         (                            n                            ,                                       d                                           m                                  o                                  d                                  e                                  l                                                 )                                  (n, d_{model})                     (n,dmodel​)
2. 自注意力机制

自注意力机制计算如下:
                                    Q                         =                                   Input                            d                                            W                            Q                                  ,                                 K                         =                                   Input                            d                                            W                            K                                  ,                                 V                         =                                   Input                            d                                            W                            V                                       Q = \text{Input}_d W_Q, \quad K = \text{Input}_d W_K, \quad V = \text{Input}_d W_V                  Q=Inputd​WQ​,K=Inputd​WK​,V=Inputd​WV​


  • 这里                                                    W                               Q                                      ,                                       W                               K                                      ,                                       W                               V                                            W_Q, W_K, W_V                     WQ​,WK​,WV​ 是参数矩阵,维度为                                         (                                       d                                           m                                  o                                  d                                  e                                  l                                                 ,                                       d                               k                                      )                                  (d_{model}, d_k)                     (dmodel​,dk​),假设                                                    d                               k                                      =                                       d                                           m                                  o                                  d                                  e                                  l                                                       d_k = d_{model}                     dk​=dmodel​。
  • 维度变化:                                        Q                            ,                            K                            ,                            V                                  Q, K, V                     Q,K,V: 输出维度为                                         (                            n                            ,                                       d                               k                                      )                                  (n, d_k)                     (n,dk​)
    自注意力的计算为:
                                             Attention                            (                            Q                            ,                            K                            ,                            V                            )                            =                            softmax                                       (                                                        Q                                                   K                                        T                                                                                      d                                        k                                                                   +                               M                               )                                      V                                  \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V                     Attention(Q,K,V)=softmax(dk​                      ​QKT​+M)V
  • 维度变化:                                        Q                                       K                               T                                            QK^T                     QKT: 维度为                                         (                            n                            ,                            n                            )                                  (n, n)                     (n,n)                                        softmax                                  \text{softmax}                     softmax: 结果维度为                                         (                            n                            ,                            n                            )                                  (n, n)                     (n,n)终极输出的维度为                                         (                            n                            ,                                       d                               v                                      )                                  (n, d_v)                     (n,dv​)(假设                                                    d                               v                                      =                                       d                                           m                                  o                                  d                                  e                                  l                                                       d_v = d_{model}                     dv​=dmodel​)。
3. 残差连接与层归一化

自注意力的输出与输入相加,然后进行层归一化:
                                              Output                            d                                       (                               l                               )                                            =                         LayerNorm                         (                         Attention                         +                                   Input                            d                                  )                              \text{Output}_d^{(l)} = \text{LayerNorm}(\text{Attention} + \text{Input}_d)                  Outputd(l)​=LayerNorm(Attention+Inputd​)


  • 维度变化:维度保持为                                         (                            n                            ,                                       d                                           m                                  o                                  d                                  e                                  l                                                 )                                  (n, d_{model})                     (n,dmodel​)。
4. 编码器-解码器注意力

接下来,解码器会对编码器的输出进行注意力计算:
                                              Q                            ′                                  =                                   Output                            d                                       (                               l                               )                                                      W                            Q                            ′                                  ,                                           K                            ′                                  =                         EncoderOutput                                   W                            K                            ′                                  ,                                           V                            ′                                  =                         EncoderOutput                                   W                            V                            ′                                       Q' = \text{Output}_d^{(l)} W_Q', \quad K' = \text{EncoderOutput} W_K', \quad V' = \text{EncoderOutput} W_V'                  Q′=Outputd(l)​WQ′​,K′=EncoderOutputWK′​,V′=EncoderOutputWV′​


  • 这里                                                    W                               Q                               ′                                      ,                                       W                               K                               ′                                      ,                                       W                               V                               ′                                            W_Q', W_K', W_V'                     WQ′​,WK′​,WV′​ 的维度也是                                         (                                       d                                           m                                  o                                  d                                  e                                  l                                                 ,                                       d                               k                                      )                                  (d_{model}, d_k)                     (dmodel​,dk​)。
  • 编码器输出的维度为                                         (                                       T                               e                                      ,                                       d                                           m                                  o                                  d                                  e                                  l                                                 )                                  (T_e, d_{model})                     (Te​,dmodel​)。
    注意力计算为:
                                             Attention                            (                                       Q                               ′                                      ,                                       K                               ′                                      ,                                       V                               ′                                      )                            =                            softmax                                       (                                                                      Q                                        ′                                                                K                                                       ′                                           T                                                                                                    d                                        k                                                                   )                                                 V                               ′                                            \text{Attention}(Q', K', V') = \text{softmax}\left(\frac{Q'K'^T}{\sqrt{d_k}}\right)V'                     Attention(Q′,K′,V′)=softmax(dk​                      ​Q′K′T​)V′
  • 维度变化:                                                   Q                               ′                                                 K                                           ′                                  T                                                       Q'K'^T                     Q′K′T: 维度为                                         (                            n                            ,                                       T                               e                                      )                                  (n, T_e)                     (n,Te​)终极输出的维度为                                         (                            n                            ,                                       d                               v                                      )                                  (n, d_v)                     (n,dv​)。
    然后与自注意力的输出进行残差连接和层归一化:
                                                        Output                               d                                           (                                  l                                  )                                                 =                            LayerNorm                            (                            EncoderDecoderAttention                            +                                       Output                               d                                           (                                  l                                  )                                                 )                                  \text{Output}_d^{(l)} = \text{LayerNorm}(\text{EncoderDecoderAttention} + \text{Output}_d^{(l)})                     Outputd(l)​=LayerNorm(EncoderDecoderAttention+Outputd(l)​)
5. 前馈神经网络

接下来是前馈神经网络的处置惩罚:
                                    FFN                         (                         x                         )                         =                         ReLU                         (                         x                                   W                            1                                  +                                   b                            1                                  )                                   W                            2                                  +                                   b                            2                                       \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2                  FFN(x)=ReLU(xW1​+b1​)W2​+b2​


  •                                                    W                               1                                            W_1                     W1​ 维度为                                         (                                       d                                           m                                  o                                  d                                  e                                  l                                                 ,                                       d                                           f                                  f                                                 )                                  (d_{model}, d_{ff})                     (dmodel​,dff​),                                                   W                               2                                            W_2                     W2​ 维度为                                         (                                       d                                           f                                  f                                                 ,                                       d                                           m                                  o                                  d                                  e                                  l                                                 )                                  (d_{ff}, d_{model})                     (dff​,dmodel​),其中                                                    d                                           f                                  f                                                       d_{ff}                     dff​ 是前馈层的隐藏单元数。
  • 维度变化:输入维度为                                         (                            n                            ,                                       d                                           m                                  o                                  d                                  e                                  l                                                 )                                  (n, d_{model})                     (n,dmodel​)输出维度为                                         (                            n                            ,                                       d                                           m                                  o                                  d                                  e                                  l                                                 )                                  (n, d_{model})                     (n,dmodel​)。
6. 终极输出

在最后一步,再次进行残差连接和层归一化:
                                              Output                            d                                       (                               l                               )                                            =                         LayerNorm                         (                         FFN                         +                                   Output                            d                                       (                               l                               )                                            )                              \text{Output}_d^{(l)} = \text{LayerNorm}(\text{FFN} + \text{Output}_d^{(l)})                  Outputd(l)​=LayerNorm(FFN+Outputd(l)​)
接下来,解码器的终极输出通过线性层和Softmax层生成词汇表的概率分布:
                                    Logits                         =                                   Output                            d                                       (                               l                               )                                                      W                                       o                               u                               t                                            +                                   b                                       o                               u                               t                                                 \text{Logits} = \text{Output}_d^{(l)} W_{out} + b_{out}                  Logits=Outputd(l)​Wout​+bout​
                                    Probabilities                         =                         softmax                         (                         Logits                         )                              \text{Probabilities} = \text{softmax}(\text{Logits})                  Probabilities=softmax(Logits)


  • 维度变化:                                                   W                                           o                                  u                                  t                                                       W_{out}                     Wout​ 维度为                                         (                                       d                                           m                                  o                                  d                                  e                                  l                                                 ,                            V                            )                                  (d_{model}, V)                     (dmodel​,V),其中                                         V                                  V                     V 是词汇表的巨细。                                        Logits                                  \text{Logits}                     Logits 的维度为                                         (                            n                            ,                            V                            )                                  (n, V)                     (n,V),                                        Probabilities                                  \text{Probabilities}                     Probabilities 的维度同样为                                         (                            n                            ,                            V                            )                                  (n, V)                     (n,V),表现每个时间步上各个词汇的概率。
    通过这些步调,解码器可以大概生成序列的下一个标记。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

商道如狼道

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表