【呆板学习】Qwen2大模型原理、训练及推理摆设实战

打印 上一主题 下一主题

主题 660|帖子 660|积分 1980



目录​​​​​​​
一、引言
二、模型简介
2.1 Qwen2 模型概述
2.2 Qwen2 模型架构
三、训练与推理
3.1 Qwen2 模型训练
3.2 Qwen2 模型推理
四、总结



一、引言

刚刚写完【呆板学习】Qwen1.5-14B-Chat大模型训练与推理实战 ,阿里Qwen就推出了Qwen2,相较于Qwen1.5中0.5B、1.8B、4B、7B、14B、32B、72B、110B等8个Dense模型以及1个14B(A2.7B)MoE模型共计9个模型,Qwen2包罗了0.5B、1.5B、7B、57B-A14B和72B共计5个尺寸模型。从尺寸上来讲,最关键的就是推出了57B-A14B这个更大尺寸的MoE模型,有人问为什么删除了14B这个针对32G显存的常用尺寸,其实对于57B-A14B剪枝一下就可以得到。
二、模型简介

2.1 Qwen2 模型概述

Qwen2对比Qwen1.5


  • 模型尺寸:将Qwen2-7B和Qwen2-72B的模型尺寸有32K提拔为128K



  • GQA(分组查询注意力):在Qwen1.5系列中,只有32B和110B的模型利用了GQA。这一次,所有尺寸的模型都利用了GQA,提供GQA加速推理和降低显存占用 

   分组查询注意力 (Grouped Query Attention) 是一种在大型语言模型中的多查询注意力 (MQA) 和多头注意力 (MHA) 之间举行插值的方法,它的目标是在保持 MQA 速度的同时实现 MHA 的质量 
  

  • tie embedding:针对小模型,由于embedding参数量较大,利用了tie embedding的方法让输入和输出层共享参数,增加非embedding参数的占比 
 结果对比
Qwen2-72B全方位围剿Llama3-70B,同时对比更大尺寸的Qwen1.5-110B也有很大提拔,官方表示来自于“预训练数据及训练方法的优化”。

2.2 Qwen2 模型架构

Qwen2仍然是一个典型decoder-only的transformers大模型结构,告急包括文本输入层embedding层decoder层输出层损失函数
​​​​​​​

通过AutoModelForCausalLM检察Qwen1.5-7B-Chat和Qwen2-7B-Instruct的模型结构,对比config.json发现:
  
 
   

  • 网络结构:无明显变革
  • 焦点网络Qwen2DecoderLayer层:由32层减少为28层(72B是80层)
  • Q、K、V、O隐层尺寸:由4096减少为3584(72B是8192)
  • attention heads:由32减少为28(72B是64)
  • kv head:由32减少为4(72B是8)
  • 滑动窗口(模型尺寸):由32768(32K)增长为131072(128K)(72B一样)
  • 词表:由151936增长为152064(72B一样)
  • intermediate_size(MLP交叉层):由11008增长为18944(72B是29568)
  可以看到此中有的参数增加有的参数减少,猜想是:
   

  • 减少的参数,并不会降低模型结果,反而能增加训练和推理效率,
  • 增大的参数:好比MLP中的intermediate_size,参数越多,模型表达本领越明显。
  
三、训练与推理

3.1 Qwen2 模型训练

在【呆板学习】Qwen1.5-14B-Chat大模型训练与推理实战 中,我们接纳LLaMA-Factory的webui举行训练,今天我们换成命令行的方式,对于LLaMA-Factory框架的摆设,可以参考我之前的文章:
AI智能体研发之路-模型篇(一):大模型训练框架LLaMA-Factory在国内网络情况下的安装、摆设及利用
该文在百度“LLaMA Factory 摆设”词条排行第一:
​​
假设你已经基于上文摆设了llama_factory的container,运行进入到container中
  1. docker exec -it llama_factory bash
