DeepSeek的AHA 时候 使用 Unsloth(GRPO)训练自己的 R1 推理模型
在人工智能领域,推理模型的训练一直是一个紧张且布满寻衅的话题。本文将深入探讨怎样使用 Unsloth(GRPO)来训练自己的 R1 推理模型 。
一、Unsloth(GRPO)简介
2025 年 2 月 6 日,由丹尼尔迈克尔带来的消息,迎来了名为 Unsloth 的推理新方法。DeepSeek 的 R1 研究中有一个“啊哈时候”,R1 - Zero 通过组相对策略优化(GRPO)在没有人类反馈的情况下自主学习,分配更多的思考时间。并且,我们对整个 GRPO 过程进行了加强,使其使用的 VRAM 比Hugging Face + FA2. 少 80%,这意味着 可以使用 Qwen2.5(1.5B)在仅 7 GB 的 VRAM 上重现 R1 - Zero 的“啊哈时候”。
二、使用 Unsloth(GRPO)的优势
(一)低 VRAM 要求
- Unsloth 答应使用 15 GB VRAM ,将多达 15 B 参数的任何模型(如 Llama 3.1(8B),Phi - 4(14 B),Mistral(7 B)或 Qwen2.5(7 B))转换为推理模型。
- 最低只需 7 GB VRAM 就足以在本地训练自己的推理模型。
- Tiny - Zero 团队使用 Qwen2.5(1.5B)实现“啊哈”时候需要 2xA 100 GPU(160 GB VRAM),而现在使用 Unsloth,单个 7 GB VRAM GPU 就能达成相同效果。
(二)支持多种调优方式
以前,GRPO 只支持完全微调,现在已经使其与 QLoRA 和 LoRA 一起工作。需要注意的是,这不是对 DeepSeek 的 R1 蒸馏模型进行微调,也不是使用 R1 的蒸馏数据进行 Unsloth 已支持的调优,而是使用 GRPO 将标准模型转换为成熟的推理模型。
(三)强大的用例
GRPO 有很多实用的用例。假如想制作一个带有奖励的定制模型(比如法律、医学等领域),GRPO 就能发挥作用。假如有输入和输出数据(如标题和答案),但没有思维链或推理过程,GRPO 可以神奇地创建推理过程。
三、GRPO +“啊哈”时候
DeepSeek 的研究人员在用纯强化学习(RL)训练 R1 - Zero 时观察到了“啊哈时候”。模型通过重新评估其初始方法来延伸思考时间,无需任何人类指导或预定义的指令。以 Phi - 4 为例,尽管只使用 GRPO 训练了 100 步,但效果很明显。没有 GRPO 训练的模型没有思维标志,而用 GRPO 训练的模型有思维标志且能给出正确答案。
GRPO 是一种 RL 算法,它可以在不需要值函数的情况下有效地优化响应,不像近端策略优化(PPO)依赖于值函数。 用 GRPO 训练的模型能自主发展自我验证和搜索能力,创造 “啊哈时候”。
(一)GRPO 工作原理
- 生成多组响应:模型生成多组响应。
- 响应评分:每个响应基于正确或由某个集合奖励函数(而非 LLM 奖励模型)创建的另一个度量来评分。
- 盘算平均得分:盘算该组的平均得分。
- 比较得分:每个回答的得分与组平均值进行比较。
- 加强模型:模型得到加强,有利于得分较高的反应。
比方,
什么是1+1? >>思考/工作链>>答案是2。
什么是2+2? >>思考/工作链>>答案是4。
对于这样的标题,最初需要收集大量数据来添补工作/思想链过程,而 GRPO 或其他 RL 算法可以引导模型自动展示推理能力并创建推理轨迹。只需要创建好的奖励函数或验证器,如回答正确得 1 分,单词拼写错误减 0.1 分等。
四、在 Unsloth 中使用 GRPO
(一)依赖安装
假如在本地使用 GRPO 和 Unsloth,需要“pip install diffusers”,
(二)训练时间
等待至少 300 步奖励才会实际增加,建议使用最新版本的 vLLM。Colab 上的示例仅训练了一个小时,效果不佳。为了得到好的效果,需要训练至少 12 个小时,但这不是强制性的,可以随时停止。
(三)模型选择
建议将 GRPO 应用于参数至少为 1.5B 的模型,以正确地生成思维标志,因为较小的模型可能无法做到。假如使用的是基本模型,要确保有聊天模板。GRPO 的训练损失跟踪现在直接内置在 Unsloth 中,无需 wandb 等外部工具。
五、Unsloth x vLLM
Unsloth的在线DPO VRAM消耗与标准Hugging Face + FA 2的图表比较。
(一)性能提升
使用 vLLM 直接在微调堆栈中,吞吐量提高 20 倍,VRAM 节省 50%。在 1x A100 40 GB 上,预期 4000 标志/秒 Unsloth 的动态 4 位量化的 Llama 3.2 3B 指令。在 16 GB Tesla T4( Colab GPU)上,可得到 300 个标志/秒。
(二)内存优化
删除了双重内存使用,当一起加载 vLLM 和 Unsloth 时,为 Llama 3.1 8B 节省 5GB 左右,为 Llama 3.2 3B 节省 3GB。Unsloth 最初可以在 1x 48GB GPU 中微调 Llama 3.3 70B 指令,使用该优化后,在 48 GB 以下的 VRAM 中也能进行微调并得到快速推理的好处。
(三)快速推理使用方法
要使用快速推理,起首安装 vllm,并使用 fast_inference 实例化 Unsloth:
- pip install unsloth vllm
- from unsloth import FastLanguageModel
- model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/Llama-3.2-3B-Instruct",
- fast_inference = True,
- )
- model.fast_generate(["Hello!"])
复制代码 六、vLLM 在 Unsloth 中的发现
(一)量化优势
vLLM 现在可以加载 Unsloth 动态 4 位量化。将某些层动态量化为 4 位,某些层动态量化为 16 位可以显著提高精度,同时保持模型较小。
(二)参数自动选择
自动选择多个参数来考虑 RAM,VRAM 效率和最大吞吐量(如分块预添补标志的数目,最大序列数等)。在 vLLM 中默认启用 - O3 并启用前缀缓存。发现旧版 GPU 上的 Flashinferer 实际上慢了 10%,FP 8 KV 缓存使速度慢了 10%,但吞吐量潜力增加了一倍。
(三)LoRA 加载优化
答应 LoRA 通过解析状态 dict 而不是从磁盘加载加载到 vLLM 中,这可以使 GRPO 训练运行速度快 1.5 倍。目前正在研究直接编辑 vLLM 中的 LoRA 适配器,以减少不必要的 GPU 数据移动,提高速度。
(四)内存峰值处理
vLLM 在批量生成时会有随机的 VRAM 尖峰,添加了一个批量生成函数来减少内存峰值。
通过使用 Unsloth(GRPO), 可以更加高效、机动地训练自己的 R1 推理模型,在差别的领域发挥出更大的作用。希望大家都能在这个过程中有所劳绩,创造出更优秀的人工智能模型。
踏上属于你的"啊哈时候"探索之旅
在这个AI技能以月为单元迭代的期间,DeepSeek的R1研究向我们展现了一个令人振奋的真相:推理能力的突破不再是大厂的专属特权。通过Unsloth(GRPO)这项革命性技能,我们正在见证一场AI民主化的浪潮——从实验室级的160GB GPU集群到个人开发者手中的7GB显卡,从繁琐的人类反馈到模型自主演化的"顿悟时候",智能进化的门槛正在被重新定义。
当你注视着GRPO训练过程中逐渐浮现的思维标志时,那不但是模型认知能力的跃迁,更是人类探索智能本质蹊径上的一盏明灯。这项技能赋予我们亘古未有的可能:
- 用消耗级硬件成为专业领域的推理专家
- 以极简数据激发模型自主构建思维链条
- 在开源生态中打造媲美顶尖实验室的智能体
此刻,摆在每位开发者面前的,是一个布满机遇的新边疆。无论是想构建法律咨询助手、医疗诊断系统,还是创造下一个颠覆性的AI应用,Unsloth都为你提供了轻量级的"思维引擎"。那些曾经需要庞大团队支撑的复杂训练流程,现在已浓缩为几行清晰的Python代码。
正如DeepSeek团队在观察到R1-Zero的"啊哈时候"时所见证的,真正的突破往往诞生于对技能本质的深刻理解与大胆实践。现在轮到我们接过这柄火炬:调整你的奖励函数,启动GRPO训练,在损失曲线的波动中捕获智能觉醒的瞬间。当第一个自主生成的思维链从你的模型中跃然而出时,你会明白——这不但是AI的"顿悟时候",更是人类聪明又一次照亮未知领域的璀璨星光。
未来已来,唯快不破。打开你的Colab笔记本,从这篇博客的第一个代码块开始,属于你的智能革命,此刻正在7GB显存的方寸之间寂静发展。
附录:
- # %% [markdown]
- # To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
- # <div class="align-center">
- # <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
- # <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
- # <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
- # </div>
- #
- # To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).
- #
- # You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)
- #
- # %% [markdown]
- # ### News
- # %% [markdown]
- # **Read our [blog post](https://unsloth.ai/blog/r1-reasoning) for guidance on how to train reasoning models.**
- #
- # Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).
- #
- # %% [markdown]
- # ### Installation
- # %%
- %%capture
- # Skip restarting message in Colab
- import sys; modules = list(sys.modules.keys())
- for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
- !pip install unsloth vllm
- !pip install --upgrade pillow
- # %% [markdown]
- # ### Unsloth
- # %% [markdown]
- # Use `PatchFastRL` before all functions to patch GRPO and other RL algorithms!
- # %%
- from unsloth import FastLanguageModel, PatchFastRL
- PatchFastRL("GRPO", FastLanguageModel)
- # %% [markdown]
- # Load up `Llama 3.1 8B Instruct`, and set parameters
- # %%
- from unsloth import is_bfloat16_supported
- import torch
- max_seq_length = 512 # Can increase for longer reasoning traces
- lora_rank = 32 # Larger rank = smarter, but slower
- model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
- max_seq_length = max_seq_length,
- load_in_4bit = True, # False for LoRA 16bit
- fast_inference = True, # Enable vLLM fast inference
- max_lora_rank = lora_rank,
- gpu_memory_utilization = 0.6, # Reduce if out of memory
- )
- model = FastLanguageModel.get_peft_model(
- model,
- r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- target_modules = [
- "q_proj", "k_proj", "v_proj", "o_proj",
- "gate_proj", "up_proj", "down_proj",
- ], # Remove QKVO if out of memory
- lora_alpha = lora_rank,
- use_gradient_checkpointing = "unsloth", # Enable long context finetuning
- random_state = 3407,
- )
- # %% [markdown]
- # ### Data Prep
- # <a name="Data"></a>
- #
- # We directly leverage [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) for data prep and all reward functions. You are free to create your own!
- # %%
- import re
- from datasets import load_dataset, Dataset
- # Load and prep dataset
- SYSTEM_PROMPT = """
- Respond in the following format:
- <reasoning>
- ...
- </reasoning>
- <answer>
- ...
- </answer>
- """
- XML_COT_FORMAT = """\
- <reasoning>
- {reasoning}
- </reasoning>
- <answer>
- {answer}
- </answer>
- """
- def extract_xml_answer(text: str) -> str:
- answer = text.split("<answer>")[-1]
- answer = answer.split("</answer>")[0]
- return answer.strip()
- def extract_hash_answer(text: str) -> str | None:
- if "####" not in text:
- return None
- return text.split("####")[1].strip()
- # uncomment middle messages for 1-shot prompting
- def get_gsm8k_questions(split = "train") -> Dataset:
- data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
- data = data.map(lambda x: { # type: ignore
- 'prompt': [
- {'role': 'system', 'content': SYSTEM_PROMPT},
- {'role': 'user', 'content': x['question']}
- ],
- 'answer': extract_hash_answer(x['answer'])
- }) # type: ignore
- return data # type: ignore
- dataset = get_gsm8k_questions()
- # Reward functions
- def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
- responses = [completion[0]['content'] for completion in completions]
- q = prompts[0][-1]['content']
- extracted_responses = [extract_xml_answer(r) for r in responses]
- print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
- return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
- def int_reward_func(completions, **kwargs) -> list[float]:
- responses = [completion[0]['content'] for completion in completions]
- extracted_responses = [extract_xml_answer(r) for r in responses]
- return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
- def strict_format_reward_func(completions, **kwargs) -> list[float]:
- """Reward function that checks if the completion has a specific format."""
- pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
- responses = [completion[0]["content"] for completion in completions]
- matches = [re.match(pattern, r) for r in responses]
- return [0.5 if match else 0.0 for match in matches]
- def soft_format_reward_func(completions, **kwargs) -> list[float]:
- """Reward function that checks if the completion has a specific format."""
- pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
- responses = [completion[0]["content"] for completion in completions]
- matches = [re.match(pattern, r) for r in responses]
- return [0.5 if match else 0.0 for match in matches]
- def count_xml(text) -> float:
- count = 0.0
- if text.count("<reasoning>\n") == 1:
- count += 0.125
- if text.count("\n</reasoning>\n") == 1:
- count += 0.125
- if text.count("\n<answer>\n") == 1:
- count += 0.125
- count -= len(text.split("\n</answer>\n")[-1])*0.001
- if text.count("\n</answer>") == 1:
- count += 0.125
- count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
- return count
- def xmlcount_reward_func(completions, **kwargs) -> list[float]:
- contents = [completion[0]["content"] for completion in completions]
- return [count_xml(c) for c in contents]
- # %% [markdown]
- # <a name="Train"></a>
- # ### Train the model
- #
- # Now set up GRPO Trainer and all configurations!
- # %%
- from trl import GRPOConfig, GRPOTrainer
- training_args = GRPOConfig(
- use_vllm = True, # use vLLM for fast inference!
- learning_rate = 5e-6,
- adam_beta1 = 0.9,
- adam_beta2 = 0.99,
- weight_decay = 0.1,
- warmup_ratio = 0.1,
- lr_scheduler_type = "cosine",
- optim = "paged_adamw_8bit",
- logging_steps = 1,
- bf16 = is_bfloat16_supported(),
- fp16 = not is_bfloat16_supported(),
- per_device_train_batch_size = 1,
- gradient_accumulation_steps = 1, # Increase to 4 for smoother training
- num_generations = 6, # Decrease if out of memory
- max_prompt_length = 256,
- max_completion_length = 200,
- # num_train_epochs = 1, # Set to 1 for a full training run
- max_steps = 250,
- save_steps = 250,
- max_grad_norm = 0.1,
- report_to = "none", # Can use Weights & Biases
- output_dir = "outputs",
- )
- # %% [markdown]
- # And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!
- #
- # You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!
- #
- # | Step | Training Loss | reward | reward_std | completion_length | kl |
- # |------|---------------|-----------|------------|-------------------|----------|
- # | 1 | 0.000000 | 0.125000 | 0.000000 | 200.000000 | 0.000000 |
- # | 2 | 0.000000 | 0.072375 | 0.248112 | 200.000000 | 0.000000 |
- # | 3 | 0.000000 | -0.079000 | 0.163776 | 182.500000 | 0.000005 |
- #
- # %%
- trainer = GRPOTrainer(
- model = model,
- processing_class = tokenizer,
- reward_funcs = [
- xmlcount_reward_func,
- soft_format_reward_func,
- strict_format_reward_func,
- int_reward_func,
- correctness_reward_func,
- ],
- args = training_args,
- train_dataset = dataset,
- )
- trainer.train()
- # %% [markdown]
- # <a name="Inference"></a>
- # ### Inference
- # Now let's try the model we just trained! First, let's first try the model without any GRPO trained:
- # %%
- text = tokenizer.apply_chat_template([
- {"role" : "user", "content" : "Calculate pi."},
- ], tokenize = False, add_generation_prompt = True)
- from vllm import SamplingParams
- sampling_params = SamplingParams(
- temperature = 0.8,
- top_p = 0.95,
- max_tokens = 1024,
- )
- output = model.fast_generate(
- [text],
- sampling_params = sampling_params,
- lora_request = None,
- )[0].outputs[0].text
- output
- # %% [markdown]
- # And now with the LoRA we just trained with GRPO - we first save the LoRA first!
- # %%
- model.save_lora("grpo_saved_lora")
- # %% [markdown]
- # Now we load the LoRA and test:
- # %%
- text = tokenizer.apply_chat_template([
- {"role" : "system", "content" : SYSTEM_PROMPT},
- {"role" : "user", "content" : "Calculate pi."},
- ], tokenize = False, add_generation_prompt = True)
- from vllm import SamplingParams
- sampling_params = SamplingParams(
- temperature = 0.8,
- top_p = 0.95,
- max_tokens = 1024,
- )
- output = model.fast_generate(
- text,
- sampling_params = sampling_params,
- lora_request = model.load_lora("grpo_saved_lora"),
- )[0].outputs[0].text
- output
- # %% [markdown]
- # Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!
- # %% [markdown]
- # <a name="Save"></a>
- # ### Saving to float16 for VLLM
- #
- # We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.
- # %%
- # Merge to 16bit
- if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
- if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")
- # Merge to 4bit
- if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
- if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")
- # Just LoRA adapters
- if False: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
- if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "lora", token = "")
- # %% [markdown]
- # ### GGUF / llama.cpp Conversion
- # To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.
- #
- # Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):
- # * `q8_0` - Fast conversion. High resource use, but generally acceptable.
- # * `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
- # * `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.
- #
- # [**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/drive/1WZDi7APtQ9VsvOrQSSC5DDtxq159j8iZ?usp=sharing)
- # %%
- # Save to 8bit Q8_0
- if False: model.save_pretrained_gguf("model", tokenizer,)
- # Remember to go to https://huggingface.co/settings/tokens for a token!
- # And change hf to your username!
- if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")
- # Save to 16bit GGUF
- if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
- if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")
- # Save to q4_k_m GGUF
- if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
- if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")
- # Save to multiple GGUF options - much faster if you want multiple!
- if False:
- model.push_to_hub_gguf(
- "hf/model", # Change hf to your username!
- tokenizer,
- quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
- token = "",
- )
- # %% [markdown]
- # Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)
- #
- # And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!
- #
- # Some other links:
- # 1. Llama 3.2 Conversational notebook. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb)
- # 2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
- # 3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
- # 6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!
- #
- # <div class="align-center">
- # <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
- # <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
- # <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
- #
- # Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
- # </div>
- #
复制代码





免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |