qidao123.com技术社区-IT企服评测·应用市场

标题: LLaMA 的模子布局 [打印本页]

作者: 花瓣小跑    时间: 2024-11-25 10:12
标题: LLaMA 的模子布局
1. LLaMA 的 Transformer 架构改进

LLaMA 基于经典的 Transformer 架构,但与原始的 Transformer 相比有几点重要改进:

   RMSNorm   在   HuggingFace Transformer   库中代码实现:
  1. class LlamaRMSNorm(nn.Module):
  2.     def __init__(self, hidden_size, eps=1e-6):
  3.         """
  4.         初始化 LlamaRMSNorm 模块。
  5.         参数:
  6.         - hidden_size: 输入隐藏状态的维度,即需要归一化的特征数。
  7.         - eps: 用于数值稳定的非常小的数,防止计算过程中分母为0(通常为1e-6)。
  8.         """
  9.         super().__init__()
  10.         self.weight = nn.Parameter(torch.ones(hidden_size))  # 权重参数初始化为1
  11.         self.variance_epsilon = eps  # 数值稳定项
  12.     def forward(self, hidden_states):
  13.         """
  14.         前向传播函数,执行归一化操作。
  15.         参数:
  16.         - hidden_states: 输入的张量,表示网络层的隐藏状态。
  17.         返回值:
  18.         - 返回归一化并且经过缩放的隐藏状态。
  19.         """
  20.         # 保存输入的原始数据类型,以便最后转换回同样的类型
  21.         input_dtype = hidden_states.dtype
  22.         
  23.         # 计算方差(或更准确地说是每个样本的特征值的均值平方)
  24.         variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  25.         
  26.         # 对 variance 加上 epsilon,防止分母为0,然后取平方根的倒数,进行归一化
  27.         hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  28.         
  29.         # weight 是可训练参数,用于缩放归一化后的输出
  30.         return (self.weight * hidden_states).to(input_dtype)
复制代码
2. 激活函数 SwiGLU

        LLaMA 还在全连接层中使用了 SwiGLU 激活函数。SwiGLU 是一种改进的激活函数,相比经典的 ReLU 或 Swish,它在大规模模子如 PaLM 和 LLaMA 中表现更好,能提供更高的非线性表达本领。其计算公式为:


Swish 激活函数在参数差别取值下的形状 

3. 旋转位置嵌入(RoPE)

        RoPE(Rotary Positional Embedding) 是 LLaMA 中一个重要的创新,旨在替代传统的绝对位置编码,改进位置信息在注意力机制中的作用。其核心思想是通过使用复数的多少操作(旋转)将位置编码引入查询(q)和键(k)中,实现相对位置编码的效果。公式如下:

RoPE 的上风在于它能处置惩罚更长的序列并捕捉相对位置信息,特别适合在大规模自然语言模子中应用。
   RoPE   在 HuggingFace Transformer 库中代码实现如下:   
  1. class LlamaRotaryEmbedding(torch.nn.Module):
  2.     def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
  3.         # 调用父类的初始化函数,确保模块正确继承 nn.Module 的功能
  4.         super().__init__()
  5.         # `inv_freq` 是逆频率的计算结果,它根据维度 `dim` 来生成。该逆频率用于生成正弦和余弦嵌入。
  6.         # `torch.arange(0, dim, 2)` 生成从 0 到 dim 的偶数序列(步长为 2),然后除以 dim,构造分布式频率。
  7.         inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
  8.         
  9.         # 将逆频率 `inv_freq` 注册为模型的缓冲区,这意味着它不是可训练参数,但会被持久保存。
  10.         self.register_buffer("inv_freq", inv_freq)
  11.         # 初始化时,预先缓存最大序列长度对应的旋转嵌入,避免在每次前向传播时重复计算
  12.         self.max_seq_len_cached = max_position_embeddings
  13.         # `t` 是时间步(位置索引),从 0 到 `max_seq_len_cached - 1` 的序列,用来生成位置嵌入。
  14.         t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
  15.         
  16.         # 通过 `einsum` 进行矩阵乘法,将 `t` 和 `inv_freq` 相乘生成频率矩阵 `freqs`,表示每个位置和对应的频率。
  17.         # einsum("i,j->ij") 表示进行外积操作,将 `t` 和 `inv_freq` 组合成位置-频率矩阵。
  18.         freqs = torch.einsum("i,j->ij", t, self.inv_freq)
  19.         # 将 `freqs` 进行拼接,扩展为 [seq_len, dim],这样每个位置都有对应的频率嵌入。
  20.         emb = torch.cat((freqs, freqs), dim=-1)
  21.         # 获取当前默认的 `dtype`,以确保缓存的 `cos` 和 `sin` 的数据类型与输入一致。
  22.         dtype = torch.get_default_dtype()
  23.         # 缓存 cos 和 sin 嵌入,这里为嵌入增加了额外的维度以适配多头注意力机制的输入格式。
  24.         # `persistent=False` 表示这些缓冲区不会被持久保存到模型的状态字典中。
  25.         self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
  26.         self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
  27.     def forward(self, x, seq_len=None):
  28.         # `x` 是输入张量,通常形状为 [batch_size, num_attention_heads, seq_len, head_size]。
  29.         # 这部分代码检查当前序列长度是否超过缓存的最大序列长度 `max_seq_len_cached`。
  30.         if seq_len > self.max_seq_len_cached:
  31.             # 如果输入的序列长度超过了缓存的最大序列长度,重新计算 sin 和 cos 值。
  32.             self.max_seq_len_cached = seq_len
  33.             t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
  34.             # 重新计算位置和频率的外积。
  35.             freqs = torch.einsum("i,j->ij", t, self.inv_freq)
  36.             # 生成新的频率嵌入,并更新缓冲区。
  37.             emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
  38.             self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
  39.             self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
  40.         # 截取缓存的 cos 和 sin 值,使其匹配输入序列的长度 `seq_len`,并确保数据类型与输入一致。
  41.         return (
  42.             self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
  43.             self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
  44.         )
  45. # `rotate_half` 函数实现对输入张量的一半进行旋转操作,这是 RoPE 的核心机制。
  46. # 它将输入张量的后半部分取负并与前半部分交换,形成旋转效果。
  47. def rotate_half(x):
  48.     """旋转输入张量的一半维度"""
  49.     x1 = x[..., : x.shape[-1] // 2]  # 获取输入的前半部分。
  50.     x2 = x[..., x.shape[-1] // 2 :]  # 获取输入的后半部分。
  51.     return torch.cat((-x2, x1), dim=-1)  # 交换并将后半部分取负,拼接成新的张量。
  52. # `apply_rotary_pos_emb` 函数将旋转位置嵌入应用到查询 `q` 和键 `k` 上。
  53. # 它将 cos 和 sin 值乘以查询和键,然后通过旋转操作进行位置嵌入的应用。
  54. def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
  55.     # `cos` 和 `sin` 的前两维始终为 1,因此可以去掉这些冗余维度。
  56.     cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
  57.     sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
  58.     # 通过 `position_ids` 选择相应的嵌入,并扩展维度以匹配查询和键的形状。
  59.     cos = cos[position_ids].unsqueeze(1)  # [batch_size, 1, seq_len, dim]
  60.     sin = sin[position_ids].unsqueeze(1)  # [batch_size, 1, seq_len, dim]
  61.     # 对查询 `q` 应用旋转位置嵌入,先乘以 cos,再加上乘以 sin 的旋转结果。
  62.     q_embed = (q * cos) + (rotate_half(q) * sin)
  63.     # 对键 `k` 应用相同的旋转位置嵌入。
  64.     k_embed = (k * cos) + (rotate_half(k) * sin)
  65.     # 返回嵌入了位置编码的查询和键。
  66.     return q_embed, k_embed
复制代码
  4. 模子团体框架

        LLaMA 模子团体架构仍然是 Transformer 的自回归解码器。LLaMA 的设计吸收了 GPT 系列模子的优点,同时在以下几个方面做了重要改进:


LLaMA 差别模子规模下的详细超参数细节

   HuggingFace Transformer   库中   LLaMA   解码器团体实现代码实现:
  1. class LlamaDecoderLayer(nn.Module):
  2.     def __init__(self, config: LlamaConfig):
  3.         # 初始化时调用父类的构造函数,确保正确继承
  4.         super().__init__()
  5.         # 从配置 `config` 中获取隐藏层大小,用于层的初始化
  6.         self.hidden_size = config.hidden_size
  7.         # 自注意力机制层的初始化,使用 LlamaAttention。它会处理输入的自注意力操作。
  8.         self.self_attn = LlamaAttention(config=config)
  9.         # 多层感知机层(MLP)用于后续的非线性变换。其包含一个隐藏层大小和中间层大小。
  10.         # `hidden_act` 是激活函数(如 ReLU),用于非线性变换。
  11.         self.mlp = LlamaMLP(
  12.             hidden_size=self.hidden_size,  # 输入的隐藏层大小
  13.             intermediate_size=config.intermediate_size,  # 中间层大小,通常是隐藏层大小的 4 倍
  14.             hidden_act=config.hidden_act  # 激活函数,如 ReLU 或 GELU
  15.         )
  16.         # 输入层的归一化层,使用 RMSNorm,防止训练中的梯度爆炸或消失
  17.         self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  18.         # 注意力后的归一化层,确保经过注意力机制后的输出有稳定的分布
  19.         self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  20.     def forward(
  21.         self,
  22.         hidden_states: torch.Tensor,  # 输入的隐藏状态张量,形状为 [batch_size, seq_len, hidden_size]
  23.         attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,防止模型关注到不需要的序列位置
  24.         position_ids: Optional[torch.LongTensor] = None,  # 位置编码的ID,用于生成位置信息
  25.         past_key_value: Optional[Tuple[torch.Tensor]] = None,  # 用于缓存前面层的键值对,加速推理
  26.         output_attentions: Optional[bool] = False,  # 是否输出注意力权重
  27.         use_cache: Optional[bool] = False,  # 是否使用缓存,加速推理
  28.     ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  29.         # 保存输入的隐藏状态到 `residual`,用于残差连接
  30.         residual = hidden_states
  31.         # 对输入的隐藏状态应用归一化处理,确保输入的数值稳定
  32.         hidden_states = self.input_layernorm(hidden_states)
  33.         # 自注意力机制,处理输入序列中的各个位置之间的相互依赖关系
  34.         hidden_states, self_attn_weights, present_key_value = self.self_attn(
  35.             hidden_states=hidden_states,  # 当前的隐藏状态
  36.             attention_mask=attention_mask,  # 注意力掩码,防止模型关注到不相关的部分
  37.             position_ids=position_ids,  # 位置编码,帮助模型知道序列中每个词的位置
  38.             past_key_value=past_key_value,  # 用于加速推理的缓存
  39.             output_attentions=output_attentions,  # 是否输出注意力权重
  40.             use_cache=use_cache,  # 是否缓存键值对,加速下一步计算
  41.         )
  42.         # 通过残差连接,将输入(residual)与注意力层的输出相加,以避免梯度消失问题
  43.         hidden_states = residual + hidden_states
  44.         # 经过注意力机制后,保存当前的隐藏状态作为残差
  45.         residual = hidden_states
  46.         # 对经过注意力后的隐藏状态再次进行归一化处理
  47.         hidden_states = self.post_attention_layernorm(hidden_states)
  48.         # 进入多层感知机(MLP)模块,进行非线性变换,增加网络的表达能力
  49.         hidden_states = self.mlp(hidden_states)
  50.         # 再次通过残差连接,将 MLP 的输出与注意力后的隐藏状态相加
  51.         hidden_states = residual + hidden_states
  52.         # 将隐藏状态作为输出
  53.         outputs = (hidden_states,)
  54.         # 如果需要输出注意力权重,则将注意力权重也加入输出
  55.         if output_attentions:
  56.             outputs += (self_attn_weights,)
  57.         # 如果使用缓存,则将当前层的键值对缓存下来,便于后续层使用
  58.         if use_cache:
  59.             outputs += (present_key_value,)
  60.         # 返回包含隐藏状态和其他可选输出的元组
  61.         return outputs
复制代码
5. LLaMA 在 HuggingFace 中的实现



免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 qidao123.com技术社区-IT企服评测·应用市场 (https://dis.qidao123.com/) Powered by Discuz! X3.4