复制代码
在app/目录下创建run_train.sh。
  
  1. CUDA_VISIBLE_DEVICES=2 llamafactory-cli train \
  2.     --stage sft \
  3.     --do_train True \
  4.     --model_name_or_path qwen/Qwen2-7B-Instruct \
  5.     --finetuning_type lora \
  6.     --template qwen \
  7.     --flash_attn auto \
  8.     --dataset_dir data \
  9.     --dataset alpaca_zh \
  10.     --cutoff_len 4096 \
  11.     --learning_rate 5e-05 \
  12.     --num_train_epochs 5.0 \
  13.     --max_samples 100000 \
  14.     --per_device_train_batch_size 4 \
  15.     --gradient_accumulation_steps 4 \
  16.     --lr_scheduler_type cosine \
  17.     --max_grad_norm 1.0 \
  18.     --logging_steps 10 \
  19.     --save_steps 1000 \
  20.     --warmup_steps 0 \
  21.     --optim adamw_torch \
  22.     --packing False \
  23.     --report_to none \
  24.     --output_dir saves/Qwen2-7B-Instruct/lora/train_2024-06-09-23-00 \
  25.     --fp16 True \
  26.     --lora_rank 32 \
  27.     --lora_alpha 16 \
  28.     --lora_dropout 0 \
  29.     --lora_target q_proj,v_proj \
  30.     --val_size 0.1 \
  31.     --evaluation_strategy steps \
  32.     --eval_steps 1000 \
  33.     --per_device_eval_batch_size 2 \
  34.     --load_best_model_at_end True \
  35.     --plot_loss True
复制代码
因为之前文章中重点讲的就是国内网络情况的LLaMA-Factory摆设,焦点就是接纳modelscope模型源取代huggingface模型源,这里脚本启动后,就会主动从modelscope下载指定的模型,这里是"qwen/Qwen2-7B-Instruct",下载完后启动训练

训练数据可以通过LLaMA-Factory/data/dataset_info.json文件举行配置,格式参考data目录下的其他数据文件。 好比构建成类型LLaMA-Factory/data/alpaca_zh_demo.json的格式

在LLaMA-Factory/data/dataset_info.json中复制一份举行配置: 

3.2 Qwen2 模型推理

Qwen2的官方文档中介绍了多种优化推理摆设的方式,包括基于hf transformers、vllm、llama.cpp、Ollama以及AWQ、GPTQ、GGUF等量化方式,告急因为Qwen2开源的Qwen2-72B、Qwen1.5-110B,宏大于GLM4、Baichuan等开源的10B量级小尺寸模型。必要考虑量化、分布式推理问题。​​今天重点介绍Qwen2-7B-Instruct在国内网络情况下的hf transformers推理测试,其他方法单开篇幅举行过细讲解。
呈上一份glm-4-9b-chat、qwen/Qwen2-7B-Instruct通用的极简代码:
  1. from modelscope import snapshot_download
  2. from transformers import AutoTokenizer, AutoModelForCausalLM
  3. #model_dir = snapshot_download('ZhipuAI/glm-4-9b-chat')
  4. model_dir = snapshot_download('qwen/Qwen2-7B-Instruct')
  5. #model_dir = snapshot_download('baichuan-inc/Baichuan2-13B-Chat')
  6. import torch
  7. device = "auto" # 也可以通过"coda:2"指定GPU
  8. tokenizer = AutoTokenizer.from_pretrained(model_dir,trust_remote_code=True)
  9. model = AutoModelForCausalLM.from_pretrained(model_dir,device_map=device,trust_remote_code=True)
  10. print(model)
  11. prompt = "介绍一下大语言模型"
  12. messages = [
  13.     {"role": "system", "content": "你是一个智能助理."},
  14.     {"role": "user", "content": prompt}
  15. ]
  16. text = tokenizer.apply_chat_template(
  17.     messages,
  18.     tokenize=False,
  19.     add_generation_prompt=True
  20. )
  21. model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
  22. """
  23. gen_kwargs = {"max_length": 512, "do_sample": True, "top_k": 1}
  24. with torch.no_grad():
  25.     outputs = model.generate(**model_inputs, **gen_kwargs)
  26.     #print(tokenizer.decode(outputs[0],skip_special_tokens=True))
  27.     outputs = outputs[:, model_inputs['input_ids'].shape[1]:] #切除system、user等对话前缀
  28.     print(tokenizer.decode(outputs[0], skip_special_tokens=True))
  29. """
  30. generated_ids = model.generate(
  31.     model_inputs.input_ids,
  32.     max_new_tokens=512
  33. )
  34. generated_ids = [
  35.     output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
  36. ]
  37. response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  38. print(response)
复制代码
  该代码有几个特点:
  

  • 网络:从modelscope下载模型文件,解决通过AutoModelForCausalLM模型头下载hf模型慢的问题
  • 通用:实用于glm-4-9b-chat、qwen/Qwen2-7B-Instruct
  • apply_chat_template() :注意!接纳 generate()替代旧方法中的chat() 。这里利用了 apply_chat_template() 函数将消息转换为模型能够明确的格式。此中的 add_generation_prompt 参数用于在输入中添加生成提示,该提示指向 <|im_start|>assistant\n 。
  •  tokenizer.batch_decode() :通过 tokenizer.batch_decode() 函数对相应举行解码。
  运行结果:
 
