一文详解LLaMa系列模型:原理先容、代码解读

[复制链接]
发表于 2025-4-28 08:24:44 | 显示全部楼层 |阅读模式
LLaMA详解

LLaMA(Large Language Model Meta AI)是由Meta(前身为Facebook)开发的一种大规模语言模型,旨在进步自然语言处理(NLP)任务的性能。LLaMA基于变动器(Transformer)架构,并经过大规模数据训练,以便在多种语言任务中表现出色。
Meta AI认为:对于给定的计算预算,最佳性能不是通过最大的模型实现的,而是通过在更多数据上训练的较小模型实现的。
前排提示,文末有大模型AGI-CSDN独家资料包哦!
模型布局

与GPT等生成模型类似,LLaMA也只使用了Transformer的解码器,但基于Transformer进行了三个改进:

  • 使用了GPT3的预标准化。为了进步训练稳定性,对每个Transformer子层的输入进行归一化,而不是对输出进行归一化。使用由RMSNorm 归一化函数。
  • 用 SwiGLU 激活函数替换 ReLU 非线性,以进步性能。使用 2 3 4 d \frac{2}{3}4d 32​4d的维度代替PaLM中的 4 d 4d 4d。
  • 类似GPTNeo,删除了绝对位置嵌入,而是添加了旋转位置嵌入(RoPE)。
下面逐一先容这三个改进:
RMSNorm

RMSNorm(Root Mean Square Normalization)是一种归一化技术,用于稳定和加速神经网络的训练过程。与其他归一化方法(如BatchNorm和LayerNorm)不同,RMSNorm通过计算输入张量的均方根(RMS)来进行归一化。RMSNorm公式如下:
RMSNorm ( x ) = x 1 d ∑ i = 1 d x i 2 + ϵ ⋅ γ \text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2 + \epsilon}} \cdot \gamma RMSNorm(x)=d1​∑i=1d​xi2​+ϵ ​x​⋅γ
此中 x x x是输入向量, d d d 是输入向量的维度, ϵ \epsilon ϵ是一个小常数,用于避免除零错误, γ \gamma γ是一个可学习的缩放参数。
LLaMa中的实现如下:
  1. class RMSNorm(torch.nn.Module):  
  2.     def __init__(self, dim: int, eps: float = 1e-6):  
  3.         super().__init__()  
  4.         self.eps = eps  
  5.         self.weight = nn.Parameter(torch.ones(dim))  
  6.   
  7.     def _norm(self, x):  
  8.         return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)  
  9.   
  10.     def forward(self, x):  
  11.         output = self._norm(x.float()).type_as(x)  
  12.         return output * self.weight
复制代码
SwiGLU激活函数

SwiGLU (Swish-Gated Linear Unit) 是一种用于神经网络的激活函数,它结合了Swish激活函数和门控机制,可以或许有用地增强模型的表达本领和性能。公式如下:
SwiGLU ( x ) = Swish ( x ) ⋅ ( Gated Linear Unit ( x ) ) \text{SwiGLU}(x) = \text{Swish}(x) \cdot (\text{Gated Linear Unit}(x)) SwiGLU(x)=Swish(x)⋅(Gated Linear Unit(x))
Swish ( x ) = x ⋅ σ ( x ) \text{Swish}(x) = x \cdot \sigma(x) Swish(x)=x⋅σ(x)
Gated Linear Unit ( x ) = Linear 1 ( x ) ⋅ σ ( Linear 2 ( x ) ) \text{Gated Linear Unit}(x) = \text{Linear}_1(x) \cdot \sigma(\text{Linear}_2(x)) Gated Linear Unit(x)=Linear1​(x)⋅σ(Linear2​(x))
σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+e−x1​
Linear 1 \text{Linear}_1 Linear1​和 Linear 2 \text{Linear}_2 Linear2​是两个单独的线性变动。
LLaMa代码中使用 F . s i l u ( x ) F.silu(x) F.silu(x)添加SwiGLU激活函数
RoPE

旋转位置嵌入(Rotary Position Embedding, RoPE)是一种为序列模型(如Transformer)提供位置编码的方法。RoPE通过将输入向量在复数域进行旋转变动,来编码序列中位置的信息。与传统的位置编码方法(如正弦-余弦位置编码)相比,RoPE可以或许更好地捕捉序列中的相对位置信息,进步模型的表现力。
旋转位置嵌入(RoPE)是一种为序列模型提供位置编码的方法。其通过将输入向量在复数域进行旋转变动来编码位置信息。以下是RoPE的具体实现步骤:

  • 频率向量的计算:
    f i = 1 θ 2 i d f_i = \frac{1}{\theta^{\frac{2i}{d}}} fi​=θd2i​1​
    此中 θ \theta θ是一个常数(通常取 10000), i i i是向量维度的索引。
  • 旋转角度的计算:
    angle ( t ) = t ⋅ f i \text{angle}(t) = t \cdot f_i angle(t)=t⋅fi​
    此中 t t t是位置索引。
  • 应用旋转变动:
    对每个位置 t t t的输入向量 x t x_t xt​,在复数域进行旋转变动:
    x t ′ = x t ⋅ e j ⋅ angle ( t ) x_t’ = x_t \cdot e^{j \cdot \text{angle}(t)} xt′​=xt​⋅ej⋅angle(t)
    对于位置编码,通例的做法是在计算 query,key 和 value 向量之前,管帐算一个位置编码向量 加到词嵌入上,位置编码向量同样也是维向量,然后再乘以对应的变动矩阵。
RoPE 的 self-attention 操作的流程是:对于 token 序列中的每个词嵌入向量,起首计算其对应的 query 和 key 向量,然后对每个 token 位置都计算对应的旋转位置编码,接着对每个 token 位置的 query 和 key 向量的元素按照两两一组应用旋转变动,最后再计算 query 和 key 之间的内积得到 self-attention 的计算效果。
下图很直观的展示了旋转变动的过程:

旋转编码 RoPE 可以有用地保持位置信息的相对关系,即相邻位置的编码之间有一定的相似性,而远离位置的编码之间有一定的差别性。 这样可以增强模型对位置信息的感知和利用。这一点是其他绝对位置编码方式(如正弦位置编码、学习的位置编码等)所不具备的,由于它们只能表示绝对位置,而不能表示相对位置。
   为什么旋转位置嵌入有用?
  

  • 捕捉相对位置信息:传统的位置嵌入方法通常仅编码绝对位置,这可能在处理长序列或须要捕捉相对位置信息的任务中表现不佳。而RoPE通过旋转变动自然地引入了相对位置信息,使得模型可以或许更好地理解序列中各个位置之间的相对关系。
  • 由于RoPE通过复数域的旋转变动来编码位置,这种变动可以或许捕捉更加丰富的位置信息。相比于简朴的线性变动,旋转变动提供了更强的非线性表达本领,使得模型在处理复杂任务时具有更好的表现力。
  • RoPE的计算相对简朴,不须要复杂的矩阵运算。预计算频率向量和应用旋转变动的过程可以高效地实现,适合在实际应用中大规模摆设。
  • RoPE可以或许无缝集成到现有的Transformer架构中,不须要对模型布局进行大的修改。这种兼容性使得RoPE成为一种易于应用和推广的位置编码方法。
  • 在长序列处理任务中,传统的位置编码方法可能会碰到信息稀释或计算复杂度增长的题目。RoPE通过引入旋转变动,可以更好地保持长序列中的位置信息,使得模型在长序列任务中表现更加稳定和高效。
  • (这一点是我的料想)在高维向量中,方向是比模长更重要的量,通例位置编码直接在词嵌入上加上位置编码,相当于改变了模长,旋转位置编码改变了方向,实际上比通例位置编码多获得了一部分信息。
  下面这篇文章给出了公式原理和推导,解说非常详细:点击此处
在LLaMA中,RoPE使用下面的方式实现:
  1. def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):  
  2.     freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  
  3.     t = torch.arange(end, device=freqs.device)  # type: ignore  
  4.     freqs = torch.outer(t, freqs).float()  # type: ignore  
  5.     freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64  
  6.     return freqs_cis  
  7.   
  8.   
  9. def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):  
  10.     ndim = x.ndim  
  11.     assert 0 <= 1 < ndim  
  12.     assert freqs_cis.shape == (x.shape[1], x.shape[-1])  
  13.     shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]  
  14.     return freqs_cis.view(*shape)  
  15.   
  16.   
  17. def apply_rotary_emb(  
  18.     xq: torch.Tensor,  
  19.     xk: torch.Tensor,  
  20.     freqs_cis: torch.Tensor,  
  21. ) -> Tuple[torch.Tensor, torch.Tensor]:  
  22.     xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  
  23.     xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  
  24.     freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  
  25.     xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)  
  26.     xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)  
  27.     return xq_out.type_as(xq), xk_out.type_as(xk)
复制代码
下面的代码给出了参加旋转位置嵌入的留意力机制:
  1. class Attention(nn.Module):  
  2.     def __init__(self, args: ModelArgs):  
  3.         super().__init__()  
  4.   
  5.         self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()  
  6.         self.head_dim = args.dim // args.n_heads  
  7.   
  8.         self.wq = ColumnParallelLinear(  
  9.             args.dim,  
  10.             args.n_heads * self.head_dim,  
  11.             bias=False,  
  12.             gather_output=False,  
  13.             init_method=lambda x: x,  
  14.         )  
  15.         self.wk = ColumnParallelLinear(  
  16.             args.dim,  
  17.             args.n_heads * self.head_dim,  
  18.             bias=False,  
  19.             gather_output=False,  
  20.             init_method=lambda x: x,  
  21.         )  
  22.         self.wv = ColumnParallelLinear(  
  23.             args.dim,  
  24.             args.n_heads * self.head_dim,  
  25.             bias=False,  
  26.             gather_output=False,  
  27.             init_method=lambda x: x,  
  28.         )  
  29.         self.wo = RowParallelLinear(  
  30.             args.n_heads * self.head_dim,  
  31.             args.dim,  
  32.             bias=False,  
  33.             input_is_parallel=True,  
  34.             init_method=lambda x: x,  
  35.         )  
  36.   
  37.         self.cache_k = torch.zeros(  
  38.             (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
  39.         ).cuda()  
  40.         self.cache_v = torch.zeros(  
  41.             (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
  42.         ).cuda()  
  43.   
  44.     def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):  
  45.         bsz, seqlen, _ = x.shape  
  46.         xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)  
  47.   
  48.         xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
  49.         xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
  50.         xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
  51.   
  52.         xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)  
  53.   
  54.         self.cache_k = self.cache_k.to(xq)  
  55.         self.cache_v = self.cache_v.to(xq)  
  56.   
  57.         self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk  
  58.         self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv  
  59.   
  60.         keys = self.cache_k[:bsz, : start_pos + seqlen]  
  61.         values = self.cache_v[:bsz, : start_pos + seqlen]  
  62.   
  63.         xq = xq.transpose(1, 2)  
  64.         keys = keys.transpose(1, 2)  
  65.         values = values.transpose(1, 2)  
  66.         scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)  
  67.         if mask is not None:  
  68.             scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)  
  69.         scores = F.softmax(scores.float(), dim=-1).type_as(xq)  
  70.         output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)  
  71.         output = output.transpose(  
  72.             1, 2  
  73.         ).contiguous().view(bsz, seqlen, -1)  
  74.   
  75.         return self.wo(output)
