【llm对话体系】大模型源码分析之 LLaMA 模型的 Masked Attention ...

打印 上一主题 下一主题

主题 1919|帖子 1919|积分 5757

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
在大型语言模型(LLM)中,注意力机制(Attention Mechanism)是核心组成部分。然而,在自回归(autoregressive)模型中,例如 LLaMA,我们须要对注意力举行屏蔽(Masking),以防止模型“偷看”将来的信息。本文将深入探讨 LLaMA 模型中 Masked Attention 的实现逻辑,并对比其他类型大模型中常用的 Masked Attention 方案。
1. 什么是 Masked Attention

1.1 为什么须要 Mask

在自回归模型中,模型的目标是根据已有的输入序列猜测下一个词。在练习阶段,模型会吸取整个输入序列,但在猜测某个位置的词时,模型不应该看到该位置之后的信息。这就是 Masked Attention 的作用:它会屏蔽将来词对当前词的影响,确保模型只能依赖于已往的信息举行猜测。
1.2 Mask 的类型

Mask 主要分为两种类型:

  • Padding Mask: 用于处理变长序列,屏蔽 padding 部分对注意力计算的影响。
  • Causal Mask: 用于自回归模型,屏蔽将来位置的信息,防止模型偷看将来。
2. LLaMA 中的 Masked Attention

LLaMA 模型主要关注自回归的生成任务,所以使用的是 Causal Mask
2.1 LLaMA 的实现逻辑

LLaMA 使用尺度的多头自注意力机制(Multi-Head Self-Attention, MHA),并在计算注意力权重时应用 Causal Mask。详细流程如下:

  • 线性变换: 将输入序列映射为查询(Query)、键(Key)和值(Value)向量。
  • 计算注意力分数: 计算查询向量和键向量的点积,并举行缩放。
  • 应用 Mask: 使用 Causal Mask 屏蔽将来位置的注意力分数。
  • 计算注意力权重: 对屏蔽后的注意力分数举行 Softmax 归一化。
  • 计算加权值向量: 使用注意力权重对值向量举行加权求和。
2.2 LLaMA 源码示例 (PyTorch)

以下是 LLaMA 模型中 Masked Attention 的核心代码(简化版):
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class LlamaAttention(nn.Module):
  5.     def __init__(self, d_model, num_heads):
  6.         super().__init__()
  7.         self.d_model = d_model
  8.         self.num_heads = num_heads
  9.         self.head_dim = d_model // num_heads
  10.         # 线性变换
  11.         self
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
继续阅读请点击广告
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

石小疯

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表