背景
LlamaFactory 的 LoRA 微调功能非常便捷,微调后的模型,没有直接支持批量的 vllm 推理。
LlamaFactory 现在支持通过 VLLM API 举行部署,调用 API 时的相应速率,仍旧没有vllm批量推理的速率快。
如果模型是通过 LlamaFactory 微调的,为了确保数据集的一致性,建议在推理时也使用 LlamaFactory 提供的封装数据集。
简介
我把下述代码贡献到了 LlamaFactory 的 scripts/vllm_infer.py下,如今已经举行了更新,建议大家使用最新的scripts/vllm_infer.py,会更方便一点。
下述的代码是最初的版本。
在上述的背景下,我们使用 LlamaFactory 原生数据集,支持 lora的 vllm 批量推理。
完整代码如下:
- import json
- import os
- from typing import List
- from vllm import LLM, SamplingParams
- from vllm.lora.request import LoRARequest
- from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
- from llamafactory.extras.constants import IGNORE_INDEX
- from llamafactory.hparams import get_train_args
- from llamafactory.model import load_tokenizer
- def vllm_infer():
- model_args, data_args, training_args, finetuning_args, generating_args = (
- get_train_args()
- )
- tokenizer = load_tokenizer(model_args)["tokenizer"]
- template = get_template_and_fix_tokenizer(tokenizer, data_args)
- eval_dataset = get_dataset(
- template, model_args, data_args, training_args, finetuning_args.stage, tokenizer
- )["eval_dataset"]
- prompts = [item["input_ids"] for item in eval_dataset]
- prompts = tokenizer.batch_decode(prompts, skip_special_tokens=False)
- labels = [
- list(filter(lambda x: x != IGNORE_INDEX, item["labels"]))
- for item in eval_dataset
- ]
- labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
- sampling_params = SamplingParams(
- temperature=generating_args.temperature,
- top_k=generating_args.top_k,
- top_p=generating_args.top_p,
- max_tokens=2048,
- )
- if model_args.adapter_name_or_path:
- if isinstance(model_args.adapter_name_or_path, list):
- lora_requests = []
- for i, _lora_path in enumerate(model_args.adapter_name_or_path):
- lora_requests.append(
- LoRARequest(f"lora_adapter_{i}", i, lora_path=_lora_path)
- )
- else:
- lora_requests = LoRARequest(
- "lora_adapter_0", 0, lora_path=model_args.adapter_name_or_path
- )
- enable_lora = True
- else:
- lora_requests = None
- enable_lora = False
- llm = LLM(
- model=model_args.model_name_or_path,
- trust_remote_code=True,
- tokenizer=model_args.model_name_or_path,
- enable_lora=enable_lora,
- )
- outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)
- if not os.path.exists(training_args.output_dir):
- os.makedirs(training_args.output_dir, exist_ok=True)
- output_prediction_file = os.path.join(
- training_args.output_dir, "generated_predictions.jsonl"
- )
- with open(output_prediction_file, "w", encoding="utf-8") as writer:
- res: List[str] = []
- for text, pred, label in zip(prompts, outputs, labels):
- res.append(
- json.dumps(
- {"prompt": text, "predict": pred.outputs[0].text, "label": label},
- ensure_ascii=False,
- )
- )
- writer.write("\n".join(res))
复制代码 vllm.yaml 示例:
- ## model
- model_name_or_path: qwen/Qwen2.5-7B-Instruct
- # adapter_name_or_path: lora模型
- ### method
- stage: sft
- do_predict: true
- finetuning_type: lora
- ### dataset
- dataset_dir: 数据集路径
- eval_dataset: 数据集
- template: qwen
- cutoff_len: 1024
- max_samples: 1000
- overwrite_cache: true
- preprocessing_num_workers: 16
- ### output
- output_dir: output/
- overwrite_output_dir: true
- ### eval
- predict_with_generate: true
复制代码 步调调用:
- python vllm_infer.py vllm.yaml
复制代码 步调运行速率:
- Processed prompts: 100%|█| 1000/1000 [01:56<00:00, 8.60it/s, est. speed input: 5169.35 toks/s, output: 811.57
复制代码 总结
本方案在原生 LlamaFactory 数据集的底子上,支持 LoRA 的 vllm 批量推理,能提升了推理效率。
进一步阅读
如果微调模型后,发现前文使用vllm模型批量结果不太好,可以参考与使用下述文章:
- 基于 LLamafactory 的异步API高效调用实现与速率对比.https://blog.csdn.net/sjxgghg/article/details/144176645
亲测,上述文章的 LLamafactory 部署模型,然后使用 Async API 调用后举行评估发现结果会好一些。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |