详解如何从零构建Llama 3(含代码)!

打印 上一主题 下一主题

主题 576|帖子 576|积分 1728

大家好,本文将详细指导如何从零开始构建完整的Llama 3模型架构,并在自定义数据集上执行训练和推理。

[图1]:Llama 3架构展示训练和推理流程。因为官方Llama 3论文中未提供相干图表。所以此图为大概架构图,阅读本文后你应能绘制出更为精确的架构图。
本文目标

通过本文。你可以相识到:

  • 深入理解Llama 3模型各组件的底层工作原理。
  • 编写代码构建Llama 3的每个组件,并将它们组装成一个功能完整的Llama 3模型。
  • 编写代码利用新的自定义数据集训练模型。
  • 编写代码执行推理,使Llama 3模型可以或许根据输入提示生成新文本。
1、输入模块

如图1所示,输入模块包罗三个组件:文本/提示、分词器和嵌入
输入模块内部工作流程
让我们通过下图相识输入模块内的工作流程。

[图2]:输入模块流程图,展示提示、分词器和嵌入流程。
起首,单个或批量文本/提示被输入模型。例如:图中的"Hello World"。
输入模型的必须是数字格式,因为模型无法直接处理文本。分词器将这些文本/提示转换为标记ID(词汇表中标记的索引号表现)。我们将利用Tiny Shakespeare数据集构建词汇表并训练模型。Llama 3模型利用TikToken作为分词器,这是一种子词分词器。但是我们这个实现将利用字符级分词器。这样做的主要缘故原由是让我们可以或许自行构建词汇表和分词器,包罗编码和解码函数,这样可以深入理解底层工作原理并完全掌控代码。
每个标记ID将被转换为128维的嵌入向量(原始Llama 3 8B中为4096维)。然后这些嵌入将被通报到下一个解码器模块。
输入模块代码实现:
  1. \# 导入必要的库   
  2. import torch   
  3. from torch import nn   
  4. from torch.nn import functional as F   
  5.    
  6. import math   
  7. import numpy as np   
  8. import time   
  9. from dataclasses import dataclass   
  10. from typing import Optional, Tuple, List   
  11. import pandas as pd   
  12. from matplotlib import pyplot as plt  
  13.    
  14. \### 步骤1: 输入模块 ###   
  15.    
  16. \# 使用Tiny Shakespeare数据集实现字符级分词器。部分字符级分词器代码参考自Andrej Karpathy的GitHub仓库  
  17. \# (https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare\_char/prepare.py)  
  18. \# 加载tiny\_shakespeare数据文件 (https://github.com/tamangmilan/llama3/blob/main/tiny\_shakespeare.txt)   
  19.    
  20. device: str \= 'cuda' if torch.cuda.is\_available() else 'cpu'   \# 根据可用性分配设备为cuda或cpu   
  21.    
  22. \# 加载tiny\_shakespeare数据文件   
  23. with open('tiny\_shakespeare.txt', 'r') as f:   
  24.    data \= f.read()   
  25.    
  26. \# 通过提取tiny\_shakespeare数据中的所有唯一字符准备词汇表   
  27. vocab \= sorted(list(set(data)))   
  28.    
  29. \# 训练Llama 3模型需要额外的标记,如<|begin\_of\_text|>、<|end\_of\_text|>和<|pad\_id|>,将它们添加到词汇表中   
  30. vocab.extend(\['<|begin\_of\_text|>','<|end\_of\_text|>','<|pad\_id|>'\])   
  31. vocab\_size \= len(vocab)   
  32.    
  33. \# 创建字符与词汇表中对应整数索引之间的映射。  
  34. \# 这对于构建分词器的编码和解码函数至关重要。  
  35. itos \= {i:ch for i, ch in enumerate(vocab)}   
  36. stoi \= {ch:i for i, ch in enumerate(vocab)}   
  37.    
  38. \# 分词器编码函数:输入字符串,输出整数列表   
  39. def encode(s):   
  40.    return \[stoi\[ch\] for ch in s\]   
  41.    
  42. \# 分词器解码函数:输入整数列表,输出字符串   
  43. def decode(l):   
  44.    return ''.join(itos\[i\] for i in l)   
  45.    
  46. \# 定义稍后在模型训练中使用的张量标记变量   
  47. token\_bos \= torch.tensor(\[stoi\['<|begin\_of\_text|>'\]\], dtype\=torch.int, device\=device)   
  48. token\_eos \= torch.tensor(\[stoi\['<|end\_of\_text|>'\]\], dtype\=torch.int, device\=device)   
  49. token\_pad \= torch.tensor(\[stoi\['<|pad\_id|>'\]\], dtype\=torch.int, device\=device)   
  50.    
  51. prompts \= "Hello World"   
  52. encoded\_tokens \= encode(prompts)   
  53. decoded\_text \= decode(encoded\_tokens)   
  54.    
  55. \### 输入模块代码测试 ###   
  56. \# 取消下面的三重引号来执行测试   
  57. """   
  58. print(f"Shakespeare文本字符长度: {len(data)}")   
  59. print(f"词汇表内容: {''.join(vocab)}\\n")   
  60. print(f"词汇表大小: {vocab\_size}")   
  61. print(f"编码后的标记: {encoded\_tokens}")   
  62. print(f"解码后的文本: {decoded\_text}")   
  63. """   
  64. \### 测试结果: ###   
  65. """   
  66. Shakespeare文本字符长度: 1115394   
  67. 词汇表内容:     
  68.   !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz<|begin\_of\_text|><|end\_of\_text|><|pad\_id|>   
  69.    
  70. 词汇表大小: 68   
  71. 编码后的标记: \[20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42\]   
  72. 解码后的文本: Hello World   
  73. """
复制代码
2、解码器模块

参照图1的架构图,解码器模块包罗以下子组件:


  • RMS归一化
  • 旋转位置编码
  • KV缓存
  • 分组查询注意力
  • 前馈网络
  • 解码器块
RMS归一化(Root Mean Square Normalization)

RMSNorm的须要性
从图1可以看出,输入模块的输出(嵌入向量)经过RMSNorm模块。这是因为嵌入向量具有多个维度(Llama3-8b中为4096维),大概出现差别范围的值。这会导致模型梯度爆炸或消失,从而导致收敛迟钝以致发散。而RMSNorm将这些值归一化到肯定范围,有助于稳定和加快训练过程。这使得梯度具有更一致的幅度,从而加快模型收敛。
RMSNorm的工作原理

