羊蹓狼 发表于 2025-3-26 07:14:16

深度学习|表示学习|多头留意力在计算时常见的张量维度变换总结|28

如是我闻: 以下是多头留意力(Multi-Headed Attention)在计算时常见的张量维度变换总结,帮助明白从输入到输出是怎样一步步处理惩罚的。为了方便,令:


[*]                                        B                                  B                     B 表示 batch size(批量巨细)
[*]                                        S                                  S                     S 表示 sequence length(序列长度)
[*]                                        m                                  m                     m 表示 num_heads(留意力头数)
[*]                                        h                                  h                     h 表示 head_size(每个头的维度)
[*]                                                   d                                           m                                  o                                  d                                  e                                  l                                                 =                            m                            ×                            h                                  d_{\mathrm{model}} = m \times h                     dmodel​=m×h 表示模型隐层维度

[*] 输入(queries、keys、values)
                                                   外形                                  =                                  (                                  B                                  ,                               S                                  ,                                              d                                                   m                                        o                                        d                                        e                                        l                                                         )                                  .                                          \text{外形} = (B,\, S,\, d_{\mathrm{model}}).                           外形=(B,S,dmodel​).
在「自留意力」(self-attention)场景下,三者通常是同一个张量;在「交叉留意力」(cross-attention)场景下,                                             queries                                    \texttt{queries}                        queries 和                                              keys,                                                               values                                    \texttt{keys, values}                        keys, values 可能来自差别子网络。
[*] 线性映射(                                                               W                                     Q                                              ,                                             W                                     K                                              ,                                             W                                     V                                                      W_Q, W_K, W_V                           WQ​,WK​,WV​)

[*]对                                                   queries                                          \texttt{queries}                           queries 做线性变换得到 (Q):外形仍为                                                   (                                  B                                  ,                                  S                                  ,                                             d                                                   m                                        o                                        d                                        e                                        l                                                         )                                          (B, S, d_{\mathrm{model}})                           (B,S,dmodel​)
[*]对                                                   keys                                          \texttt{keys}                           keys 做线性变换得到                                                   K                                          K                           K:外形同上
[*]对                                                   values                                          \texttt{values}                           values 做线性变换得到                                                   V                                          V                           V:外形同上

[*] 拆分 heads(split heads)

[*]将                                                   (                                  B                                  ,                                  S                                  ,                                             d                                                   m                                        o                                        d                                        e                                        l                                                         )                                          (B, S, d_{\mathrm{model}})                           (B,S,dmodel​) reshape + transpose 成                                                   (                                  B                                  ,                                  m                                  ,                                  S                                  ,                                  h                                  )                                          (B, m, S, h)                           (B,m,S,h)。
[*]这样每个 batch、每个序列位置上就可以拆出                                                   m                                          m                           m 个“头”,每个头维度为                                                   h                                          h                           h。
[*]拆分后:
                                                      Q                                     ,                                     K                                     ,                                     V                                                                                         →                                                       split                                                                                                      (                                     B                                     ,                                      m                                     ,                                      S                                     ,                                      h                                     )                                     .                                              Q, K, V ~\xrightarrow{\text{split}}~ (B,\, m,\, S,\, h).                              Q,K,V split                ​ (B,m,S,h).

[*] 计算留意力分数(scores)

[*]使用 scaled dot-product:
                                                      scores                                     =                                                                  Q                                           ×                                                         K                                              T                                                                                    h                                                                            外形                                     =                                     (                                     B                                     ,                                      m                                     ,                                      S                                     ,                                      S                                     )                                     .                                              \text{scores} = \frac{Q \times K^T}{\sqrt{h}} \quad\text{外形} = (B,\, m,\, S,\, S).                              scores=h                     ​Q×KT​外形=(B,m,S,S).
[*]此时会应用「下三角 mask」(causal mask)以保证自回归:只关注「过去和当前」位置,屏蔽「未来」位置。
[*]对                                                   scores                                          \text{scores}                           scores 做                                                   s                                  o                                  f                                  t                                  m                                  a                                  x                                          \mathrm{softmax}                           softmax 得到留意力权重                                                   a                                  t                                  t                                  n                                  _                                  w                                  e                                  i                                  g                                  h                                  t                                  s                                          \mathrm{attn\_weights}                           attn_weights。

[*] 加权求和(attended values)
                                                   attended_values                                  =                                             a                                     t                                     t                                     n                                     _                                     w                                     e                                     i                                     g                                     h                                     t                                     s                                              ×                                  V                                  ,                                             外形                                  =                                  (                                  B                                  ,                               m                                  ,                               S                                  ,                               h                                  )                                  .                                          \text{attended\_values} = \mathrm{attn\_weights} \times V, \quad \text{外形} = (B,\, m,\, S,\, h).                           attended_values=attn_weights×V,外形=(B,m,S,h).
这样就得到每个 head 对原值向量的加权结果。
[*] 合并 heads(merge heads)

[*]将                                                   (                                  B                                  ,                                  m                                  ,                                  S                                  ,                                  h                                  )                                          (B, m, S, h)                           (B,m,S,h) 还原到                                                   (                                  B                                  ,                                  S                                  ,                                  m                                  ×                                  h                                  )                                          (B, S, m \times h)                           (B,S,m×h),即                                                   (                                  B                                  ,                                  S                                  ,                                             d                                                   m                                        o                                        d                                        e                                        l                                                         )                                          (B, S, d_{\mathrm{model}})                           (B,S,dmodel​)。
[*]合并之后,相当于将全部 head 的信息拼接到最后一个维度上。

[*] 可选的最终线性映射                                                                W                                     O                                                      \mathbf{W}_O                           WO​

[*]多数实现会继承用一个线性层                                                                W                                     O                                                      \mathbf{W}_O                           WO​(同样是                                                   (                                             d                                                   m                                        o                                        d                                        e                                        l                                                         ,                                             d                                                   m                                        o                                        d                                        e                                        l                                                         )                                          (d_{\mathrm{model}}, d_{\mathrm{model}})                           (dmodel​,dmodel​))把拼接后的多头输出再次投影,外形保持                                                   (                                  B                                  ,                                  S                                  ,                                             d                                                   m                                        o                                        d                                        e                                        l                                                         )                                          (B, S, d_{\mathrm{model}})                           (B,S,dmodel​)。

通过以上步骤,多头留意力便可将序列的上下文信息捕捉到差别的 head(差别的子空间),再合并形成新的隐层表示。
cao!

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页: [1]
查看完整版本: 深度学习|表示学习|多头留意力在计算时常见的张量维度变换总结|28