【大模子】微调实战—利用 ORPO 微调 Llama 3

打印 上一主题 下一主题

主题 532|帖子 532|积分 1596

ORPO 是一种新颖微调(fine-tuning)技能,它将传统的监督微调(supervised fine-tuning)和偏好对齐(preference alignment)阶段归并为一个过程。这减少了练习所需的盘算资源和时间。此外,实证效果表明,ORPO 在各种模子规模和基准测试(benchmarks)上优于其他对齐方法。
在本文中,我们将利用 ORPO 和 TRL 库对新的 Llama 3 8B 模子进行微调。
ORPO

指令微调(instruction tuning)和偏好对齐(preference alignment)是使LLM顺应特定任务的基本技能。传统上,这涉及一个多阶段的过程:1/ 在指令上进行监督微调(Supervised Fine-Tuning, SFT),以使模子顺应目标领域,然后 2/ 利用偏好对齐方法,如基于人类反馈的强化学习(Reinforcement Learning with Human Feedback, RLHF)或直接偏好优化(Direct Preference Optimization, DPO),以增长生成首选相应而非被拒绝相应的可能性。

然而,研究人员发现了这种方法的范围性。虽然 SFT 有效地使模子顺应所需的领域,但它无意中增长了在首选答案的同时生成不须要的答案的可能性。这就是为什么偏好调整阶段对于扩大首选输出和拒绝输出的可能性之间的差距是须要的。
ORPO 由 Hong 和 Lee (2024) 提出,通过将指令调整和偏好对齐联合到一个单一的整体练习过程中,为这个问题提供了一个优雅的解决方案。 ORPO 修改了尺度语言建模目标,将负对数似然丧失与上风比 (OR) 项相联合。这种 OR 丧失对被拒绝的相应进行弱处罚,同时对首选相应进行强烈奖励,从而使模子能够同时学习目标任务并与人类偏好保持同等。

ORPO 已在紧张微调库中实现,如 TRL、Axolotl 和 LLaMA-Factory。在下一节中,我们将相识怎样与 TRL 一起利用。
利用 ORPO 微调 Llama 3

Llama 3 是Meta开发的最新大型语言模子(LLM)家族。该模子在一个包含15万亿个标记的数据集上进行了练习(相比之下,Llama 2 的练习数据集为2万亿个标记)。现在已经发布了两种模子尺寸:一个是拥有70B参数的模子,另一个是较小的8B参数模子。70B参数的模子已经展示了令人印象深刻的性能,在MMLU基准测试中得分为82,在HumanEval基准测试中得分为81.7。
Llama 3 模子还将上下文长度增长到了8,192个标记(相比之下,Llama 2 为4,096个标记),并且有可能通过RoPE扩展到32k。此外,这些模子利用了一种新的分词器,具有128K标记的词汇量,从而减少了编码文本所需的标记数目15%。这种词汇量的增长也解释了参数从70亿增长到80亿。
ORPO 须要一个偏好数据集,包括提示、选择的答案和拒绝的答案。在此示例中,我们将利用 mlabonne/orpo-dpo-mix-40k ,它是以下高质量 DPO 数据集的组合:


  • argilla/distilabel-capybara-dpo-7k-binarized: highly scored chosen answers >=5 (2,882 samples)
  • argilla/distilabel-intel-orca-dpo-pairs: highly scored chosen answers>=9, not in GSM8K (2,299 samples)
  • argilla/ultrafeedback-binarized-preferences-cleaned: highly scoredchosen answers >=5 (22,799 samples)
  • argilla/distilabel-math-preference-dpo: highly scored chosen answers>=9 (2,181 samples)
  • unalignment/toxic-dpo-v0.2 (541 samples)
  • M4-ai/prm_dpo_pairs_cleaned (7,958 samples)
  • jondurbin/truthy-dpo-v0.1 (1,016 samples)
首先安装所需的库:
  1. pip install -U transformers datasets accelerate peft trl bitsandbytes wandb
