概述
因果注意力(Causal Attention)是一种自注意力机制,广泛应用于自回归模子中,尤其是在自然语言处置惩罚和时间序列预测等使掷中。它的核心头脑是在生成每个时间步的输出时,只关注当前时间步及之前的时间步,确保生成过程的因果性,从而避免模子在预测时依赖未来的信息。
工作原理
因果注意力的工作原理是通过掩码矩阵限制模子在计算每个时间步的注意力时,只关注当前时间步及之前的内容。具体地,掩码矩阵是一个上三角矩阵,其上三角部分为0,其余部分为1。这样,在计算注意力分布时,掩码矩阵将未来时间步的注意力得分设置为非常大的负值(-inf),使得这些位置在 softmax 操纵后靠近于零,从而不会对最终的输出产生影响。
掩码矩阵示例
掩码矩阵的结构如下:
- [
- [1, 0, 0, 0],
- [1, 1, 0, 0],
- [1, 1, 1, 0],
- [1, 1, 1, 1]
- ]
复制代码 该掩码矩阵确保每个时间步仅关注当前时间步及之前的时间步,维持因果性。
NumPy实现
以下是基于NumPy的因果注意力机制实现代码:
- import numpy as np
- def softmax(x):
- """Compute the softmax of vector x in a numerically stable way."""
- shift_x = x - np.max(x, axis=-1, keepdims=True)
- exp_x = np.exp(shift_x)
- softmax_x = exp_x / np.sum(exp_x, axis=-1, keepdims=True)
- return softmax_x
- def causal_self_attention(Q, K, V, mask):
- """
- 计算因果自注意力
- :param Q: 查询矩阵
- :param K: 键矩阵
- :param V: 值矩阵
- :param mask: 因果掩码矩阵,上三角为0,其余为1
- :return: 自注意力的输出
- """
- dim_key = K.shape[-1]
-
- # 计算未掩码的注意力得分
- attention_scores = np.matmul(Q, K.transpose(0, 2, 1)) / (np.sqrt(dim_key) + 1e-9)
-
- # 应用因果掩码,将mask为0的位置设置为非常大的负值
- attention_scores = np.where(mask == 0, -np.inf, attention_scores)
-
- # 使用数值稳定的softmax
- attention_weights = softmax(attention_scores)
-
- # 确保无效值处理后不会影响计算结果
- attention_weights = np.nan_to_num(attention_weights, nan=0.0, posinf=0.0, neginf=0.0)
-
- # 加权求和得到输出
- output = np.matmul(attention_weights, V)
- return output
- # 示例用法
- batch_size = 2
- seq_length = 4
- dim = 8
- Q = np.random.rand(batch_size, seq_length, dim)
- K = np.random.rand(batch_size, seq_length, dim)
- V = np.random.rand(batch_size, seq_length, dim)
- # 创建一个上三角掩码矩阵
- mask = np.triu(np.ones((seq_length, seq_length)), k=1)[np.newaxis, np.newaxis, :, :]
- # 调用causal_self_attention函数
- output = causal_self_attention(Q, K, V, mask)
- print(output)
复制代码 关键点
- 掩码矩阵:通过上三角掩码矩阵实现因果性,确保模子在生成每个时间步时只能关注当前及之前的时间步。
- 数值稳定性:在 softmax 计算中,通过减去最大值来进步数值稳定性,避免溢出题目。
- 无效值处置惩罚:在计算注意力权重时,使用 np.nan_to_num 处置惩罚无效值,确保结果的有效性。
应用场景
- 自回归语言模子:如GPT系列,在生成下一个词时,只能依赖已生成的词。
- 语音生成:如WaveNet,在生成下一帧语音数据时,只能依赖之前的帧。
- 时间序列预测:在预测过程中,不依赖未来时间步,确保预测的因果性。
Code
代码已上传至:AI_With_NumPy
此项目汇集了更多AI相干的算法实现,供大家学习参考使用,接待点赞收藏 |