[图3]:对外形为[3,3]的输入嵌入应用RMSNorm
类似于层归一化,RMSNorm沿嵌入特征或维度应用。上图中的嵌入外形为[3,3],意味着每个标记有3个维度。
示例:对第一个标记X1的嵌入应用RMSNorm:
X1标记在每个维度上的值(x11、x12和x13)分别除以所有这些值的均方根。公式如图3所示。
为制止除以零并包管数值稳定性,在均方根中参加一个小常数E(Epsilon)。乘以一个缩放参数Gamma (Y)。每个特征都有一个独特的Gamma参数(如图中d1维度的Y1、d2维度的Y2和d3维度的Y3),这是一个学习参数,可以向上或向下缩放以进一步稳定归一化。gamma参数初始化为1(如上面的计算所示)。
如示例所示,嵌入值本来较大且分布范围宽。应用RMSNorm后,值变小且范围缩小。计算利用现实的RMSNorm函数完成。
RMSNorm相比层归一化的优势
如上例所示没有计算任何均值或方差,而这在层归一化中是必须的。所以RMSNorm通过制止计算均值和方差减少了计算开销。根据作者的研究,RMSNorm在不影响准确性的同时提供了性能优势。
RMSNorm代码实现:
  1. \# 步骤2: 解码器模块   
  2. \# 注:由于Llama 3模型由Meta开发,为了与他们的代码库保持一致并考虑未来兼容性,  
  3. \# 我将使用Meta GitHub上的大部分代码,并进行必要的修改以实现我们的目标。  
  4.    
  5. \# 定义参数数据类:我们将在模型构建、训练和推理过程中使用这些参数。  
  6. \# 注:为了更快地看到训练和推理结果,而不是专注于高准确性,我们对大多数参数采用较低的值,  
  7. \# 这些值在Llama 3模型中设置得更高。  
  8.    
  9. @dataclass   
  10. class ModelArgs:   
  11.      dim: int \= 512              \# 嵌入维度   
  12.      n\_layers: int \= 8           \# 模型解码器块的数量   
  13.      n\_heads: int \= 8            \# 查询嵌入的头数   
  14.      n\_kv\_heads: int \= 4         \# 键和值嵌入的头数   
  15.      vocab\_size: int \= len(vocab) \# 词汇表长度   
  16.      multiple\_of: int \= 256        \# 用于计算前馈网络维度   
  17.      ffn\_dim\_multiplier: Optional\[float\] \= None  \# 用于计算前馈网络维度   
  18.      norm\_eps: float \= 1e-5                       \# RMSNorm计算的默认Epsilon值   
  19.      rope\_theta: float \= 10000.0   \# RePE计算的默认theta值   
  20.    
  21.      max\_batch\_size: int \= 10     \# 最大批量大小   
  22.      max\_seq\_len: int \= 256         \# 最大序列长度   
  23.    
  24.      epochs: int \= 2500             \# 总训练迭代次数   
  25.      log\_interval: int \= 10        \# 打印日志和损失值的间隔数      
  26.      device: str \= 'cuda' if torch.cuda.is\_available() else 'cpu'   \# 根据可用性分配设备为cuda或cpu   
  27.    
  28. \## 步骤2a: RMSNorm   
  29.    
  30. class RMSNorm(nn.Module):   
  31.    def \_\_init\_\_(self, dim: int, eps: float \= 1e-6):   
  32.      super().\_\_init\_\_()   
  33.      device \= ModelArgs.device   
  34.      self.eps \= eps   
  35.      \# 缩放参数gamma,初始化为1,参数数量等于dim的大小   
  36.      self.weight \= nn.Parameter(torch.ones(dim).to(device))   
  37.    
  38.    def \_norm(self, x):   
  39.      return x \* torch.rsqrt(x.pow(2).mean(dim\=-1, keepdim\=True) + self.eps).to(device)   
  40.    
  41.    def forward(self, x):   
  42.      #形状: x\[bs,seq,dim\]   
  43.      output \= self.\_norm(x.float()).type\_as(x)   
  44.    
  45.      #形状: x\[bs,seq,dim\] -> x\_norm\[bs,seq,dim\]   
  46.      return output \* self.weight   
  47.    
  48. \### RMSNorm代码测试 ###   
  49. \# 取消下面的三重引号来执行测试   
  50. """   
  51. x = torch.randn((ModelArgs.max\_batch\_size, ModelArgs.max\_seq\_len, ModelArgs.dim), device=device)   
  52. rms\_norm = RMSNorm(dim=ModelArgs.dim)   
  53. x\_norm = rms\_norm(x)   
  54.    
  55. print(f"x的形状: {x.shape}")   
  56. print(f"x\_norm的形状: {x\_norm.shape}")   
  57. """   
  58. \### 测试结果: ###   
  59. """   
  60. x的形状: torch.Size(\[10, 256, 512\])   
  61. x\_norm的形状: torch.Size(\[10, 256, 512\])   
  62. """
复制代码
旋转位置编码(Rotary Positional Encoding, RoPE)

回首之前的步调,我们已将输入文本转换为嵌入,并对嵌入应用了RMSNorm。然而,这里存在一个问题:假设输入文本是"I love apple"或"apple love I",模型会将两个句子视为相同并以相同方式学习。这是因为嵌入中没有为模型定义顺序信息。因此对于任何语言模型来说,保持标记的顺序至关重要。在Llama 3模型架构中,引入了旋转位置编码(RoPE)来定义句子中每个标记的位置,这不仅维护了顺序,还保存了句子中标记的相对位置信息。
旋转位置编码的工作原理
RoPE是一种位置编码方法,它通过添加绝对位置信息以及包罗标记之间的相对位置信息来编码嵌入,从而维护句子中标记的顺序。它通过利用一个特殊的旋转矩阵来旋转给定的嵌入来执行编码操作。这种利用旋转矩阵的简洁而强大的数学推导是RoPE的焦点。

[图4]:应用于2维向量的旋转矩阵
上图展示了旋转矩阵应用于2维向量的环境。Llama 3模型中的维度数是4096,远高于此。我们详细先容如何对更高维度的嵌入应用旋转。

