原文:
Shikra: Unleashing Multimodal LLM’s Referential Dialogue Magic
代码:
https://github.com/shikras/shikra
模型:
https://huggingface.co/shikras/shikra-7b-delta-v1
https://huggingface.co/shikras/shikra7b-delta-v1-0708
第一个是论文用的,第二个会有迭代。
本人的shikra论文解读,逐行解读,非常具体!
多模态大模型目标检测,精读,Shikra
部署:
- 下载GitHub工程,和shikras的模型参数,注意,还要下载LLaMA-7b的模型;
- 创建环境:
- conda create -n shikra python=3.10
- conda activate shikra
- pip install -r requirements.txt
复制代码 反面我运行的时间报缺包了,又pip install了以下包,不过每个人环境差别:
- pip install uvicorn
- pip install mmengine
复制代码 然后还会报错:
- File "/usr/local/lib/python3.10/dist-packages/cv2/typing/__init__.py", line 171, in <module>
- LayerId = cv2.dnn.DictValue
- AttributeError: module 'cv2.dnn' has no attribute 'DictValue'
复制代码 解决方案:
修改/usr/local/lib/python3.10/dist-packages/cv2/typing/init.py
注释掉LayerId = cv2.dnn.DictValue这行即可。
- 权重下载和合并
shikra官方提供的模型权重必要和llama1-7b合并之后才能用,然而llama1必要申请,比力贫苦,在hf上找到了平替(这一步我走了好久QwQ):
https://huggingface.co/huggyllama/llama-7b
各人本身下载,然后运行官方提供的合并代码:
- python mllm/models/shikra/apply_delta.py \
- --base /path/to/llama-7b \
- --target /output/path/to/shikra-7b-merge \
- --delta shikras/shikra-7b-delta-v1
复制代码 得到了可用的模型参数shikra-7b-merge。
注意要把参数文件夹里config里的模型路径改成merge版的。
此外还必要下载clip模型参数:
https://huggingface.co/openai/clip-vit-large-patch14
代码和配置文件中有多处调用/openai/clip-vit-large-patch14,要改成当地版本。如果不预先下载,应该会在运行时自动下载,各人看网络环境自行选择。
- 我写的demo文件,用于在命令行测试模型效果,主要是为了不消gradio和fastapi这些东西。
- import argparse
- import os
- import sys
- import base64
- import logging
- import time
- from pathlib import Path
- from io import BytesIO
- import torch
- import uvicorn
- import transformers
- from PIL import Image
- from mmengine import Config
- from transformers import BitsAndBytesConfig
- sys.path.append(str(Path(__file__).parent.parent.parent))
- from mllm.dataset.process_function import PlainBoxFormatter
- from mllm.dataset.builder import prepare_interactive
- from mllm.models.builder.build_shikra import load_pretrained_shikra
- from mllm.dataset.utils.transform import expand2square, box_xyxy_expand2square
- # Set up logging
- log_level = logging.DEBUG
- transformers.logging.set_verbosity(log_level)
- transformers.logging.enable_default_handler()
- transformers.logging.enable_explicit_format()
- # prompt for coco
- # Argument parsing
- parser = argparse.ArgumentParser("Shikra Local Demo")
- parser.add_argument('--model_path', default = "xxx/shikra-merge", help="Path to the model")
- parser.add_argument('--load_in_8bit', action='store_true', help="Load model in 8-bit precision")
- parser.add_argument('--image_path', default = "xxx/shikra-main/mllm/demo/assets/ball.jpg", help="Path to the image file")
- parser.add_argument('--text', default="What do you see in this image? Please mention the objects and their locations using the format [x1,y1,x2,y2].", help="Text prompt")
- parser.add_argument('--boxes_value', nargs='+', type=int, default=[], help="Bounding box values (x1, y1, x2, y2)")
- parser.add_argument('--boxes_seq', nargs='+', type=int, default=[], help="Sequence of bounding boxes")
- parser.add_argument('--do_sample', action='store_true', help="Use sampling during generation")
- parser.add_argument('--max_length', type=int, default=512, help="Maximum length of the output")
- parser.add_argument('--top_p', type=float, default=1.0, help="Top-p value for sampling")
- parser.add_argument('--temperature', type=float, default=1.0, help="Temperature for sampling")
- args = parser.parse_args()
- model_name_or_path = args.model_path
- # Model initialization
- model_args = Config(dict(
- type='shikra',
- version='v1',
- # checkpoint config
- cache_dir=None,
- model_name_or_path=model_name_or_path,
- vision_tower=r'xxx/clip-vit-large-patch14',
- pretrain_mm_mlp_adapter=None,
- # model config
- mm_vision_select_layer=-2,
- model_max_length=2048,
- # finetune config
- freeze_backbone=False,
- tune_mm_mlp_adapter=False,
- freeze_mm_mlp_adapter=False,
- # data process config
- is_multimodal=True,
- sep_image_conv_front=False,
- image_token_len=256,
- mm_use_im_start_end=True,
- target_processor=dict(
- boxes=dict(type='PlainBoxFormatter'),
- ),
- process_func_args=dict(
- conv=dict(type='ShikraConvProcess'),
- target=dict(type='BoxFormatProcess'),
- text=dict(type='ShikraTextProcess'),
- image=dict(type='ShikraImageProcessor'),
- ),
- conv_args=dict(
- conv_template='vicuna_v1.1',
- transforms=dict(type='Expand2square'),
- tokenize_kwargs=dict(truncation_size=None),
- ),
- gen_kwargs_set_pad_token_id=True,
- gen_kwargs_set_bos_token_id=True,
- gen_kwargs_set_eos_token_id=True,
- ))
- training_args = Config(dict(
- bf16=False,
- fp16=True,
- device='cuda',
- fsdp=None,
- ))
- quantization_kwargs = dict(
- quantization_config=BitsAndBytesConfig(
- load_in_8bit=args.load_in_8bit,
- )
- ) if args.load_in_8bit else dict()
- model, preprocessor = load_pretrained_shikra(model_args, training_args, **quantization_kwargs)
- # Convert the model and vision tower to float16
- if not getattr(model, 'is_quantized', False):
- model.to(dtype=torch.float16, device=torch.device('cuda'))
- if not getattr(model.model.vision_tower[0], 'is_quantized', False):
- model.model.vision_tower[0].to(dtype=torch.float16, device=torch.device('cuda'))
- preprocessor['target'] = {'boxes': PlainBoxFormatter()}
- tokenizer = preprocessor['text']
- # Load and preprocess the image
- pil_image = Image.open(args.image_path).convert("RGB")
- ds = prepare_interactive(model_args, preprocessor)
- image = expand2square(pil_image)
- boxes_value = [box_xyxy_expand2square(box, w=pil_image.width, h=pil_image.height) for box in zip(args.boxes_value[::2], args.boxes_value[1::2], args.boxes_value[2::2], args.boxes_value[3::2])]
- ds.set_image(image)
- ds.append_message(role=ds.roles[0], message=args.text, boxes=boxes_value, boxes_seq=args.boxes_seq)
- model_inputs = ds.to_model_input()
- model_inputs['images'] = model_inputs['images'].to(torch.float16)
- # Generate
- gen_kwargs = dict(
- use_cache=True,
- do_sample=args.do_sample,
- pad_token_id=tokenizer.pad_token_id,
- bos_token_id=tokenizer.bos_token_id,
- eos_token_id=tokenizer.eos_token_id,
- max_new_tokens=args.max_length,
- top_p=args.top_p,
- temperature=args.temperature,
- )
- input_ids = model_inputs['input_ids']
- st_time = time.time()
- with torch.inference_mode():
- with torch.autocast(device_type='cuda', dtype=torch.float16):
- output_ids = model.generate(**model_inputs, **gen_kwargs)
- print(f"Generated in {time.time() - st_time} seconds")
- input_token_len = input_ids.shape[-1]
- response = tokenizer.batch_decode(output_ids[:, input_token_len:])[0]
- print(f"Response: {response}")
复制代码 这么良心,点个关注吧,会连续更新多模态大模型相关内容。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |