Llama开源代码详细解读(2)

打印 上一主题 下一主题

主题 524|帖子 524|积分 1576

FlashAttention

  1. if is_flash_attn_available(): # 检查flashattention的可用性
  2.     from flash_attn import flash_attn_func, flash_attn_varlen_func
  3.     from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa
复制代码
FlashAttention是Tranformer模子中用于改进留意力机制的技术,重要目的是淘汰计算复杂度和内存占用。


  • flash_attn_func用于标准的flashattention计算。
  • flash_attn_varlen_func用于处置处罚变长序列(长度未能确定)的flashattention计算。
  • index_first_axis用于处置处罚第一个索引轴。
  • pad_input将数据进行添补处置处罚,从而确定长度。
  • unpad_input将添补后的输入还原为原始形态。
Logging模块

  1. logger = logging.get_logger(__name__)
  2. _CONFIG_FOR_DOC = "LlamaConfig"
复制代码
创建了名为logger的日志记录器对象,__name__用于生存模块的名称,确保每个模块都有本身的日志记录器。
_CONFIG_FOR_DOC前面带有下划线,因此可以看出其代表一个模块的内部变量。
get_unpad_data模块

  1. def _get_unpad_data(padding_mask):
  2.     seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
  3.     indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
  4.     max_seqlen_in_batch = seqlens_in_batch.max().item()
  5.     cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
  6.     return (
  7.         indices,
  8.         cu_seqlens,
  9.         max_seqlen_in_batch,
  10.     )
复制代码
该模块的作用是padding_mask提取非添补的数据,分为以下几步:

  • seqlens_in_batch计算量每个张量的有效长度,sum()函数计算每个张量的有效长度。dim等于-1意味着以最按照最后一个维度进行求和,如果是二维,就可以理解为对跨列操作,即计算了每一行非添补元素的个数。
  • indices获取了非添补元素的索引。flatten()函数将张量睁开成一维,as_tuple为flase意味着返回不是元组形式而是二维矩阵形式,由于返回的是二维矩阵,因此我们需要flatten()再次展平成一维。
    ——为什么不返回元组呢?
    如果返回元组,那么返回的格式是包罗一个一维张量的元组,然后还需要从元组中取出这个一维张量,类似:
  1. torch.nonzero(padding_mask.flatten(), as_tuple=True)[0]
复制代码
这样比较麻烦,不如直接返回二维数组再展平。
3. max_seqlen_in_batch获取了在seqlens_in_batch中的最大值并返回(即长度最长的那一个),然后 item()函数的作用是将一个元素的张量转换为python对应的标量,即一个数。
4. cu_seqlens计算累计长度并进行添补。cumsum()函数用于计算指定维度的累计和,(1,0)意味着只在左边添加一个元素,右边不添加。F.pad()是为张量进行添补的函数。这对于处置处罚变长序列非常有用,因为即获得了每个序列的开始索引,容易确定起始和竣事位置。
5. 最终返回的包括:非零元素的索引,左边添补过了的累计长度,最长序列的长度。
从而,到达了提取非添补数据的目的。
_make_causal_mask模块

  1. def _make_causal_mask(
  2.     input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
  3. ):
  4.     """
  5.     Make causal mask used for bi-directional self-attention.
  6.     """
  7.     bsz, tgt_len = input_ids_shape
  8.     mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  9.     mask_cond = torch.arange(mask.size(-1), device=device)
  10.     mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  11.     mask = mask.to(dtype)
  12.     if past_key_values_length > 0:
  13.         mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  14.     return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
复制代码
该模块用于生成因果掩码,通常用于双向自留意力机制。具体来说,该模块包管在计算留意力时,只能看到当前时间步之前的信息,而看不到未来的,来保持因果关系。
输入参数

input_ids_shape: torch.Size:输入张量的形状,通常为(batch_size, target_length)。
dtype: torch.dtype:用于生成掩码的张量范例。
device: torch.device:指定装备是GPU还是CPU。
past_key_values_length: int = 0:已往的键值对长度,用于增量计算。
步骤



  • 获取批大小(bsz)和目的长度(tgt_len)。
  • 创建初始掩码:形状为(tgt_len,tgt_len),全部的dtpye设置为最小,通常为负无穷。torch.full()函数创建一个均为min的矩阵
  • 设置掩码条件:mask_cond生成一个mask最后一个维度大小-1长度的序列,并放置在指定装备
  • mask.masked_fill:将下三角矩阵设置为0。这里用到了pytorch的广播机制,将一个行为1和列为1的向量扩充进行比较,从而将下三角都变为了0。
  • 将掩码转换为指定的数据范例。
  • 如果past_key_values_length>0,那么就在最后一个维度拼接上一个(tgt_len, past_key_values_length)的张量,这是为了在处置处罚增量计算时,可以或许考虑已往的键值对。
    其中,zeros()函数创建了(tgt_len,past_key_value_length)的全零矩阵,用cat()在mask前添加了一个全0块。
  • 最终将遮罩的维度扩展为四维形状(bsz, 1, tgt_len, tgt_len + past_key_values_length),并返回。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

篮之新喜

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表