ToB企服应用市场:ToB评测及商务社交产业平台

标题: 基于 LlamaFactory 的 LoRA 微调模型支持 vllm 批量推理的实现 [打印本页]

作者: 冬雨财经    时间: 昨天 18:36
标题: 基于 LlamaFactory 的 LoRA 微调模型支持 vllm 批量推理的实现
背景

LlamaFactory 的 LoRA 微调功能非常便捷,微调后的模型,没有直接支持批量的 vllm 推理。
LlamaFactory 现在支持通过 VLLM API 举行部署,调用 API 时的相应速率,仍旧没有vllm批量推理的速率快。
如果模型是通过 LlamaFactory 微调的,为了确保数据集的一致性,建议在推理时也使用 LlamaFactory 提供的封装数据集。
简介

我把下述代码贡献到了 LlamaFactory 的 scripts/vllm_infer.py下,如今已经举行了更新,建议大家使用最新的scripts/vllm_infer.py,会更方便一点。
下述的代码是最初的版本。
在上述的背景下,我们使用 LlamaFactory 原生数据集,支持 lora的 vllm 批量推理。
完整代码如下:
  1. import json
  2. import os
  3. from typing import List
  4. from vllm import LLM, SamplingParams
  5. from vllm.lora.request import LoRARequest
  6. from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
  7. from llamafactory.extras.constants import IGNORE_INDEX
  8. from llamafactory.hparams import get_train_args
  9. from llamafactory.model import load_tokenizer
  10. def vllm_infer():
  11.     model_args, data_args, training_args, finetuning_args, generating_args = (
  12.         get_train_args()
  13.     )
  14.     tokenizer = load_tokenizer(model_args)["tokenizer"]
  15.     template = get_template_and_fix_tokenizer(tokenizer, data_args)
  16.     eval_dataset = get_dataset(
  17.         template, model_args, data_args, training_args, finetuning_args.stage, tokenizer
  18.     )["eval_dataset"]
  19.     prompts = [item["input_ids"] for item in eval_dataset]
  20.     prompts = tokenizer.batch_decode(prompts, skip_special_tokens=False)
  21.     labels = [
  22.         list(filter(lambda x: x != IGNORE_INDEX, item["labels"]))
  23.         for item in eval_dataset
  24.     ]
  25.     labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  26.     sampling_params = SamplingParams(
  27.         temperature=generating_args.temperature,
  28.         top_k=generating_args.top_k,
  29.         top_p=generating_args.top_p,
  30.         max_tokens=2048,
  31.     )
  32.     if model_args.adapter_name_or_path:
  33.         if isinstance(model_args.adapter_name_or_path, list):
  34.             lora_requests = []
  35.             for i, _lora_path in enumerate(model_args.adapter_name_or_path):
  36.                 lora_requests.append(
  37.                     LoRARequest(f"lora_adapter_{i}", i, lora_path=_lora_path)
  38.                 )
  39.         else:
  40.             lora_requests = LoRARequest(
  41.                 "lora_adapter_0", 0, lora_path=model_args.adapter_name_or_path
  42.             )
  43.         enable_lora = True
  44.     else:
  45.         lora_requests = None
  46.         enable_lora = False
  47.     llm = LLM(
  48.         model=model_args.model_name_or_path,
  49.         trust_remote_code=True,
  50.         tokenizer=model_args.model_name_or_path,
  51.         enable_lora=enable_lora,
  52.     )
  53.     outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)
  54.     if not os.path.exists(training_args.output_dir):
  55.         os.makedirs(training_args.output_dir, exist_ok=True)
  56.     output_prediction_file = os.path.join(
  57.         training_args.output_dir, "generated_predictions.jsonl"
  58.     )
  59.     with open(output_prediction_file, "w", encoding="utf-8") as writer:
  60.         res: List[str] = []
  61.         for text, pred, label in zip(prompts, outputs, labels):
  62.             res.append(
  63.                 json.dumps(
  64.                     {"prompt": text, "predict": pred.outputs[0].text, "label": label},
  65.                     ensure_ascii=False,
  66.                 )
  67.             )
  68.         writer.write("\n".join(res))
复制代码
vllm.yaml 示例:
  1. ## model
  2. model_name_or_path: qwen/Qwen2.5-7B-Instruct
  3. # adapter_name_or_path: lora模型
  4. ### method
  5. stage: sft
  6. do_predict: true
  7. finetuning_type: lora
  8. ### dataset
  9. dataset_dir: 数据集路径
  10. eval_dataset: 数据集
  11. template: qwen
  12. cutoff_len: 1024
  13. max_samples: 1000
  14. overwrite_cache: true
  15. preprocessing_num_workers: 16
  16. ### output
  17. output_dir: output/
  18. overwrite_output_dir: true
  19. ### eval
  20. predict_with_generate: true
复制代码
步调调用:
  1. python vllm_infer.py vllm.yaml
复制代码
步调运行速率:
  1. Processed prompts: 100%|█| 1000/1000 [01:56<00:00,  8.60it/s, est. speed input: 5169.35 toks/s, output: 811.57
复制代码
总结

本方案在原生 LlamaFactory 数据集的底子上,支持 LoRA 的 vllm 批量推理,能提升了推理效率。
进一步阅读

如果微调模型后,发现前文使用vllm模型批量结果不太好,可以参考与使用下述文章:


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4