复制代码
接下来给出LLaMA实现的全部代码:
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.  # This software may be used and distributed according to the terms of the GNU General Public License version 3.    from typing import Optional, Tuple  from dataclasses import dataclass  import math    import torch  from torch import nn  import torch.nn.functional as F    import fairscale.nn.model_parallel.initialize as fs_init  from fairscale.nn.model_parallel.layers import (      ParallelEmbedding,      RowParallelLinear,      ColumnParallelLinear,  )      @dataclass  class ModelArgs:      dim: int = 512      n_layers: int = 8      n_heads: int = 8      vocab_size: int = -1  # defined later by tokenizer      multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2      norm_eps: float = 1e-5        max_batch_size: int = 32      max_seq_len: int = 2048      class RMSNorm(torch.nn.Module):  
  2.     def __init__(self, dim: int, eps: float = 1e-6):  
  3.         super().__init__()  
  4.         self.eps = eps  
  5.         self.weight = nn.Parameter(torch.ones(dim))  
  6.   
  7.     def _norm(self, x):  
  8.         return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)  
  9.   
  10.     def forward(self, x):  
  11.         output = self._norm(x.float()).type_as(x)  
  12.         return output * self.weight
  13.       def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):  
  14.     freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  
  15.     t = torch.arange(end, device=freqs.device)  # type: ignore  
  16.     freqs = torch.outer(t, freqs).float()  # type: ignore  
  17.     freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64  
  18.     return freqs_cis  
  19.   
  20.   
  21. def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):  
  22.     ndim = x.ndim  
  23.     assert 0 <= 1 < ndim  
  24.     assert freqs_cis.shape == (x.shape[1], x.shape[-1])  
  25.     shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]  
  26.     return freqs_cis.view(*shape)  
  27.   
  28.   
  29. def apply_rotary_emb(  
  30.     xq: torch.Tensor,  
  31.     xk: torch.Tensor,  
  32.     freqs_cis: torch.Tensor,  
  33. ) -> Tuple[torch.Tensor, torch.Tensor]:  
  34.     xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  
  35.     xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  
  36.     freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  
  37.     xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)  
  38.     xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)  
  39.     return xq_out.type_as(xq), xk_out.type_as(xk)
  40.       class Attention(nn.Module):  
  41.     def __init__(self, args: ModelArgs):  
  42.         super().__init__()  
  43.   
  44.         self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()  
  45.         self.head_dim = args.dim // args.n_heads  
  46.   
  47.         self.wq = ColumnParallelLinear(  
  48.             args.dim,  
  49.             args.n_heads * self.head_dim,  
  50.             bias=False,  
  51.             gather_output=False,  
  52.             init_method=lambda x: x,  
  53.         )  
  54.         self.wk = ColumnParallelLinear(  
  55.             args.dim,  
  56.             args.n_heads * self.head_dim,  
  57.             bias=False,  
  58.             gather_output=False,  
  59.             init_method=lambda x: x,  
  60.         )  
  61.         self.wv = ColumnParallelLinear(  
  62.             args.dim,  
  63.             args.n_heads * self.head_dim,  
  64.             bias=False,  
  65.             gather_output=False,  
  66.             init_method=lambda x: x,  
  67.         )  
  68.         self.wo = RowParallelLinear(  
  69.             args.n_heads * self.head_dim,  
  70.             args.dim,  
  71.             bias=False,  
  72.             input_is_parallel=True,  
  73.             init_method=lambda x: x,  
  74.         )  
  75.   
  76.         self.cache_k = torch.zeros(  
  77.             (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
  78.         ).cuda()  
  79.         self.cache_v = torch.zeros(  
  80.             (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
  81.         ).cuda()  
  82.   
  83.     def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):  
  84.         bsz, seqlen, _ = x.shape  
  85.         xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)  
  86.   
  87.         xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
  88.         xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
  89.         xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
  90.   
  91.         xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)  
  92.   
  93.         self.cache_k = self.cache_k.to(xq)  
  94.         self.cache_v = self.cache_v.to(xq)  
  95.   
  96.         self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk  
  97.         self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv  
  98.   
  99.         keys = self.cache_k[:bsz, : start_pos + seqlen]  
  100.         values = self.cache_v[:bsz, : start_pos + seqlen]  
  101.   
  102.         xq = xq.transpose(1, 2)  
  103.         keys = keys.transpose(1, 2)  
  104.         values = values.transpose(1, 2)  
  105.         scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)  
  106.         if mask is not None:  
  107.             scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)  
  108.         scores = F.softmax(scores.float(), dim=-1).type_as(xq)  
  109.         output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)  
  110.         output = output.transpose(  
  111.             1, 2  
  112.         ).contiguous().view(bsz, seqlen, -1)  
  113.   
  114.         return self.wo(output)
  115.       class FeedForward(nn.Module):      def __init__(          self,          dim: int,          hidden_dim: int,          multiple_of: int,      ):          super().__init__()          hidden_dim = int(2 * hidden_dim / 3)          hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)            self.w1 = ColumnParallelLinear(              dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x          )          self.w2 = RowParallelLinear(              hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x          )          self.w3 = ColumnParallelLinear(              dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x          )        def forward(self, x):          return self.w2(F.silu(self.w1(x)) * self.w3(x))      class TransformerBlock(nn.Module):      def __init__(self, layer_id: int, args: ModelArgs):          super().__init__()          self.n_heads = args.n_heads          self.dim = args.dim          self.head_dim = args.dim // args.n_heads          self.attention = Attention(args)          self.feed_forward = FeedForward(              dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of          )          self.layer_id = layer_id          self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)          self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)        def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):          h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)          out = h + self.feed_forward.forward(self.ffn_norm(h))          return out      class Transformer(nn.Module):      def __init__(self, params: ModelArgs):          super().__init__()          self.params = params          self.vocab_size = params.vocab_size          self.n_layers = params.n_layers            self.tok_embeddings = ParallelEmbedding(              params.vocab_size, params.dim, init_method=lambda x: x          )            self.layers = torch.nn.ModuleList()          for layer_id in range(params.n_layers):              self.layers.append(TransformerBlock(layer_id, params))            self.norm = RMSNorm(params.dim, eps=params.norm_eps)          self.output = ColumnParallelLinear(              params.dim, params.vocab_size, bias=False, init_method=lambda x: x          )            self.freqs_cis = precompute_freqs_cis(              self.params.dim // self.params.n_heads, self.params.max_seq_len * 2          )        @torch.inference_mode()      def forward(self, tokens: torch.Tensor, start_pos: int):          _bsz, seqlen = tokens.shape          h = self.tok_embeddings(tokens)          self.freqs_cis = self.freqs_cis.to(h.device)          freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]            mask = None          if seqlen > 1:              mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)              mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)            for layer in self.layers:              h = layer(h, start_pos, freqs_cis, mask)          h = self.norm(h)          output = self.output(h[:, -1, :])  # only compute last logits          return output.float()
复制代码

CSDN独家福利

最后,感谢每一个认真阅读我文章的人,礼尚往来总是要有的,下面资料虽然不是什么很值钱的东西,如果你用得到的话可以直接拿走:


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
继续阅读请点击广告

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
回复

使用道具 举报

×
登录参与点评抽奖,加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表