【AIGC】大模子口试高频考点-注意力(Attention)篇

打印 上一主题 下一主题

主题 915|帖子 915|积分 2749

(一)手撕单头注意力机制(ScaledDotProductAttention)函数

输入是query和 key-value,注意力机制起首计算query与每个key的关联性(compatibility),每个关联性作为每个value的权重(weight),各个权重与value的乘积相加得到输出。

  1. class ScaledDotProductAttention(nn.Module):
  2.     """ Scaled Dot-Product Attention """
  3.     def __init__(self, scale):
  4.         super().__init__()
  5.         self.scale = scale
  6.         self.softmax = nn.Softmax(dim=2)
  7.     def forward(self, q, k, v, mask=None):
  8.         u = torch.bmm(q, k.transpose(1, 2)) # 1.Matmul
  9.         u = u / self.scale # 2.Scale
  10.         if mask is not None:
  11.             u = u.masked_fill(mask, -np.inf) # 3.Mask
  12.         attn = self.softmax(u) # 4.Softmax
  13.         output = torch.bmm(attn, v) # 5.Output
  14.         return attn, output
  15. if __name__ == "__main__":
  16.     n_q, n_k, n_v = 2, 4, 4
  17.     d_q, d_k, d_v = 128, 128, 64
  18.     q = torch.randn(batch, n_q, d_q)
  19.     k = torch.randn(batch, n_k, d_k)
  20.     v = torch.randn(batch, n_v, d_v)
  21.     mask = torch.zeros(batch, n_q, n_k).bool()
  22.     attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
  23.     attn, output = attention(q, k, v, mask=mask)
  24.     print(attn)
  25.     print(output)
复制代码
(二)手撕多头注意力(MultiHeadAttention)

  1. class MultiHeadAttention(nn.Module):
  2.     """ Multi-Head Attention """
  3.     def __init__(self, n_head, d_k_, d_v_, d_k, d_v, d_o):
  4.         super().__init__()
  5.         self.n_head = n_head
  6.         self.d_k = d_k
  7.         self.d_v = d_v
  8.         self.fc_q = nn.Linear(d_k_, n_head * d_k)
  9.         self.fc_k = nn.Linear(d_k_, n_head * d_k)
  10.         self.fc_v = nn.Linear(d_v_, n_head * d_v)
  11.         self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
  12.         self.fc_o = nn.Linear(n_head * d_v, d_o)
  13.     def forward(self, q, k, v, mask=None):
  14.         n_head, d_q, d_k, d_v = self.n_head, self.d_k, self.d_k, self.d_v
  15.         batch, n_q, d_q_ = q.size()
  16.         batch, n_k, d_k_ = k.size()
  17.         batch, n_v, d_v_ = v.size()
  18.         q = self.fc_q(q) # 1.单头变多头
  19.         k = self.fc_k(k)
  20.         v = self.fc_v(v)
  21.         q = q.view(batch, n_q, n_head, d_q).permute(2, 0, 1, 3).contiguous().view(-1, n_q, d_q)
  22.         k = k.view(batch, n_k, n_head, d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, d_k)
  23.         v = v.view(batch, n_v, n_head, d_v).permute(2, 0, 1, 3).contiguous().view(-1, n_v, d_v)
  24.         if mask is not None:
  25.             mask = mask.repeat(n_head, 1, 1)
  26.         attn, output = self.attention(q, k, v, mask=mask) # 2.当成单头注意力求输出
  27.         output = output.view(n_head, batch, n_q, d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1) # 3.Concat
  28.         output = self.fc_o(output) # 4.仿射变换得到最终输出
  29.         return attn, output
  30. if __name__ == "__main__":
  31.     n_q, n_k, n_v = 2, 4, 4
  32.     d_q_, d_k_, d_v_ = 128, 128, 64
  33.     q = torch.randn(batch, n_q, d_q_)
  34.     k = torch.randn(batch, n_k, d_k_)
  35.     v = torch.randn(batch, n_v, d_v_)   
  36.     mask = torch.zeros(batch, n_q, n_k).bool()
  37.     mha = MultiHeadAttention(n_head=8, d_k_=128, d_v_=64, d_k=256, d_v=128, d_o=128)
  38.     attn, output = mha(q, k, v, mask=mask)
  39.     print(attn.size())
  40.     print(output.size())
