ToB企服应用市场:ToB评测及商务社交产业平台
标题:
基于 LlamaFactory 的 LoRA 微调模型支持 vllm 批量推理的实现
[打印本页]
作者:
冬雨财经
时间:
昨天 18:36
标题:
基于 LlamaFactory 的 LoRA 微调模型支持 vllm 批量推理的实现
背景
LlamaFactory 的 LoRA 微调功能非常便捷,微调后的模型,没有直接支持批量的 vllm 推理。
LlamaFactory 现在支持通过 VLLM API 举行部署,调用 API 时的相应速率,仍旧没有vllm批量推理的速率快。
如果模型是通过 LlamaFactory 微调的,为了确保数据集的一致性,建议在推理时也使用 LlamaFactory 提供的封装数据集。
简介
我把下述代码贡献到了 LlamaFactory 的 scripts/vllm_infer.py下,如今已经举行了更新,建议大家使用最新的scripts/vllm_infer.py,会更方便一点。
下述的代码是最初的版本。
在上述的背景下,我们使用 LlamaFactory 原生数据集,支持 lora的 vllm 批量推理。
完整代码如下:
import json
import os
from typing import List
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
def vllm_infer():
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args()
)
tokenizer = load_tokenizer(model_args)["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
eval_dataset = get_dataset(
template, model_args, data_args, training_args, finetuning_args.stage, tokenizer
)["eval_dataset"]
prompts = [item["input_ids"] for item in eval_dataset]
prompts = tokenizer.batch_decode(prompts, skip_special_tokens=False)
labels = [
list(filter(lambda x: x != IGNORE_INDEX, item["labels"]))
for item in eval_dataset
]
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
sampling_params = SamplingParams(
temperature=generating_args.temperature,
top_k=generating_args.top_k,
top_p=generating_args.top_p,
max_tokens=2048,
)
if model_args.adapter_name_or_path:
if isinstance(model_args.adapter_name_or_path, list):
lora_requests = []
for i, _lora_path in enumerate(model_args.adapter_name_or_path):
lora_requests.append(
LoRARequest(f"lora_adapter_{i}", i, lora_path=_lora_path)
)
else:
lora_requests = LoRARequest(
"lora_adapter_0", 0, lora_path=model_args.adapter_name_or_path
)
enable_lora = True
else:
lora_requests = None
enable_lora = False
llm = LLM(
model=model_args.model_name_or_path,
trust_remote_code=True,
tokenizer=model_args.model_name_or_path,
enable_lora=enable_lora,
)
outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)
if not os.path.exists(training_args.output_dir):
os.makedirs(training_args.output_dir, exist_ok=True)
output_prediction_file = os.path.join(
training_args.output_dir, "generated_predictions.jsonl"
)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for text, pred, label in zip(prompts, outputs, labels):
res.append(
json.dumps(
{"prompt": text, "predict": pred.outputs[0].text, "label": label},
ensure_ascii=False,
)
)
writer.write("\n".join(res))
复制代码
vllm.yaml 示例:
## model
model_name_or_path: qwen/Qwen2.5-7B-Instruct
# adapter_name_or_path: lora模型
### method
stage: sft
do_predict: true
finetuning_type: lora
### dataset
dataset_dir: 数据集路径
eval_dataset: 数据集
template: qwen
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: output/
overwrite_output_dir: true
### eval
predict_with_generate: true
复制代码
步调调用:
python vllm_infer.py vllm.yaml
复制代码
步调运行速率:
Processed prompts: 100%|█| 1000/1000 [01:56<00:00, 8.60it/s, est. speed input: 5169.35 toks/s, output: 811.57
复制代码
总结
本方案在原生 LlamaFactory 数据集的底子上,支持 LoRA 的 vllm 批量推理,能提升了推理效率。
进一步阅读
如果微调模型后,发现前文使用vllm模型批量结果不太好,可以参考与使用下述文章:
基于 LLamafactory 的异步API高效调用实现与速率对比.https://blog.csdn.net/sjxgghg/article/details/144176645
亲测,上述文章的 LLamafactory 部署模型,然后使用 Async API 调用后举行评估发现结果会好一些。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/)
Powered by Discuz! X3.4