[图5]:RoPE应用于嵌入的示例
嵌入的旋转涉及每个嵌入位置(m)值和theta (θ)对每对嵌入维度的乘法。这就是RoPE如何通过实现旋转矩阵来捕获绝对位置和相对位置信息的方式。
注意:在执行旋转之前,需要将旋转矩阵转换为极坐标情势,并将嵌入向量转换为复数。旋转完成后,旋转后的嵌入需要转换回实数以举行注意力操作。另外RoPE仅应用于查询和键嵌入,不实用于值嵌入。
RoPE的代码实现:
  1. \## 步骤2b: RoPE实现   
  2. def precompute\_freqs\_cis(dim:int, seq\_len: int, theta: float\=10000.0):   
  3.    \# 计算每对维度的Theta值,即dim/2   
  4.    device \= ModelArgs.device   
  5.    freqs \= 1.0 / (theta \*\* (torch.arange(0, dim, 2,device\=device)\[:(dim//2)\].float()/dim))   
  6.    
  7.    \# 计算序列中位置(m)的范围   
  8.    t \= torch.arange(seq\_len, dtype\=torch.float32, device\=device)   
  9.    
  10.    \# freqs给出序列中所有标记位置的Theta值范围   
  11.    freqs \= torch.outer(t, freqs).to(device)   
  12.    
  13.    \# 这是需要转换为极坐标形式的旋转矩阵,以便对嵌入执行旋转   
  14.    freqs\_cis \= torch.polar(torch.ones\_like(freqs).to(device), freqs).to(device)   
  15.    return freqs\_cis   
  16.    
  17. def reshape\_for\_broadcast(freqs\_cis, x):   
  18.    ndim \= x.ndim   
  19.    assert 0<=1<ndim   
  20.    assert freqs\_cis.shape \== (x.shape\[1\],x.shape\[\-1\]), "freqs\_cis的最后两个维度必须与x匹配"   
  21.    shape \= \[d if i\==1 or i\==ndim\-1 else 1 for i,d in enumerate(x.shape)\]   
  22.    return freqs\_cis.view(\*shape)   
  23.    
  24. def apply\_rotary\_emb(xq: torch.Tensor, xk: torch.Tensor, freqs\_cis: torch.Tensor)\->Tuple\[torch.Tensor, torch.Tensor\]:   
  25.    device \= ModelArgs.device   
  26.    \# 同时对查询和键嵌入应用旋转位置编码   
  27.    \# 首先:xq和xk嵌入的最后一个维度需要重塑为一对。因为旋转矩阵应用于每对维度。   
  28.    \# 其次:将xq和xk转换为复数,因为旋转矩阵只适用于复数   
  29.    xq\_ \= torch.view\_as\_complex(xq.float().reshape(\*xq.shape\[:\-1\], \-1, 2)).to(device)    #xq\_:\[bsz, seq\_len, n\_heads, head\_dim/2\]   
  30.    xk\_ \= torch.view\_as\_complex(xk.float().reshape(\*xk.shape\[:\-1\], \-1, 2)).to(device)    #xk\_:\[bsz, seq\_len, n\_heads, head\_dim/2\]   
  31.    
  32.    \# 旋转矩阵(freqs\_cis)在seq\_len(dim=1)和head\_dim(dim=3)维度上应与嵌入匹配   
  33.    \# 此外,freqs\_cis的形状应与xq和xk相同,因此将freqs\_cis的形状从\[seq\_len,head\_dim\]改变为\[1,seq\_len,1,head\_dim\]   
  34.    freqs\_cis \= reshape\_for\_broadcast(freqs\_cis, xq\_)   
  35.    
  36.    \# 最后,通过与freqs\_cis相乘执行旋转操作。   
  37.    \# 旋转完成后,将xq\_out和xk\_out转换回实数并返回   
  38.    xq\_out \= torch.view\_as\_real(xq\_ \* freqs\_cis).flatten(3).to(device) #xq\_out:\[bsz, seq\_len, n\_heads, head\_dim\]   
  39.    xk\_out \= torch.view\_as\_real(xk\_ \* freqs\_cis).flatten(3).to(device) #xk\_out:\[bsz, seq\_len, n\_heads, head\_dim\]   
  40.    return xq\_out.type\_as(xq), xk\_out.type\_as(xk)   
  41.    
  42. \### RoPE代码测试 ###   
  43. \# 注:x\_norm在RMSNorm测试中计算,这里用于测试。   
  44. \# 取消下面的三重引号来执行测试   
  45. """   
  46. head\_dim = ModelArgs.dim//ModelArgs.n\_heads   
  47. wq = nn.Linear(ModelArgs.dim, ModelArgs.n\_heads \* head\_dim, bias=False, device=device)   
  48. wk = nn.Linear(ModelArgs.dim, ModelArgs.n\_kv\_heads \* head\_dim, bias=False, device=device)   
  49. xq = wq(x\_norm)   
  50. xk = wk(x\_norm)   
  51. print(f"xq.shape: {xq.shape}")   
  52. print(f"xk.shape: {xk.shape}")   
  53.    
  54. xq = xq.view(xq.shape\[0\],xq.shape\[1\],ModelArgs.n\_heads, head\_dim)   
  55. xk = xk.view(xk.shape\[0\],xk.shape\[1\],ModelArgs.n\_kv\_heads, head\_dim)   
  56. print(f"xq.re-shape: {xq.shape}")   
  57. print(f"xk.re-shape: {xk.shape}")   
  58.    
  59. freqs\_cis = precompute\_freqs\_cis(dim=head\_dim, seq\_len=ModelArgs.max\_seq\_len)   
  60. print(f"freqs\_cis.shape: {freqs\_cis.shape}")   
  61.    
  62. xq\_rotate, xk\_rotate = apply\_rotary\_emb(xq, xk, freqs\_cis)   
  63. print(f"xq\_rotate.shape: {xq\_rotate.shape}")   
  64. print(f"xk\_rotate.shape: {xk\_rotate.shape}")   
  65. """   
  66. \### 测试结果: ###   
  67. """   
  68. xq.shape: torch.Size(\[10, 256, 512\])   
  69. xk.shape: torch.Size(\[10, 256, 256\])   
  70. xq.re-shape: torch.Size(\[10, 256, 8, 64\])   
  71. xk.re-shape: torch.Size(\[10, 256, 4, 64\])   
  72. freqs\_cis.shape: torch.Size(\[256, 32\])   
  73. xq\_rotate.shape: torch.Size(\[10, 256, 8, 64\])   
  74. xk\_rotate.shape: torch.Size(\[10, 256, 4, 64\])   
  75. """
复制代码
KV缓存(仅用于推理)

在Llama 3架构中,推理阶段引入了KV缓存的概念,用于以键和值缓存的情势存储先前生成的标记。这些缓存用于计算自注意力以生成下一个标记。只缓存键和值标记,而不缓存查询标记,因此称为KV缓存。
KV缓存的须要性
让我们通过下图来理解KV缓存的重要性。

[图6]:KV缓存实现
图中的A块:在生成output3标记时,仍在计算先前的输出标记(output1, output2),这是不须要的。这在注意力计算期间导致了额外的矩阵乘法,显著增长了计算资源的利用。
图中的B块:输出标记更换了查询嵌入中的输入标记。KV缓存存储了先前生成的标记。在注意力分数计算期间,我们只需要利用查询中的1个标记,并利用键和值缓存中的先前标记。这将矩阵乘法从A块的3x3减少到B块的1x3,减少了约66%。在现实应用中,对于巨大的序列长度和批量大小,这将显著减少计算资源的利用。
分组查询注意力

分组查询注意力与之前模型(如Llama 1)中利用的多头注意力相似,唯一的区别在于为查询和键/值利用单独的头。分配给查询的头数是键和值头数的n倍。让我们通过图表来进一步理解。

[图7]:分组查询注意力和多头注意力对比
在给定的图中,多头注意力在所有查询、键和值中都有相等数目标头,即n_heads = 8。
分组查询注意力块有8个查询头(n_heads)和4个键和值头(n_kv_heads),这是查询头数目标一半。
分组查询注意力的优势
尽管多头注意力已经表现出色,引入分组查询注意力是有其特定缘故原由。我们先回首KV缓存,KV缓存确实大大减少了计算资源的利用。但是随着KV缓存存储越来越多的先前标记,内存利用会显著增长。这对模型性能和计算成本都不利。**所以引入了分组查询注意力。**减少K和V的头数会减少需要存储的参数数目,从而减少内存利用。多项测试结果表明,利用这种方法模型的准确性仍保持在相近的范围内。
注意力模块的代码实现:
  1. \## 注意力模块 \[步骤2c: KV缓存; 步骤2d: 分组查询注意力\]   
  2. \## 如前所述,命名约定遵循原始Meta LLama3 GitHub   
  3.    
  4. class Attention(nn.Module):   
  5.    def \_\_init\_\_(self, args: ModelArgs):   
  6.      super().\_\_init\_\_()   
  7.      self.args \= args   
  8.      \# 嵌入维度   
  9.      self.dim \= args.dim   
  10.      \# 分配给查询的头数   
  11.      self.n\_heads \= args.n\_heads   
  12.      \# 分配给键和值的头数。如果为"None",则数量与查询相同。   
  13.      self.n\_kv\_heads \= args.n\_heads if args.n\_kv\_heads is None else args.n\_kv\_heads   
  14.      \# 每个头相对于模型维度的维度   
  15.      self.head\_dim \= args.dim // args.n\_heads   
  16.      \# 重复次数,以使键、值头数与查询头数匹配   
  17.      self.n\_rep \= args.n\_heads // args.n\_kv\_heads   
  18.    
  19.      \# 初始化键、查询、值和输出的权重。注意q和kv的权重out\_feature值基于其头数   
  20.      self.wq \= nn.Linear(self.dim, self.n\_heads \* self.head\_dim, bias\=False, device\=device)   
  21.      self.wk \= nn.Linear(self.dim, self.n\_kv\_heads \* self.head\_dim, bias\=False, device\=device)   
  22.      self.wv \= nn.Linear(self.dim, self.n\_kv\_heads \* self.head\_dim, bias\=False, device\=device)   
  23.      self.wo \= nn.Linear(self.n\_heads \* self.head\_dim, self.dim, bias\=False, device\=device)   
  24.    
  25.      \# 初始化缓存以在开始时存储键、值 (KV缓存实现)   
  26.      self.cache\_k \= torch.zeros((args.max\_batch\_size, args.max\_seq\_len, self.n\_kv\_heads, self.head\_dim), device\=args.device)   
  27.      self.cache\_v \= torch.zeros((args.max\_batch\_size, args.max\_seq\_len, self.n\_kv\_heads, self.head\_dim), device\=args.device)   
  28.    
  29.    def forward(self, x: torch.Tensor, start\_pos, inference):   
  30.      \# 输入嵌入的形状: \[bsz,seq\_len,dim\]   
  31.      bsz, seq\_len, \_ \= x.shape   
  32.      \# 掩码将在"训练"期间使用,由于使用KV缓存,"推理"不需要掩码。  
  33.      mask \= None   
  34.    
  35.      xq \= self.wq(x)  #x\[bsz,seq\_len,dim\]\*wq\[dim,n\_heads \* head\_dim\] -> q\[bsz,seq\_len,n\_heads \* head\_dim\]   
  36.      xk \= self.wk(x)  #x\[bsz,seq\_len,dim\]\*wq\[dim,n\_kv\_heads \* head\_dim\] -> k\[bsz,seq\_len,n\_kv\_heads \* head\_dim\]   
  37.      xv \= self.wv(x)  #x\[bsz,seq\_len,dim\]\*wq\[dim,n\_kv\_heads \* head\_dim\] -> v\[bsz,seq\_len,n\_kv\_heads \* head\_dim\]   
  38.    
  39.      \# 根据头数重塑查询、键和值 (分组查询注意力实现)   
  40.      xq \= xq.view(bsz, seq\_len, self.n\_heads, self.head\_dim)      #xq\[bsz,seq\_len,n\_heads, head\_dim\]   
  41.      xk \= xk.view(bsz, seq\_len, self.n\_kv\_heads, self.head\_dim)   #xk\[bsz,seq\_len,n\_kv\_heads, head\_dim\]   
  42.      xv \= xv.view(bsz, seq\_len, self.n\_kv\_heads, self.head\_dim)   #xv\[bsz,seq\_len,n\_kv\_heads, head\_dim\]   
  43.    
  44.      \# 模型 - 推理模式: kv-cache仅在推理模式下启用   
  45.      if inference:   
  46.        \# 计算序列中每个位置的旋转矩阵   
  47.        freqs\_cis \= precompute\_freqs\_cis(dim\=self.head\_dim, seq\_len\=self.args.max\_seq\_len \* 2)   
  48.        \# 在推理过程中,我们应该只取从当前标记位置开始的旋转矩阵范围   
  49.        freqs\_cis \= freqs\_cis\[start\_pos : start\_pos + seq\_len\]   
  50.        \# 将RoPE应用于查询和键嵌入   
  51.        xq, xk \= apply\_rotary\_emb(xq, xk, freqs\_cis)   
  52.    
  53.        self.cache\_k \= self.cache\_k.to(xq)   
  54.        self.cache\_v \= self.cache\_v.to(xq)   
  55.        \# 将键和值标记嵌入存储到它们各自的缓存中 \[KV缓存实现\]   
  56.        self.cache\_k\[:bsz, start\_pos:start\_pos + seq\_len\] \= xk   
  57.        self.cache\_v\[:bsz, start\_pos:start\_pos + seq\_len\] \= xv   
  58.    
  59.        \# 为注意力计算分配所有直到当前标记位置的先前标记嵌入给键和值变量   
  60.        keys \= self.cache\_k\[:bsz, :start\_pos + seq\_len\]   
  61.        values \= self.cache\_v\[:bsz, :start\_pos + seq\_len\]   
  62.    
  63.        \# 此时,键和值的形状与查询嵌入不同,但为了计算注意力分数,它们必须相同   
  64.        \# 使用repeat\_kv函数使键、值的形状与查询形状相同   
  65.        keys \= repeat\_kv(keys, self.n\_rep)      #keys\[bsz,seq\_len,n\_heads,head\_dim\]   
  66.        values \= repeat\_kv(values, self.n\_rep)  #values\[bsz,seq\_len,n\_heads,head\_dim\]   
  67.    
  68.      \# 模式 - 训练模式: 未实现KV-Cache   
  69.      else:   
  70.        \# 计算旋转矩阵并将RoPE应用于训练的查询和键   
  71.        freqs\_cis \= precompute\_freqs\_cis(dim\=self.head\_dim, seq\_len\=self.args.max\_seq\_len)   
  72.    
  73.        #xq\[bsz,seq\_len,n\_heads, head\_dim\], xk\[bsz,seq\_len,n\_heads, head\_dim\]   
  74.        xq, xk \= apply\_rotary\_emb(xq, xk, freqs\_cis)   
  75.    
  76.        \# 使用repeat\_kv函数使键、值的形状与查询形状相同   
  77.        #keys\[bsz,seq\_len,n\_heads,head\_dim\], #values\[bsz,seq\_len,n\_heads,head\_dim\]   
  78.        keys \= repeat\_kv(xk, self.n\_rep)   
  79.        values \= repeat\_kv(xv, self.n\_rep)   
  80.    
  81.        \# 对于训练模式,我们将计算掩码并稍后应用于注意力分数   
  82.        mask \= torch.full((seq\_len, seq\_len),float("-inf"),device\=self.args.device)   
  83.        mask \= torch.triu(mask, diagonal\=1).to(self.args.device)   
  84.    
  85.      \# 为了计算注意力,我们需要执行转置操作来重塑所有查询、键和值,将头部放在维度1,序列放在维度2   
  86.      xq \= xq.transpose(1,2)                  #xq\[bsz,n\_heads,seq\_len,head\_dim\]   
  87.      keys \= keys.transpose(1,2)              #keys\[bsz,n\_heads,seq\_len,head\_dim\]   
  88.      values \= values.transpose(1,2)          #values\[bsz,n\_heads,seq\_len,head\_dim\]   
  89.    
  90.      \# 计算注意力分数   
  91.      scores \= torch.matmul(xq, keys.transpose(2,3)).to(self.args.device)/math.sqrt(self.head\_dim)   
  92.      if mask is not None:   
  93.        scores \= scores + mask   
  94.    
  95.      \# 对注意力分数应用softmax   
  96.      scores \= F.softmax(scores.float(), dim\=-1).type\_as(xq)   
  97.      \# 注意力分数与值的矩阵乘法   
  98.      output \= torch.matmul(scores, values).to(self.args.device)   
  99.    
  100.      \# 我们得到了每个头部的上下文嵌入   
  101.      \# 所有头部需要重塑回来并组合,以给出单个上下文注意力输出   
  102.      \# 形状变化: output\[bsz,n\_heads,seq\_len,head\_dim\] -> output\[bsz,seq\_len, n\_heads,head\_dim\] -> output\[bsz,seq\_len, n\_heads \* head\_dim\]   
  103.      output \= output.transpose(1,2).contiguous().view(bsz, seq\_len, \-1)   
  104.    
  105.      \# 形状: output \[bsz,seq\_len,dim\]   
  106.      return self.wo(output)   
  107.    
  108. \# 如果键/值头的数量少于查询头,此函数使用所需的重复次数扩展键/值嵌入   
  109. def repeat\_kv(x:torch.Tensor, n\_rep: int)\-> torch.Tensor:   
  110.    bsz, seq\_len, n\_kv\_heads, head\_dim \= x.shape   
  111.    if n\_rep \== 1:   
  112.      return x   
  113.    return (   
  114.        x\[:,:,:,None,:\]   
  115.        .expand(bsz,seq\_len,n\_kv\_heads,n\_rep, head\_dim)   
  116.        .reshape(bsz,seq\_len,n\_kv\_heads \* n\_rep, head\_dim)   
  117.    )   
  118.    
  119. \### 测试: Repeat\_kv函数 ###   
  120. \# 注: xk, x\_norm已在RoPE, RMSNorm测试中计算,这里用于测试   
  121. \# 取消下面的三重引号来执行测试   
  122. """   
  123. n\_rep = ModelArgs.n\_heads // ModelArgs.n\_kv\_heads   
  124. keys = repeat\_kv(xk, n\_rep)   
  125. print(f"xk.shape: {xk.shape}")   
  126. print(f"keys.shape: {keys.shape}")   
  127.    
  128. \## 测试: Attention函数   
  129. \# 取消下面的三重引号来执行测试   
  130.    
  131. attention = Attention(ModelArgs)   
  132. x\_out = attention(x\_norm,start\_pos=0, inference=False)   
  133. print(f"x\_out.shape: {x\_out.shape}")   
  134. """   
  135. \### 测试结果: ###   
  136. """   
  137. xk.shape: torch.Size(\[10, 256, 4, 64\])   
  138. keys.shape: torch.Size(\[10, 256, 8, 64\])   
  139. x\_out.shape: torch.Size(\[10, 256, 512\])   
  140. """
复制代码
前馈网络 (利用SwiGLU激活函数)

如图1所示,注意力输出起首经过RMSNorm,然后输入前馈网络。在前馈网络中,注意力输出嵌入会在其潜伏层中扩展到更高维度,学习标记的更复杂特征。
为什么选择SwiGLU而非ReLU

[图8]:带有SwiGLU函数的前馈网络
如图所示,SwiGLU函数在正轴上的行为与ReLU相似。然而,在负轴上,SwiGLU输出一些负值,这在学习较小值时大概有用,而不是像ReLU那样在负轴上为平坦的0。根据作者的研究,利用SwiGLU的性能优于ReLU,因此被选用。
前馈网络的代码实现:
  1. \## 步骤2e: 前馈网络 (SwiGLU激活)   
  2. class FeedForward(nn.Module):   
  3.    def \_\_init\_\_(self, dim:int, hidden\_dim:int, multiple\_of:int, ffn\_dim\_multiplier: Optional\[float\]):   
  4.      super().\_\_init\_\_()   
  5.      \# 模型嵌入维度   
  6.      self.dim \= dim   
  7.    
  8.      \# 我们必须使用Meta提供的隐藏维度计算方法,这是该模型的理想设置   
  9.      \# 隐藏维度的计算方式使其是256的倍数   
  10.      hidden\_dim \= int(2 \* hidden\_dim/3)   
  11.      if ffn\_dim\_multiplier is not None:   
  12.        hidden\_dim \= int(ffn\_dim\_multiplier \* hidden\_dim)   
  13.      hidden\_dim \= multiple\_of \* ((hidden\_dim + multiple\_of \- 1) // multiple\_of)   
  14.    
  15.      \# 定义隐藏层权重   
  16.      self.w1 \= nn.Linear(self.dim, hidden\_dim, bias\=False, device\=device)   
  17.      self.w2 \= nn.Linear(hidden\_dim, self.dim, bias\=False, device\=device)   
  18.      self.w3 \= nn.Linear(self.dim, hidden\_dim, bias\=False, device\=device)   
  19.    
  20.    def forward(self, x):   
  21.      \# 形状: \[bsz,seq\_len,dim\]   
  22.      return self.w2(F.silu(self.w1(x)) \* self.w3(x))   
  23.    
  24. \### 测试: 前馈模块 ###   
  25. \# 注: x\_out已在Attention测试中计算,这里用于测试   
  26. \# 取消下面的三重引号来执行测试   
  27. """   
  28. feed\_forward = FeedForward(ModelArgs.dim, 4 \* ModelArgs.dim, ModelArgs.multiple\_of, ModelArgs.ffn\_dim\_multiplier)   
  29. x\_out = rms\_norm(x\_out)   
  30. x\_out = feed\_forward(x\_out)   
  31. print(f"前馈输出: x\_out.shape: {x\_out.shape}")   
  32. """   
  33.    
  34. \### 测试结果: ###   
  35. """   
  36. 前馈输出: x\_out.shape: torch.Size(\[10, 256, 512\])   
  37. """
复制代码
解码器块

如图1所示,解码器块由多个子组件组成,我们在前面的部分中已经实现了这些组件。以下是解码器块内举行的徐徐操作:
1、来自输入模块的嵌入起首经过注意力-RMSNorm,然后输入分组查询注意力模块。
2、同时,来自输入模块的原始嵌入与注意力输出相加。
3、然后,这个结果经过前馈-RMSNorm,输入前馈网络模块。
4、前馈网络的输出再次与步调2的结果相加。
5、终极输出被称为解码器输出。这个解码器输出然后作为输入通报给下一个解码器块。这个过程在接下来的31个解码器块中重复。第32个解码器块的终极输出然后通报到输出模块。
解码器块的代码实现:
  1. \## 步骤2f: 解码器块。类名为TransformerBlock,以匹配Meta Llama 3代码库   
  2.    
  3. class TransformerBlock(nn.Module):   
  4.    def \_\_init\_\_(self, args: ModelArgs):   
  5.      super().\_\_init\_\_()   
  6.      self.args \= args   
  7.      \# 初始化注意力的RMSNorm   
  8.      self.attention\_norm \= RMSNorm(dim\=args.dim, eps \= args.norm\_eps)   
  9.      \# 初始化注意力类   
  10.      self.attention \= Attention(args)   
  11.      \# 初始化前馈网络的RMSNorm   
  12.      self.ff\_norm \= RMSNorm(dim\=args.dim, eps \= args.norm\_eps)   
  13.      \# 初始化前馈网络类   
  14.      self.feedforward \= FeedForward(args.dim, 4 \* args.dim, args.multiple\_of, args.ffn\_dim\_multiplier)   
  15.    
  16.    def forward(self, x, start\_pos, inference):   
  17.      \# start\_pos: 推理模式下的标记位置, inference: True表示推理模式,False表示训练模式   
  18.      \# 1) 将输入嵌入传递给attention\_norm,然后传递给注意力模块   
  19.      \# 2) 注意力的输出与原始输入(归一化前)相加   
  20.      h \= x + self.attention(self.attention\_norm(x), start\_pos,inference)   
  21.    
  22.      \# 1) 将注意力输出传递给ff\_norm,然后传递给前馈网络   
  23.      \# 2) 前馈网络的输出与注意力输出(ff\_norm前)相加   
  24.      out \= h + self.feedforward(self.ff\_norm(h))   
  25.      \# 形状: \[bsz,seq\_len,dim\]   
  26.      return out   
  27.    
  28. \### 测试: TransformerBlock ###   
  29. \# 取消下面的三重引号来执行测试   
  30. """   
  31. x = torch.randn((ModelArgs.max\_batch\_size, ModelArgs.max\_seq\_len, ModelArgs.dim), device=device)   
  32. transformer\_block = TransformerBlock(ModelArgs)   
  33. transformer\_block\_out = transformer\_block(x,start\_pos=0, inference=False)   
  34. print(f"transformer\_block\_out.shape: {transformer\_block\_out.shape}")   
  35. """   
  36.    
  37. \### 测试结果: ###   
  38. """   
  39. transformer\_block\_out.shape: torch.Size(\[10, 64, 128\])   
  40. """
复制代码
3、输出模块

最后一个解码器块的输出将传入输出模块。它起首经过RMSNorm处理,然后传入线性层生成logits。接下来根据模式的差别,会执行以下两种操作之一:
如果是推理模式,计算top_p概率并生成下一个标记。如果到达最大生发展度或生成的下一个标记为句子结束标记,则停止生成。
如果是训练模式,利用目标标签计算丧失,并重复训练直到到达最大epoch数。
下图展示了输出模块的流程:

[图9]:Llama 3在训练和推理模式下的输出流程图
终极的Llama 3模型实现
我们将组合三个模块(输入模块、解码器模块和输出模块)的所有组件。这就构成了我们的完整Llama 3模型。
  1. \## 步骤3: 输出模块   
  2. \# 这是Llama 3模型。类名保持为Transformer以匹配Meta Llama 3模型   
  3.    
  4. class Transformer(nn.Module):   
  5.    def \_\_init\_\_(self, params: ModelArgs):   
  6.      super().\_\_init\_\_()   
  7.      \# 设置params变量中的所有ModelArgs   
  8.      self.params \= params   
  9.      \# 从输入模块初始化嵌入类   
  10.      self.tok\_embeddings \= nn.Embedding(params.vocab\_size, params.dim)   
  11.    
  12.      \# 初始化解码器块并将其存储在ModuleList中     
  13.      \# 这是因为我们的Llama 3模型中有4个解码器块 (官方Llama 3有32个块)   
  14.      self.layers \= nn.ModuleList()   
  15.      for layer\_id in range(params.n\_layers):   
  16.        self.layers.append(TransformerBlock(args\=params))   
  17.    
  18.      \# 为输出模块初始化RMSNorm   
  19.      self.norm \= RMSNorm(params.dim, eps \= params.norm\_eps)   
  20.          
  21.      \# 在输出模块初始化线性层   
  22.      self.output \= nn.Linear(params.dim, params.vocab\_size, bias\=False)   
  23.    
  24.    def forward(self, x, start\_pos\=0, targets\=None):   
  25.          
  26.      \# start\_pos: 推理模式的标记位置, inference: True表示推理模式, False表示训练模式   
  27.      \# x是使用分词器从文本或提示生成的标记ID批次   
  28.      \# x\[bsz, seq\_len\] -> h\[bsz, seq\_len, dim\]   
  29.      h \= self.tok\_embeddings(x)   
  30.    
  31.      \# 如果目标为None,则激活推理模式并设置为"True",否则为训练模式"False"   
  32.      inference \= targets is None   
  33.    
  34.      \# 嵌入(h)然后将通过所有解码器块   
  35.      for layer in self.layers:   
  36.        h \= layer(h, start\_pos, inference)   
  37.    
  38.      \# 最后解码器块的输出将馈入RMSNorm   
  39.      h \= self.norm(h)   
  40.    
  41.      \# 归一化后,嵌入h将馈入线性层     
  42.      \# 线性层的主要任务是生成将嵌入映射到词汇表大小的logits   
  43.      \# h\[bsz, seq\_len, dim\] -> logits\[bsz, seq\_len, vocab\_size\]   
  44.      logits \= self.output(h).float()   
  45.      loss \= None   
  46.    
  47.      \# 如果目标不可用,则为推理模式   
  48.      if targets is None:   
  49.        loss \= None   
  50.      \# 如果目标可用,则为训练模式。计算损失以进行进一步的模型训练     
  51.      else:   
  52.        loss \= F.cross\_entropy(logits.view(\-1, self.params.vocab\_size), targets.view(\-1))   
  53.    
  54.      return logits, loss   
  55.    
  56. \### 测试: Transformer (Llama模型) ###   
  57. \# 取消下面的三重引号来执行测试   
  58. """   
  59. model = Transformer(ModelArgs).to(ModelArgs.device)   
  60. print(model)   
  61. """
复制代码

[图10]: Llama 3分层架构
我们刚刚构建的Llama 3模型结构看起来很完整。现在我们可以开始训练过程了。
4、训练Llama 3模型

训练流程在输出模块流程图(图9)中已经展示。在开始训练之前,让我们先实现训练代码。以下代码块中包罗了须要的解释。
  1. \## 步骤4: 训练Llama 3模型:   
  2.    
  3. \# 使用我们在输入模块部分构建的分词器的encode函数,通过对整个tiny\_shakespeare数据进行编码来创建数据集   
  4. dataset \= torch.tensor(encode(data), dtype\=torch.int).to(ModelArgs.device)   
  5. print(f"dataset-shape: {dataset.shape}")   
  6.    
  7. \# 定义函数从给定数据集生成批次   
  8. def get\_dataset\_batch(data, split, args:ModelArgs):   
  9.    seq\_len \= args.max\_seq\_len   
  10.    batch\_size \= args.max\_batch\_size   
  11.    device \= args.device   
  12.    
  13.    train \= data\[:int(0.8 \* len(data))\]   
  14.    val \= data\[int(0.8 \* len(data)): int(0.9 \* len(data))\]   
  15.    test \= data\[int(0.9 \* len(data)):\]   
  16.    
  17.    batch\_data \= train   
  18.    if split \== "val":   
  19.      batch\_data \= val   
  20.    elif split \== "test":   
  21.      batch\_data \= test   
  22.       
  23.    \# 从数据集中选择随机起点,为训练、验证和测试提供随机样本   
  24.    ix \= torch.randint(0, len(batch\_data) \- seq\_len \- 3, (batch\_size,)).to(device)   
  25.    x \= torch.stack(\[torch.cat(\[token\_bos, batch\_data\[i:i+seq\_len\-1\]\]) for i in ix\]).long().to(device)   
  26.    y \= torch.stack(\[torch.cat(\[batch\_data\[i+1:i+seq\_len\], token\_eos\]) for i in ix\]).long().to(device)   
  27.       
  28.    return x, y   
  29.    
  30. \### 测试: get\_dataset函数 ###   
  31. """   
  32. xs, ys = get\_dataset\_batch(dataset, split="train", args=ModelArgs)   
  33. print(\[(decode(xs\[i\].tolist()), decode(ys\[i\].tolist())) for i in range(len(xs))\])   
  34. """   
  35.    
  36. \# 定义evaluate\_loss函数来计算和存储训练和验证损失,用于日志记录和绘图   
  37. @torch.no\_grad()   
  38. def evaluate\_loss(model, args:ModelArgs):   
  39.    out \= {}   
  40.    model.eval()   
  41.    
  42.    for split in \["train", "val"\]:   
  43.      losses \= \[\]   
  44.      for \_ in range(10):         
  45.        xb, yb \= get\_dataset\_batch(dataset, split, args)   
  46.        \_, loss \= model(x\=xb, targets\=yb)   
  47.        losses.append(loss.item())   
  48.      out\[split\] \= np.mean(losses)   
  49.    
  50.    model.train()   
  51.    return out   
  52.    
  53. \# 定义训练函数来执行模型训练   
  54. def train(model, optimizer, args:ModelArgs):   
  55.      epochs \= args.epochs   
  56.      log\_interval \= args.log\_interval   
  57.      device \= args.device   
  58.      losses \= \[\]      
  59.      start\_time \= time.time()   
  60.    
  61.      for epoch in range(epochs):   
  62.          optimizer.zero\_grad()   
  63.             
  64.          xs, ys \= get\_dataset\_batch(dataset, 'train', args)   
  65.          xs \= xs.to(device)   
  66.          ys \= ys.to(device)   
  67.          logits, loss \= model(x\=xs, targets\=ys)   
  68.          loss.backward()   
  69.          optimizer.step()   
  70.    
  71.          if epoch % log\_interval \== 0:   
  72.              batch\_time \= time.time() \- start\_time   
  73.              x \= evaluate\_loss(model, args)   
  74.              losses.append(x)               
  75.              print(f"Epoch {epoch} | val loss {x\['val'\]:.3f} | Time {batch\_time:.3f}")   
  76.              start\_time \= time.time()   
  77.          
  78.      \# 打印最终验证损失   
  79.      print("验证损失: ", losses\[\-1\]\['val'\])   
  80.      \# 在图表中显示间隔损失     
  81.      return pd.DataFrame(losses).plot()
复制代码
定义完训练函数。就可以开始训练过程,并在训练完成后观察结果。
  1. \## 开始训练我们的Llama 3模型   
  2. model \= Transformer(ModelArgs).to(ModelArgs.device)   
  3. optimizer \= torch.optim.Adam(model.parameters())   
  4.    
  5. train(model, optimizer, ModelArgs)
复制代码

[图11] 训练与验证丧失图
上图表现了训练和验证丧失的变革。训练举行了2500个epoch。利用Google Colab的默认GPU和RAM设置,整个训练过程约莫花费了10分钟,这是相当快速的。最后一个epoch的验证丧失为2.19,考虑到我们利用的训练数据量和epoch数目,这个结果是可以担当的。要显著低落丧失,我们还需要增长训练数据的规模、提高epoch数目,并利用更强大的GPU或处理能力。
5、Llama 3模型推理

推理流程在输出模块流程图(图9)中已经展示。让我们实现推理代码。
  1. \## 步骤5: Llama 3模型推理   
  2. \# 这个函数使用我们构建和训练的Llama 3模型,基于提供的提示生成文本序列   
  3.    
  4. def generate(model, prompts: str, params: ModelArgs, max\_gen\_len: int\=500, temperature: float \= 0.6, top\_p: float \= 0.9):   
  5.    
  6.      \# prompt\_tokens: 用户输入文本或提示列表   
  7.      \# max\_gen\_len: 生成文本序列的最大长度   
  8.      \# temperature: 用于控制采样随机性的温度值。默认为0.6   
  9.      \# top\_p: 从logits采样prob输出的top-p概率阈值。默认为0.9   
  10.      bsz \= 1  \# 对于推理,通常用户只输入一个提示,我们将其作为1个批次   
  11.      prompt\_tokens \= token\_bos.tolist() + encode(prompts)   
  12.      assert len(prompt\_tokens) <= params.max\_seq\_len, "提示标记长度应小于max\_seq\_len"   
  13.      total\_len \= min(len(prompt\_tokens)+max\_gen\_len, params.max\_seq\_len)      
  14.    
  15.      \# 这个tokens矩阵用于存储输入提示和模型生成的所有输出   
  16.      \# 稍后我们将使用分词器的decode函数来解码这个token,以文本格式查看结果   
  17.      tokens \= torch.full((bsz,total\_len), fill\_value\=token\_pad.item(), dtype\=torch.long, device\=params.device)   
  18.    
  19.      \# 将提示tokens填入token矩阵   
  20.      tokens\[:,:len(prompt\_tokens)\] \= torch.tensor(prompt\_tokens, dtype\=torch.long, device\=params.device)   
  21.    
  22.      \# 创建一个prompt\_mask\_token,用于稍后识别token是提示token还是填充token   
  23.      \# 如果是提示token则为True,如果是填充token则为False   
  24.      input\_text\_mask \= tokens != token\_pad.item()   
  25.    
  26.      \# 现在我们可以从第一个位置开始,一次使用一个token从prompt\_tokens列表开始推理   
  27.      prev\_pos \= 0   
  28.      for cur\_pos in range(1, total\_len):   
  29.        with torch.no\_grad():   
  30.          logits, \_ \= model(x\=tokens\[:,prev\_pos:cur\_pos\], start\_pos\=prev\_pos)   
  31.        if temperature \> 0:         
  32.          probs \= torch.softmax(logits\[:, \-1\]/temperature, dim\=-1)   
  33.          next\_token \= sample\_top\_p(probs, top\_p)            
  34.        else:   
  35.          next\_token \= torch.argmax(logits\[:, \-1\], dim\=-1)            
  36.    
  37.        next\_token \= next\_token.reshape(\-1)   
  38.    
  39.        \# 只有在是填充token时才替换token   
  40.        next\_token \= torch.where(input\_text\_mask\[:, cur\_pos\], tokens\[:, cur\_pos\], next\_token)   
  41.        tokens\[:, cur\_pos\] \= next\_token   
  42.    
  43.        prev\_pos \= cur\_pos   
  44.        if tokens\[:,cur\_pos\]\==token\_pad.item() and next\_token \== token\_eos.item():   
  45.          break   
  46.    
  47.      output\_tokens, output\_texts \= \[\], \[\]        
  48.    
  49.      for i, toks in enumerate(tokens.tolist()):   
  50.        if token\_eos.item() in toks:   
  51.          eos\_idx \= toks.index(token\_eos.item())   
  52.          toks \= toks\[:eos\_idx\]   
  53.    
  54.        output\_tokens.append(toks)   
  55.        output\_texts.append(decode(toks))   
  56.      return output\_tokens, output\_texts   
  57.    
  58. \# 对概率分布执行top-p (nucleus) 采样   
  59. \# probs (torch.Tensor): 由logits导出的概率分布张量   
  60. \# p: top-p采样的概率阈值   
  61. \# 根据相关研究,Top-p采样选择累积概率质量超过阈值p的最小标记集     
  62. \# 基于选定的标记重新归一化分布   
  63. def sample\_top\_p(probs, p):   
  64.      probs\_sort, prob\_idx \= torch.sort(probs, dim\=-1, descending\=True)   
  65.      probs\_sum \= torch.cumsum(probs\_sort, dim\=-1)   
  66.      mask \= probs\_sum \- probs\_sort \> p   
  67.      probs\_sort\[mask\] \= 0.0   
  68.      probs\_sort.div\_(probs\_sort.sum(dim\=-1, keepdim\=True))   
  69.      next\_token \= torch.multinomial(probs\_sort, num\_samples\=1)   
  70.      next\_token \= torch.gather(prob\_idx, \-1, next\_token)        
  71.      \# 返回从词汇表中采样的标记索引     
  72.      return next\_token
复制代码
对新的提示执行推理,并检查生成的输出:
  1. \## 对用户输入的提示执行推理   
  2. prompts \= "Consider you what services he has done"   
  3. output\_tokens, output\_texts \= generate(model, prompts, ModelArgs)   
  4. output\_texts \= output\_texts\[0\].replace("<|begin\_of\_text|>", "")   
  5. print(output\_texts)   
  6.    
  7. \## 输出 ##   
  8. """   
  9. Consider you what services he has done o eretrane   
  10. adetranytnn i eey i ade hs rcuh i eey,ad hsatsTns rpae,T   
  11. eon o i hseflns o i eee ee hs ote i ocal ersl,Bnnlnface   
  12. o i hmr a il nwye ademto nt i a ere   
  13. h i ees.   
  14. Frm oe o etrane o oregae,alh,t orede i oeral   
  15. """
复制代码
从结果可以看出,我们的Llama 3模型可以或许对新的提示执行推理并生成文本。虽然考虑到我们利用的训练数据量和训练轮数,输出质量并不是很高,但这证明了模型的根本功能是正常的。通过利用更大规模的训练数据和更多的训练轮数,我们将可以或许得到更高质量的输出。
总结

我们已经成功地从零开始构建了本身的Llama 3模型。我们不仅实现了模型的架构,还成功地举行了训练,并可以或许执行推理以生成新的文本。值得注意的是,我们在相对有限的计算资源(Google Colab Notebook提供的免费GPU和RAM)下,在较短的时间内完成了这个过程。
本文中的代码和方法主要用于教导和研究目标。在现实应用中,大概需要举行更多的优化和调整,以到达生产级别的性能和结果。
如何学习AI大模型 ?

“最先掌握AI的人,将会比较晚掌握AI的人有竞争优势”。

这句话,放在计算机、互联网、移动互联网的开局时期,都是一样的道理。
我在一线互联网企业工作十余年里,指导过不少偕行后辈。帮助许多人得到了学习和发展。
我意识到有许多经验和知识值得分享给大家,故此将并将重要的AI大模型资料包罗AI大模型入门学习头脑导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。【包管100%免费】

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

梦应逍遥

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表