复制代码
(三)手撕自注意力机制函数(SelfAttention)

Self-Attention。和Attention雷同,他们都是一种注意力机制。差别的是Attention是source对target,输入的source和输出的target内容差别。比方英译中,输入英文,输出中文。而Self-Attention是source对source,是source内部元素之间大概target内部元素之间发生的Attention机制,也可以理解为Target=Source这种特殊情况下的注意力机制。
  1. class SelfAttention(nn.Module):
  2.     """ Self-Attention """
  3.     def __init__(self, n_head, d_k, d_v, d_x, d_o):
  4.         self.wq = nn.Parameter(torch.Tensor(d_x, d_k))
  5.         self.wk = nn.Parameter(torch.Tensor(d_x, d_k))
  6.         self.wv = nn.Parameter(torch.Tensor(d_x, d_v))
  7.         self.mha = MultiHeadAttention(n_head=n_head, d_k_=d_k, d_v_=d_v, d_k=d_k, d_v=d_v, d_o=d_o)
  8.         self.init_parameters()
  9.     def init_parameters(self):
  10.         for param in self.parameters():
  11.             stdv = 1. / np.power(param.size(-1), 0.5)
  12.             param.data.uniform_(-stdv, stdv)
  13.     def forward(self, x, mask=None):
  14.         q = torch.matmul(x, self.wq)   
  15.         k = torch.matmul(x, self.wk)
  16.         v = torch.matmul(x, self.wv)
  17.         attn, output = self.mha(q, k, v, mask=mask)
  18.         return attn, output
  19. if __name__ == "__main__":
  20.     n_x = 4
  21.     d_x = 80
  22.     x = torch.randn(batch, n_x, d_x)
  23.     mask = torch.zeros(batch, n_x, n_x).bool()
  24.     selfattn = SelfAttention(n_head=8, d_k=128, d_v=64, d_x=80, d_o=80)
  25.     attn, output = selfattn(x, mask=mask)
  26.     print(attn.size())
  27.     print(output.size())
复制代码
(四)GPT2 解码中的KV Cache

无论是Encoder-Decoder布局,还是现在我们最靠近AGI的decoder-only的LLM,解码生成时都是自回归auto-regressive的方式。
也就是,解码的时候,先根据当前输入input ,生成下一个 token,然后把新生成的token拼接在input反面,获得新的输入input,再用input生成token,依此迭代,直到生成竣事。
我们可以注意到,下一个step的输入实在包含了上一个step的内容,而且只在最反面多了一点点(一个token)。那么下一个step的计算应该也包含了上一个step的计算。
但是模子在推理的时候可不管这些,无论你是不是只要末了一个字的输出,它都把所有输入计算一遍,给出所有输出结果。
也就是说中心有许多我们用不到的计算,这样就造成了浪费。
而且随着生成的结果越来越多,输入的长度也越来越长,上面这个例子里,输入长度就从step0的10个,每步增长1,直到step5的15个。假如输入的instruction是让模子写作文,那大概就有800个step。这个情况下,step0被算了800次,step1被算了799次…这样浪费的计算资源确实不容忽视。
有没有什么办法可以重复利用上一个step里已经计算过的结果,减少浪费呢?
答案就是KV Cache,利用一个缓存,把必要重复利用的中心计算结果存下来,减少重复计算。
而 k 和 v 就是我要缓存的对象。
想象一下,在上面的例子中,假设我们一开始的输入就是3个字,我们第一次预测就是预测第4个字,那么由于一开始没有任何缓存,所有我们每一层还是要诚实地计算一遍。然后把 k 、 v 值缓存起来。
则有

kv cache的下标l表示模子层数。
在举行第二次预测,也就是预测第5个字的时候,在第l层的时候,由于前面我们缓存了每层的ku 值,那本层就只必要算新的 o3,而不消算 o0、o1、o2。
由于第l层的 o0、o1、o2本来会经过FNN层之后进到 l十1 层,再经过新的投影变更,成为 l + 1 层的 k、v值,但是l十 1 层的 k、v 值我们已经缓存过了!
然后我们把本次新增算出来的 k、υ 值也存入缓存。