除了该极简代码,我针对网络情况对官方git提供的demo代码举行了改造:
cli_demo:
接纳modelscope的AutoModelForCausalLM, AutoTokenizer模型头取代transformers对应的模型头,举行模型主动下载
  1. # Copyright (c) Alibaba Cloud.
  2. #
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. """A simple command-line interactive chat demo."""
  6. import argparse
  7. import os
  8. import platform
  9. import shutil
  10. from copy import deepcopy
  11. from threading import Thread
  12. import torch
  13. from modelscope import AutoModelForCausalLM, AutoTokenizer
  14. from transformers import TextIteratorStreamer
  15. from transformers.trainer_utils import set_seed
  16. DEFAULT_CKPT_PATH = 'Qwen/Qwen2-7B-Instruct'
  17. _WELCOME_MSG = '''\
  18. Welcome to use Qwen2-Instruct model, type text to start chat, type :h to show command help.
  19. (欢迎使用 Qwen2-Instruct 模型,输入内容即可进行对话,:h 显示命令帮助。)
  20. Note: This demo is governed by the original license of Qwen2.
  21. We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc.
  22. (注:本演示受Qwen2的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)
  23. '''
  24. _HELP_MSG = '''\
  25. Commands:
  26.     :help / :h              Show this help message              显示帮助信息
  27.     :exit / :quit / :q      Exit the demo                       退出Demo
  28.     :clear / :cl            Clear screen                        清屏
  29.     :clear-history / :clh   Clear history                       清除对话历史
  30.     :history / :his         Show history                        显示对话历史
  31.     :seed                   Show current random seed            显示当前随机种子
  32.     :seed <N>               Set random seed to <N>              设置随机种子
  33.     :conf                   Show current generation config      显示生成配置
  34.     :conf <key>=<value>     Change generation config            修改生成配置
  35.     :reset-conf             Reset generation config             重置生成配置
  36. '''
  37. _ALL_COMMAND_NAMES = [
  38.     'help', 'h', 'exit', 'quit', 'q', 'clear', 'cl', 'clear-history', 'clh', 'history', 'his',
  39.     'seed', 'conf', 'reset-conf',
  40. ]
  41. def _setup_readline():
  42.     try:
  43.         import readline
  44.     except ImportError:
  45.         return
  46.     _matches = []
  47.     def _completer(text, state):
  48.         nonlocal _matches
  49.         if state == 0:
  50.             _matches = [cmd_name for cmd_name in _ALL_COMMAND_NAMES if cmd_name.startswith(text)]
  51.         if 0 <= state < len(_matches):
  52.             return _matches[state]
  53.         return None
  54.     readline.set_completer(_completer)
  55.     readline.parse_and_bind('tab: complete')
  56. def _load_model_tokenizer(args):
  57.     tokenizer = AutoTokenizer.from_pretrained(
  58.         args.checkpoint_path, resume_download=True,
  59.     )
  60.     if args.cpu_only:
  61.         device_map = "cpu"
  62.     else:
  63.         device_map = "auto"
  64.     model = AutoModelForCausalLM.from_pretrained(
  65.         args.checkpoint_path,
  66.         torch_dtype="auto",
  67.         device_map=device_map,
  68.         resume_download=True,
  69.     ).eval()
  70.     model.generation_config.max_new_tokens = 2048    # For chat.
  71.     return model, tokenizer
  72. def _gc():
  73.     import gc
  74.     gc.collect()
  75.     if torch.cuda.is_available():
  76.         torch.cuda.empty_cache()
  77. def _clear_screen():
  78.     if platform.system() == "Windows":
  79.         os.system("cls")
  80.     else:
  81.         os.system("clear")
  82. def _print_history(history):
  83.     terminal_width = shutil.get_terminal_size()[0]
  84.     print(f'History ({len(history)})'.center(terminal_width, '='))
  85.     for index, (query, response) in enumerate(history):
  86.         print(f'User[{index}]: {query}')
  87.         print(f'QWen[{index}]: {response}')
  88.     print('=' * terminal_width)
  89. def _get_input() -> str:
  90.     while True:
  91.         try:
  92.             message = input('User> ').strip()
  93.         except UnicodeDecodeError:
  94.             print('[ERROR] Encoding error in input')
  95.             continue
  96.         except KeyboardInterrupt:
  97.             exit(1)
  98.         if message:
  99.             return message
  100.         print('[ERROR] Query is empty')
  101. def _chat_stream(model, tokenizer, query, history):
  102.     conversation = [
  103.         {'role': 'system', 'content': 'You are a helpful assistant.'},
  104.     ]
  105.     for query_h, response_h in history:
  106.         conversation.append({'role': 'user', 'content': query_h})
  107.         conversation.append({'role': 'assistant', 'content': response_h})
  108.     conversation.append({'role': 'user', 'content': query})
  109.     inputs = tokenizer.apply_chat_template(
  110.         conversation,
  111.         add_generation_prompt=True,
  112.         return_tensors='pt',
  113.     )
  114.     inputs = inputs.to(model.device)
  115.     streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True, timeout=60.0, skip_special_tokens=True)
  116.     generation_kwargs = dict(
  117.         input_ids=inputs,
  118.         streamer=streamer,
  119.     )
  120.     thread = Thread(target=model.generate, kwargs=generation_kwargs)
  121.     thread.start()
  122.     for new_text in streamer:
  123.         yield new_text
  124. def main():
  125.     parser = argparse.ArgumentParser(
  126.         description='QWen2-Instruct command-line interactive chat demo.')
  127.     parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
  128.                         help="Checkpoint name or path, default to %(default)r")
  129.     parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
  130.     parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
  131.     args = parser.parse_args()
  132.     history, response = [], ''
  133.     model, tokenizer = _load_model_tokenizer(args)
  134.     orig_gen_config = deepcopy(model.generation_config)
  135.     _setup_readline()
  136.     _clear_screen()
  137.     print(_WELCOME_MSG)
  138.     seed = args.seed
  139.     while True:
  140.         query = _get_input()
  141.         # Process commands.
  142.         if query.startswith(':'):
  143.             command_words = query[1:].strip().split()
  144.             if not command_words:
  145.                 command = ''
  146.             else:
  147.                 command = command_words[0]
  148.             if command in ['exit', 'quit', 'q']:
  149.                 break
  150.             elif command in ['clear', 'cl']:
  151.                 _clear_screen()
  152.                 print(_WELCOME_MSG)
  153.                 _gc()
  154.                 continue
  155.             elif command in ['clear-history', 'clh']:
  156.                 print(f'[INFO] All {len(history)} history cleared')
  157.                 history.clear()
  158.                 _gc()
  159.                 continue
  160.             elif command in ['help', 'h']:
  161.                 print(_HELP_MSG)
  162.                 continue
  163.             elif command in ['history', 'his']:
  164.                 _print_history(history)
  165.                 continue
  166.             elif command in ['seed']:
  167.                 if len(command_words) == 1:
  168.                     print(f'[INFO] Current random seed: {seed}')
  169.                     continue
  170.                 else:
  171.                     new_seed_s = command_words[1]
  172.                     try:
  173.                         new_seed = int(new_seed_s)
  174.                     except ValueError:
  175.                         print(f'[WARNING] Fail to change random seed: {new_seed_s!r} is not a valid number')
  176.                     else:
  177.                         print(f'[INFO] Random seed changed to {new_seed}')
  178.                         seed = new_seed
  179.                     continue
  180.             elif command in ['conf']:
  181.                 if len(command_words) == 1:
  182.                     print(model.generation_config)
  183.                 else:
  184.                     for key_value_pairs_str in command_words[1:]:
  185.                         eq_idx = key_value_pairs_str.find('=')
  186.                         if eq_idx == -1:
  187.                             print('[WARNING] format: <key>=<value>')
  188.                             continue
  189.                         conf_key, conf_value_str = key_value_pairs_str[:eq_idx], key_value_pairs_str[eq_idx + 1:]
  190.                         try:
  191.                             conf_value = eval(conf_value_str)
  192.                         except Exception as e:
  193.                             print(e)
  194.                             continue
  195.                         else:
  196.                             print(f'[INFO] Change config: model.generation_config.{conf_key} = {conf_value}')
  197.                             setattr(model.generation_config, conf_key, conf_value)
  198.                 continue
  199.             elif command in ['reset-conf']:
  200.                 print('[INFO] Reset generation config')
  201.                 model.generation_config = deepcopy(orig_gen_config)
  202.                 print(model.generation_config)
  203.                 continue
  204.             else:
  205.                 # As normal query.
  206.                 pass
  207.         # Run chat.
  208.         set_seed(seed)
  209.         _clear_screen()
  210.         print(f"\nUser: {query}")
  211.         print(f"\nQwen2-Instruct: ", end="")
  212.         try:
  213.             partial_text = ''
  214.             for new_text in _chat_stream(model, tokenizer, query, history):
  215.                 print(new_text, end='', flush=True)
  216.                 partial_text += new_text
  217.             response = partial_text
  218.             print()
  219.         except KeyboardInterrupt:
  220.             print('[WARNING] Generation interrupted')
  221.             continue
  222.         history.append((query, response))
  223. if __name__ == "__main__":
  224.     main()
