目标检测多模态大模型实践:貌似是全网唯一Shikra的部署和测试教程,内含各 ...

汕尾海湾  金牌会员 | 2024-8-27 23:48:55 | 来自手机 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 555|帖子 555|积分 1665


原文:
Shikra: Unleashing Multimodal LLM’s Referential Dialogue Magic
代码:
https://github.com/shikras/shikra
模型:
https://huggingface.co/shikras/shikra-7b-delta-v1
https://huggingface.co/shikras/shikra7b-delta-v1-0708
第一个是论文用的,第二个会有迭代。
本人的shikra论文解读,逐行解读,非常具体!
多模态大模型目标检测,精读,Shikra
部署:

  • 下载GitHub工程,和shikras的模型参数,注意,还要下载LLaMA-7b的模型;
  • 创建环境:
  1. conda create -n shikra python=3.10
  2. conda activate shikra
  3. pip install -r requirements.txt
复制代码
反面我运行的时间报缺包了,又pip install了以下包,不过每个人环境差别:
  1. pip install uvicorn
  2. pip install mmengine
复制代码
然后还会报错:
  1. File "/usr/local/lib/python3.10/dist-packages/cv2/typing/__init__.py", line 171, in <module>
  2.     LayerId = cv2.dnn.DictValue
  3. AttributeError: module 'cv2.dnn' has no attribute 'DictValue'
复制代码
解决方案:
修改/usr/local/lib/python3.10/dist-packages/cv2/typing/init.py
注释掉LayerId = cv2.dnn.DictValue这行即可。

  • 权重下载和合并
    shikra官方提供的模型权重必要和llama1-7b合并之后才能用,然而llama1必要申请,比力贫苦,在hf上找到了平替(这一步我走了好久QwQ):
    https://huggingface.co/huggyllama/llama-7b
    各人本身下载,然后运行官方提供的合并代码:
  1. python mllm/models/shikra/apply_delta.py \
  2.     --base /path/to/llama-7b \
  3.     --target /output/path/to/shikra-7b-merge \
  4.     --delta shikras/shikra-7b-delta-v1
复制代码
得到了可用的模型参数shikra-7b-merge。
注意要把参数文件夹里config里的模型路径改成merge版的。
此外还必要下载clip模型参数:
https://huggingface.co/openai/clip-vit-large-patch14
代码和配置文件中有多处调用/openai/clip-vit-large-patch14,要改成当地版本。如果不预先下载,应该会在运行时自动下载,各人看网络环境自行选择。

  • 我写的demo文件,用于在命令行测试模型效果,主要是为了不消gradiofastapi这些东西。
  1. import argparse
  2. import os
  3. import sys
  4. import base64
  5. import logging
  6. import time
  7. from pathlib import Path
  8. from io import BytesIO
  9. import torch
  10. import uvicorn
  11. import transformers
  12. from PIL import Image
  13. from mmengine import Config
  14. from transformers import BitsAndBytesConfig
  15. sys.path.append(str(Path(__file__).parent.parent.parent))
  16. from mllm.dataset.process_function import PlainBoxFormatter
  17. from mllm.dataset.builder import prepare_interactive
  18. from mllm.models.builder.build_shikra import load_pretrained_shikra
  19. from mllm.dataset.utils.transform import expand2square, box_xyxy_expand2square
  20. # Set up logging
  21. log_level = logging.DEBUG
  22. transformers.logging.set_verbosity(log_level)
  23. transformers.logging.enable_default_handler()
  24. transformers.logging.enable_explicit_format()
  25. # prompt for coco
  26. # Argument parsing
  27. parser = argparse.ArgumentParser("Shikra Local Demo")
  28. parser.add_argument('--model_path', default = "xxx/shikra-merge", help="Path to the model")
  29. parser.add_argument('--load_in_8bit', action='store_true', help="Load model in 8-bit precision")
  30. parser.add_argument('--image_path', default = "xxx/shikra-main/mllm/demo/assets/ball.jpg", help="Path to the image file")
  31. parser.add_argument('--text', default="What do you see in this image? Please mention the objects and their locations using the format [x1,y1,x2,y2].", help="Text prompt")
  32. parser.add_argument('--boxes_value', nargs='+', type=int, default=[], help="Bounding box values (x1, y1, x2, y2)")
  33. parser.add_argument('--boxes_seq', nargs='+', type=int, default=[], help="Sequence of bounding boxes")
  34. parser.add_argument('--do_sample', action='store_true', help="Use sampling during generation")
  35. parser.add_argument('--max_length', type=int, default=512, help="Maximum length of the output")
  36. parser.add_argument('--top_p', type=float, default=1.0, help="Top-p value for sampling")
  37. parser.add_argument('--temperature', type=float, default=1.0, help="Temperature for sampling")
  38. args = parser.parse_args()
  39. model_name_or_path = args.model_path
  40. # Model initialization
  41. model_args = Config(dict(
  42.     type='shikra',
  43.     version='v1',
  44.     # checkpoint config
  45.     cache_dir=None,
  46.     model_name_or_path=model_name_or_path,
  47.     vision_tower=r'xxx/clip-vit-large-patch14',
  48.     pretrain_mm_mlp_adapter=None,
  49.     # model config
  50.     mm_vision_select_layer=-2,
  51.     model_max_length=2048,
  52.     # finetune config
  53.     freeze_backbone=False,
  54.     tune_mm_mlp_adapter=False,
  55.     freeze_mm_mlp_adapter=False,
  56.     # data process config
  57.     is_multimodal=True,
  58.     sep_image_conv_front=False,
  59.     image_token_len=256,
  60.     mm_use_im_start_end=True,
  61.     target_processor=dict(
  62.         boxes=dict(type='PlainBoxFormatter'),
  63.     ),
  64.     process_func_args=dict(
  65.         conv=dict(type='ShikraConvProcess'),
  66.         target=dict(type='BoxFormatProcess'),
  67.         text=dict(type='ShikraTextProcess'),
  68.         image=dict(type='ShikraImageProcessor'),
  69.     ),
  70.     conv_args=dict(
  71.         conv_template='vicuna_v1.1',
  72.         transforms=dict(type='Expand2square'),
  73.         tokenize_kwargs=dict(truncation_size=None),
  74.     ),
  75.     gen_kwargs_set_pad_token_id=True,
  76.     gen_kwargs_set_bos_token_id=True,
  77.     gen_kwargs_set_eos_token_id=True,
  78. ))
  79. training_args = Config(dict(
  80.     bf16=False,
  81.     fp16=True,
  82.     device='cuda',
  83.     fsdp=None,
  84. ))
  85. quantization_kwargs = dict(
  86.     quantization_config=BitsAndBytesConfig(
  87.         load_in_8bit=args.load_in_8bit,
  88.     )
  89. ) if args.load_in_8bit else dict()
  90. model, preprocessor = load_pretrained_shikra(model_args, training_args, **quantization_kwargs)
  91. # Convert the model and vision tower to float16
  92. if not getattr(model, 'is_quantized', False):
  93.     model.to(dtype=torch.float16, device=torch.device('cuda'))
  94. if not getattr(model.model.vision_tower[0], 'is_quantized', False):
  95.     model.model.vision_tower[0].to(dtype=torch.float16, device=torch.device('cuda'))
  96. preprocessor['target'] = {'boxes': PlainBoxFormatter()}
  97. tokenizer = preprocessor['text']
  98. # Load and preprocess the image
  99. pil_image = Image.open(args.image_path).convert("RGB")
  100. ds = prepare_interactive(model_args, preprocessor)
  101. image = expand2square(pil_image)
  102. boxes_value = [box_xyxy_expand2square(box, w=pil_image.width, h=pil_image.height) for box in zip(args.boxes_value[::2], args.boxes_value[1::2], args.boxes_value[2::2], args.boxes_value[3::2])]
  103. ds.set_image(image)
  104. ds.append_message(role=ds.roles[0], message=args.text, boxes=boxes_value, boxes_seq=args.boxes_seq)
  105. model_inputs = ds.to_model_input()
  106. model_inputs['images'] = model_inputs['images'].to(torch.float16)
  107. # Generate
  108. gen_kwargs = dict(
  109.     use_cache=True,
  110.     do_sample=args.do_sample,
  111.     pad_token_id=tokenizer.pad_token_id,
  112.     bos_token_id=tokenizer.bos_token_id,
  113.     eos_token_id=tokenizer.eos_token_id,
  114.     max_new_tokens=args.max_length,
  115.     top_p=args.top_p,
  116.     temperature=args.temperature,
  117. )
  118. input_ids = model_inputs['input_ids']
  119. st_time = time.time()
  120. with torch.inference_mode():
  121.     with torch.autocast(device_type='cuda', dtype=torch.float16):
  122.         output_ids = model.generate(**model_inputs, **gen_kwargs)
  123. print(f"Generated in {time.time() - st_time} seconds")
  124. input_token_len = input_ids.shape[-1]
  125. response = tokenizer.batch_decode(output_ids[:, input_token_len:])[0]
  126. print(f"Response: {response}")
复制代码
这么良心,点个关注吧,会连续更新多模态大模型相关内容。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

汕尾海湾

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表