这样就节省了attention和FFN的许多重复计算。
transformers中,生成的时候传入use_cache=True就会开启KV Cache。
也可以简单看下GPT2中的实现,中文注释的部分就是使用缓存结果和更新缓存结果
  1. Class GPT2Attention(nn.Module):
  2.     ...
  3.     ...
  4.     def forward(
  5.         self,
  6.         hidden_states: Optional[Tuple[torch.FloatTensor]],
  7.         layer_past: Optional[Tuple[torch.Tensor]] = None,
  8.         attention_mask: Optional[torch.FloatTensor] = None,
  9.         head_mask: Optional[torch.FloatTensor] = None,
  10.         encoder_hidden_states: Optional[torch.Tensor] = None,
  11.         encoder_attention_mask: Optional[torch.FloatTensor] = None,
  12.         use_cache: Optional[bool] = False,
  13.         output_attentions: Optional[bool] = False,
  14.     ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
  15.         if encoder_hidden_states is not None:
  16.             if not hasattr(self, "q_attn"):
  17.                 raise ValueError(
  18.                     "If class is used as cross attention, the weights `q_attn` have to be defined. "
  19.                     "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
  20.                 )
  21.             query = self.q_attn(hidden_states)
  22.             key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
  23.             attention_mask = encoder_attention_mask
  24.         else:
  25.             query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
  26.         query = self._split_heads(query, self.num_heads, self.head_dim)
  27.         key = self._split_heads(key, self.num_heads, self.head_dim)
  28.         value = self._split_heads(value, self.num_heads, self.head_dim)
  29.         # 过去所存的值
  30.         if layer_past is not None:
  31.             past_key, past_value = layer_past
  32.             key = torch.cat((past_key, key), dim=-2)  # 把当前新的key加入
  33.             value = torch.cat((past_value, value), dim=-2)  # 把当前新的value加入
  34.         if use_cache is True:
  35.             present = (key, value)  # 输出用于保存
  36.         else:
  37.             present = None
  38.         if self.reorder_and_upcast_attn:
  39.             attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
  40.         else:
  41.             attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  42.         attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
  43.         attn_output = self.c_proj(attn_output)
  44.         attn_output = self.resid_dropout(attn_output)
  45.         outputs = (attn_output, present)
  46.         if output_attentions:
  47.             outputs += (attn_weights,)
  48.         return outputs  # a, present, (attentions)
复制代码
总的来说,KV Cache是以空间换时间的做法,通过使用快速的缓存存取,减少了重复计算。(注意,只有decoder布局的模子可用,由于有mask attention的存在,使得前面的token可以不消关注反面的token)
(五)手撕 MQA 算法

MQA 让所有的头之间 共享 同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。
  1. class MultiQueryAttention(nn.Module):
  2.     """Multi-Query self attention.
  3.     Using torch or triton attention implemetation enables user to also use
  4.     additive bias.
  5.     """
  6.     def __init__(
  7.         self,
  8.         d_model: int,
  9.         n_heads: int,
  10.         attn_impl: str = 'triton',
  11.         clip_qkv: Optional[float] = None,
  12.         qk_ln: bool = False,
  13.         softmax_scale: Optional[float] = None,
  14.         attn_pdrop: float = 0.0,
  15.         low_precision_layernorm: bool = False,
  16.         verbose: int = 0,
  17.         device: Optional[str] = None,
  18.     ):
  19.         super().__init__()
  20.         self.attn_impl = attn_impl
  21.         self.clip_qkv = clip_qkv
  22.         self.qk_ln = qk_ln
  23.         self.d_model = d_model
  24.         self.n_heads = n_heads
  25.         self.head_dim = d_model // n_heads
  26.         self.softmax_scale = softmax_scale
  27.         if self.softmax_scale is None:
  28.             self.softmax_scale = 1 / math.sqrt(self.head_dim)
  29.         self.attn_dropout_p = attn_pdrop
  30.         self.Wqkv = nn.Linear(
  31.             d_model,
  32.             d_model + 2 * self.head_dim,
  33.             device=device,
  34.         )
  35.         fuse_splits = (d_model, d_model + self.head_dim)
  36.         self.Wqkv._fused = (0, fuse_splits)  # type: ignore
  37.         self.attn_fn = scaled_multihead_dot_product_attention
  38.         self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
  39.         self.out_proj._is_residual = True  # type: ignore
  40.     def forward(
  41.         self,
  42.         x,
  43.         past_key_value=None,
  44.         attn_bias=None,
  45.         attention_mask=None,
  46.         is_causal=True,
  47.         needs_weights=False,
  48.     ):
  49.         qkv = self.Wqkv(x)                                # (1, 512, 960)
  50.         if self.clip_qkv:
  51.             qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
  52.         query, key, value = qkv.split(                         # query -> (1, 512, 768)
  53.             [self.d_model, self.head_dim, self.head_dim],      # key   -> (1, 512, 96)
  54.             dim=2                                              # value -> (1, 512, 96)
  55.         )
  56.         key_padding_mask = attention_mask
  57.         if self.qk_ln:
  58.             # Applying layernorm to qk
  59.             dtype = query.dtype
  60.             query = self.q_ln(query).to(dtype)
  61.             key = self.k_ln(key).to(dtype)
  62.         context, attn_weights, past_key_value = self.attn_fn(
  63.             query,
  64.             key,
  65.             value,
  66.             self.n_heads,
  67.             past_key_value=past_key_value,
  68.             softmax_scale=self.softmax_scale,
  69.             attn_bias=attn_bias,
  70.             key_padding_mask=key_padding_mask,
  71.             is_causal=is_causal,
  72.             dropout_p=self.attn_dropout_p,
  73.             training=self.training,
  74.             needs_weights=needs_weights,
  75.             multiquery=True,
  76.         )
  77.         return self.out_proj(context), attn_weights, past_key_value
复制代码
(六)Attention改进版

(1)Flash Attention

Paper:《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》
论文链接:https://arxiv.org/abs/2205.1413


  • FlashAttention是一种加速注意力计算方法,现在已经应用在:GPT-3、Falcon2(阿联酋大模子)、Llama2、Megatron-LM、GPT-4等着名LLM上。
  • Flash Attention已经集成到了pytorch2.0中,可以很便捷的调用。
  • FlashAttention旨在加速注意力计算并减少内存占用。FlashAttention利用底层硬件的内存层次知识,比方GPU的内存层次布局,来进步计算速率和减少内存访问开销。 FlashAttention的核心原理是通过将输入分块并在每个块上实行注意力操作,从而减少对高带宽内存(HBM)的读写操作。具体而言,FlashAttention使用平铺和重计算等经典技术,将输入块从HBM加载到SRAM(快速缓存),在SRAM上实行注意力操作,并将结果更新回HBM。FlashAttention减少了内存读写量,从而实现了2-4倍的时钟时间加速。
  • Timeline: 最新的FlashAttention-2版本进一步优化了FlashAttention算法,使用了更好的并行化和工作分区方法,使得计算速率进步了2倍。FlashAttention-2还支持更高的头维数和多查询注意力等新特性,进一步提拔了性能和灵活性。

具体原理细节请查看:
Flash Attention原理详解(含代码解说)
(2)Page Attention

论文地点:vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention | vLLM Blog
原理如下:
步骤1: 确定固定块巨细: (假设512 token) 等举行内存管理。
步骤2: 动态分配内存: 对每一个batch 起首各分配一个固定块巨细,将该固定块对应的物理地点和当前占用token 个数,在一张block table 表内记录。
步骤3:对于不断增长的token数,每凌驾512就 在runtime阶段重新分配物理内存,并同样将该块的信息记录到blocktable 表中。
步骤4: 算子库加载实现,获得每个batch 的逻辑地点即该起始block快的地点,得到block 块,加载kvcache,根据当前总的token个数,分段加载token快举行计算。
步骤5:算子库存储实现,获得每个batch 的逻辑地点即该起始block快的地点,根据当前token索引,获得要存储在哪个逻辑block块,找到对应块的fill个数,偏移后存储到对应位置。


具体原理细节请查看:PageAttention 论文解析
(3)Flash Attention2

论文地点:https://arxiv.org/pdf/2307.08691
如何扩展Transformer使之能够处置惩罚更长的序列不停是一个挑战,**由于其核心注意力层的运行时间和内存占用量随输入序列长度成二次增加。**我们希望能够冲破2k序列长度限制,从而能够训练书籍、高分辨率图像和长视频。此外,写作等应用也必要模子能够处置惩罚长序列。已往一年中,业界推出了一些远超之前长度的语言模子:GPT-4为32k,MosaicML的MPT为65k,以及Anthropic的Claude为100k。
虽然相比标准Attention,FlashAttention快了24倍,节约了1020倍内存,但是离设备理论最大throughput和flops还差了许多。本文提出了FlashAttention-2,它具有更好的并行性和工作分区。实验结果表现,FlashAttention-2在正向传递中实现了约2倍的速率提拔,到达了理论最大吞吐量的73%,在反向传递中到达了理论最大吞吐量的63%。在每个A100 GPU上的训练速率可到达225 TFLOPs/s。
本文主要贡献和创新点为:


  • 减少了non-matmul FLOPs的数量(消除了原先频仍rescale)。虽然non-matmul FLOPs仅占总FLOPs的一小部分,但它们的实行时间较长,这是由于GPU有专用的矩阵乘法计算单元,其吞吐量高达非矩阵乘法吞吐量的16倍。因此,减少non-matmul FLOPs并尽大概多地实行matmul FLOPs非常紧张。
  • 提出了在序列长度维度上并行化。该方法在输入序列很长(此时batch size通常很小)的情况下增加了GPU利用率。纵然对于单个head,也在差别的thread block之间举行并行计算。
  • 在一个attention计算块内,将工作分配在一个thread block的差别warp上,以减少通信和共享内存读/写。
具体原理细节请查看:FlashAttention2详解(性能比FlashAttention提拔200%)
(4)Flash Attention3

Github地点:https://github.com/Dao-AILab/flash-attention
论文地点:https://tridao.me/publications/flash3/flash3.pdf
FlashAttention、FlashAttention-2开创了一种通过最小化内存读/写来加快 GPU 注意力的方法,现在已经成为了pytorch库的标配了,使用它来加速 Transformer 训练和推理。
使得LLM上下文长度大幅增加,从 2-4K (GPT-3, OPT) 到 128K (GPT-4),乃至 1M (Llama 3)


  • FlashAttention-2 可以在 A100 GPU 上实现高达 70% 的理论最大 FLOPS,但它尚未利用 Hopper GPU 上的新功能来最大限度地进步性能。
  • FlashAttention-2 在 H100 GPU 上仅实现了 35% 的理论最大 FLOP 利用率。
  • FlashAttention-3在H100 理论最大 FLOPS 的利用率为 75%,比接纳 FP16 的 FlashAttention-2 快 1.5-2.0 倍,最高可达 740 TFLOPS。使用 FP8 时,FlashAttention-3 可到达靠近 1.2 PFLOPS,毛病比基线 FP8 注意力小 2.6 倍。
具体改进如下:


  • 更高效的 GPU 利用率

    • H100 GPU 推出了WGMMA(翘曲矩阵乘法累加)功能,比A100吞吐量高3倍
    • H100 GPU 的TMA(张量影象加速器)功能,可加速全局内存和共享内存之间的数据传输,负责所有索引计算和越界预测。这样可以开释寄存器,增加图块巨细和效率的宝贵资源。这导致大型语言模子 (LLM) 的训练和运行速率比FlashAttention-2快得多(1.5-2 倍)。

  • 以更低的精度获得更好的性能

    • FlashAttention-3 可以在保持精度的同时处置惩罚较低精度的数字 (FP8)。比方,FP16 为 989 TFLOPS,FP8 为 1978 TFLOPS。这允许更快的处置惩罚速率并尽大概低落内存使用量,这大概会为运行大规模 AI 操作的客户节省成本并进步效率。
    • 具体的做法是:利用QuIP: 2-Bit Quantization of Large Language Models With Guarantees技术,通过非干系处置惩罚减少量化毛病,即将查询和键与随机正交矩阵相乘,以“分散”非常值并减少量化毛病。

  • 能够在 LLM 中使用更长的上下文

    • 通过加速注意力机制,FlashAttention-3 使 AI 模子能够更有效地处置惩罚更长的文本片段。这可以使应用步伐能够理解和生成更长、更复杂的内容,而不会减慢速率。
    • 对于 FP16,我们看到比 FlashAttention-2 加速约 1.6-2.0 倍。

    • 对于FP8,我们可以到达靠近1.2 PFLOPS


具体原理细节请查看:FlashAttention-3 比FlashAttention-2快了2倍,做了些什么?

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

王國慶

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表