复制代码
web_demo.py:
同上,接纳modelscope取代transformers饮用AutoModelForCausalLM, AutoTokenizer,解决模型下载问题
输入参数:参加-g,指定运行的GPU
[code]# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""A simple web interactive chat demo based on gradio."""

from argparse import ArgumentParser
from threading import Thread

import gradio as gr
import torch
from modelscope import AutoModelForCausalLM, AutoTokenizer
from transformers import TextIteratorStreamer

DEFAULT_CKPT_PATH = 'Qwen/Qwen2-7B-Instruct'


def _get_args():
    parser = ArgumentParser()
    parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
                        help="Checkpoint name or path, default to %(default)r")
    parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")

    parser.add_argument("--share", action="store_true", default=False,
                        help="Create a publicly shareable link for the interface.")
    parser.add_argument("--inbrowser", action="store_true", default=False,
                        help="Automatically launch the interface in a new tab on the default browser.")
    parser.add_argument("--server-port", type=int, default=18003,
                        help="Demo server port.")
    parser.add_argument("--server-name", type=str, default="127.0.0.1",
                        help="Demo server name.")
    parser.add_argument("-g","--gpus",type=str,default="auto",help="set gpu numbers")

    args = parser.parse_args()
    return args


def _load_model_tokenizer(args):
    tokenizer = AutoTokenizer.from_pretrained(
        args.checkpoint_path, resume_download=True,
    )

    if args.cpu_only:
        device_map = "cpu"
    elif args.gpus=="auto":
        device_map = args.gpus
    else:
        device_map = "cuda:"+args.gpus

    model = AutoModelForCausalLM.from_pretrained(
        args.checkpoint_path,
        torch_dtype="auto",
        device_map=device_map,
        resume_download=True,
    ).eval()
    model.generation_config.max_new_tokens = 2048   # For chat.

    return model, tokenizer


