深度学习|表示学习|多头留意力在计算时常见的张量维度变换总结|28
如是我闻: 以下是多头留意力(Multi-Headed Attention)在计算时常见的张量维度变换总结,帮助明白从输入到输出是怎样一步步处理惩罚的。为了方便,令:[*] B B B 表示 batch size(批量巨细)
[*] S S S 表示 sequence length(序列长度)
[*] m m m 表示 num_heads(留意力头数)
[*] h h h 表示 head_size(每个头的维度)
[*] d m o d e l = m × h d_{\mathrm{model}} = m \times h dmodel=m×h 表示模型隐层维度
[*] 输入(queries、keys、values)
外形 = ( B , S , d m o d e l ) . \text{外形} = (B,\, S,\, d_{\mathrm{model}}). 外形=(B,S,dmodel).
在「自留意力」(self-attention)场景下,三者通常是同一个张量;在「交叉留意力」(cross-attention)场景下, queries \texttt{queries} queries 和 keys, values \texttt{keys, values} keys, values 可能来自差别子网络。
[*] 线性映射( W Q , W K , W V W_Q, W_K, W_V WQ,WK,WV)
[*]对 queries \texttt{queries} queries 做线性变换得到 (Q):外形仍为 ( B , S , d m o d e l ) (B, S, d_{\mathrm{model}}) (B,S,dmodel)
[*]对 keys \texttt{keys} keys 做线性变换得到 K K K:外形同上
[*]对 values \texttt{values} values 做线性变换得到 V V V:外形同上
[*] 拆分 heads(split heads)
[*]将 ( B , S , d m o d e l ) (B, S, d_{\mathrm{model}}) (B,S,dmodel) reshape + transpose 成 ( B , m , S , h ) (B, m, S, h) (B,m,S,h)。
[*]这样每个 batch、每个序列位置上就可以拆出 m m m 个“头”,每个头维度为 h h h。
[*]拆分后:
Q , K , V → split ( B , m , S , h ) . Q, K, V ~\xrightarrow{\text{split}}~ (B,\, m,\, S,\, h). Q,K,V split (B,m,S,h).
[*] 计算留意力分数(scores)
[*]使用 scaled dot-product:
scores = Q × K T h 外形 = ( B , m , S , S ) . \text{scores} = \frac{Q \times K^T}{\sqrt{h}} \quad\text{外形} = (B,\, m,\, S,\, S). scores=h Q×KT外形=(B,m,S,S).
[*]此时会应用「下三角 mask」(causal mask)以保证自回归:只关注「过去和当前」位置,屏蔽「未来」位置。
[*]对 scores \text{scores} scores 做 s o f t m a x \mathrm{softmax} softmax 得到留意力权重 a t t n _ w e i g h t s \mathrm{attn\_weights} attn_weights。
[*] 加权求和(attended values)
attended_values = a t t n _ w e i g h t s × V , 外形 = ( B , m , S , h ) . \text{attended\_values} = \mathrm{attn\_weights} \times V, \quad \text{外形} = (B,\, m,\, S,\, h). attended_values=attn_weights×V,外形=(B,m,S,h).
这样就得到每个 head 对原值向量的加权结果。
[*] 合并 heads(merge heads)
[*]将 ( B , m , S , h ) (B, m, S, h) (B,m,S,h) 还原到 ( B , S , m × h ) (B, S, m \times h) (B,S,m×h),即 ( B , S , d m o d e l ) (B, S, d_{\mathrm{model}}) (B,S,dmodel)。
[*]合并之后,相当于将全部 head 的信息拼接到最后一个维度上。
[*] 可选的最终线性映射 W O \mathbf{W}_O WO
[*]多数实现会继承用一个线性层 W O \mathbf{W}_O WO(同样是 ( d m o d e l , d m o d e l ) (d_{\mathrm{model}}, d_{\mathrm{model}}) (dmodel,dmodel))把拼接后的多头输出再次投影,外形保持 ( B , S , d m o d e l ) (B, S, d_{\mathrm{model}}) (B,S,dmodel)。
通过以上步骤,多头留意力便可将序列的上下文信息捕捉到差别的 head(差别的子空间),再合并形成新的隐层表示。
cao!
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页:
[1]