复制代码
安装完成后,我们可以导入须要的库并登录W&B(可选)
  1. import gc
  2. import os
  3. import torch
  4. import wandb
  5. from datasets import load_dataset
  6. # from google.colab import userdata
  7. from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
  8. from transformers import (
  9.     AutoModelForCausalLM,
  10.     AutoTokenizer,
  11.     BitsAndBytesConfig,
  12.     TrainingArguments,
  13.     pipeline,
  14. )
  15. from trl import ORPOConfig, ORPOTrainer, setup_chat_format
  16. # wb_token = userdata.get('wandb')
  17. # wandb.login(key=wb_token)
复制代码
如果您有最新的 GPU,还应该能够利用 Flash Attention 库将默认的 eager Attention 实现替换为更高效的实现。
  1. if torch.cuda.get_device_capability()[0] >= 8:
  2.     #!pip install -qqq flash-attn
  3.     attn_implementation = "flash_attention_2"
  4.     torch_dtype = torch.bfloat16
  5. else:
  6.     attn_implementation = "eager"
  7.     torch_dtype = torch.float16
复制代码
接下来,我们将借助bitsandbytes 以 4 位精度加载 Llama 3 8B 模子。然后,我们利用 QLoRA 的 PEFT 设置 LoRA 配置。我还利用方便的 setup_chat_format() 函数来修改模子和标记生成器以支持 ChatML。它会自动应用此谈天模板,添加特别标记,并调整模子嵌入层的巨细以匹配新的词汇表巨细。
请注意,您须要提交访问 meta-llama/Meta-Llama-3-8B 的哀求并登录您的 Hugging Face 帐户。或者,您可以加载模子的非门控副本,比方 NousResearch/Meta–Llama-3-8B。(我选择手动从NousResearch/Meta–Llama-3-8B下载)
  1. # Model
  2. base_model = "meta-llama/Meta-Llama-3-8B"
  3. new_model = "OrpoLlama-3-8B"
  4. # QLoRA config
  5. bnb_config = BitsAndBytesConfig(
  6.     load_in_4bit=True,
  7.     bnb_4bit_quant_type="nf4",
  8.     bnb_4bit_compute_dtype=torch_dtype,
  9.     bnb_4bit_use_double_quant=True,
  10. )
  11. # LoRA config
  12. peft_config = LoraConfig(
  13.     r=16,
  14.     lora_alpha=32,
  15.     lora_dropout=0.05,
  16.     bias="none",
  17.     task_type="CAUSAL_LM",
  18.     target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
  19. )
  20. # Load tokenizer
  21. tokenizer = AutoTokenizer.from_pretrained(base_model)
  22. # Load model
  23. model = AutoModelForCausalLM.from_pretrained(
  24.     base_model,
  25.     quantization_config=bnb_config,
  26.     device_map="auto",
  27.     attn_implementation=attn_implementation
  28. )
  29. model, tokenizer = setup_chat_format(model, tokenizer)
  30. model = prepare_model_for_kbit_training(model)
复制代码
现在模子已准备好进行练习,我们可以处置惩罚数据集了。我们加载 mlabonne/orpo-dpo-mix-40k 并利用 apply_chat_template() 函数将“chosen”和“rejected”列转换为 ChatML 格式。请注意,我仅利用 1,00 个样本,而不是整个数据集,由于运行时间太长。(我选择手动下载)
  1. dataset_name = "mlabonne/orpo-dpo-mix-40k"
  2. dataset = load_dataset(dataset_name, split="all")
  3. dataset = dataset.shuffle(seed=42).select(range(100))
  4. def format_chat_template(row):
  5.     row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
  6.     row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
  7.     return row
  8. dataset = dataset.map(
  9.     format_chat_template,
  10.     num_proc= os.cpu_count(),
  11. )
  12. dataset = dataset.train_test_split(test_size=0.01)
复制代码
首先,我们须要设置一些超参数: * learning_rate :与传统的 SFT 甚至 DPO 相比,ORPO 利用非常低的学习率。 8e-6这个值来自原始论文,大致对应于SFT学习率1e-5和DPO学习率5e-6。我建议将其增长到 1e-6 左右以进行真正的微调。 * beta :即论文中的

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

美丽的神话

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表