def _chat_stream(model, tokenizer, query, history):
    conversation = [
        {'role': 'system', 'content': 'You are a helpful assistant.'},
    ]
    for query_h, response_h in history:
        conversation.append({'role': 'user', 'content': query_h})
        conversation.append({'role': 'assistant', 'content': response_h})
    conversation.append({'role': 'user', 'content': query})
    inputs = tokenizer.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        return_tensors='pt',
    )
    inputs = inputs.to(model.device)
    streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True, timeout=60.0, skip_special_tokens=True)
    generation_kwargs = dict(
        input_ids=inputs,
        streamer=streamer,
    )
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    for new_text in streamer:
        yield new_text


def _gc():
    import gc
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


def _launch_demo(args, model, tokenizer):

    def predict(_query, _chatbot, _task_history):
        print(f"User: {_query}")
        _chatbot.append((_query, ""))
        full_response = ""
        response = ""
        for new_text in _chat_stream(model, tokenizer, _query, history=_task_history):
            response += new_text
            _chatbot[-1] = (_query, response)

            yield _chatbot
            full_response = response

        print(f"History: {_task_history}")
        _task_history.append((_query, full_response))
        print(f"Qwen2-Instruct: {full_response}")

    def regenerate(_chatbot, _task_history):
        if not _task_history:
            yield _chatbot
            return
        item = _task_history.pop(-1)
        _chatbot.pop(-1)
        yield from predict(item[0], _chatbot, _task_history)

    def reset_user_input():
        return gr.update(value="")

    def reset_state(_chatbot, _task_history):
        _task_history.clear()
        _chatbot.clear()
        _gc()
        return _chatbot

    with gr.Blocks() as demo:
        gr.Markdown("""\
<p align="center"><img src="https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/logo_qwen2.png" style="height: 80px"/><p>""")
        gr.Markdown("""<center><font size=8>Qwen2 Chat Bot</center>""")
        gr.Markdown(
            """\
<center><font size=3>This WebUI is based on Qwen2-Instruct, developed by Alibaba Cloud. \
(本WebUI基于Qwen2-Instruct打造,实现聊天机器人功能。)</center>""")
        gr.Markdown("""\
<center><font size=4>
Qwen2-7B-Instruct <a href="https://modelscope.cn/models/qwen/Qwen2-7B-Instruct/summary">
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

没腿的鸟

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表