关于 c10::Half 类型和float不匹配

打印 上一主题 下一主题

主题 904|帖子 904|积分 2712

相关错误

  1. # error-1 ; (all-no-half) self-attn RuntimeError: expected m1 and m2 to have the same dtype, but got: float != c10::Half
  2. # error-2 : (embed-half) self-attn RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
  3. # error-3 : model(half) embed(no-half) conv RuntimeError: Input type (float) and bias type (c10::Half) should be the same
  4. # error-4 : model(half) embed(half) conv RuntimeError: Input type (float) and bias type (c10::Half) should be the same
  5. # cuda + half all : RuntimeError: Input type (float) and bias type (c10::Half) should be the same
复制代码
我在跑大模子推理的时间,碰到了上面的错误。
起首有一个题目需要考虑:

  • 我希望模子以半精度的方式推理,所以在from_pretrained的时间,是以float16的方式加载的
  1. self.llama_model = LlamaForCausalLM.from_pretrained(
  2.                 args.llama_model,  torch_dtype=torch.float16, )
复制代码

  • 我希望模子可以在gpu上面推理,但是我默认了模子会自动加载到gpu上面。。。
办理方法


  • 检查llama模子是不是正确加载到gpu,一半出现 c10:Half 这个类型,模子很大概率是加载到CPU上面去推理的,所以只要修改到gpu上就不会报错了
  • 模子推理的时间,记得加上autocase
  1. with torch.cuda.amp.autocast():
  2. ....
复制代码
末了代码

由于是修改R2genGPT的,所以代码如下:
  1. class Generator:
  2.     def __init__(self):
  3.         pass
  4.     def generate(self, input_conv, img_list):
  5.         raise NotImplementedError
  6.    
  7.    
  8. class R2genGPT_shallow(Generator):
  9.     def __init__(self):
  10.         super().__init__()
  11.         args = parser.parse_args()
  12.         # args.precision = "fp16"
  13.         args.delta_file = "../checkpoints/R2genGPT/shallow_checkpoint_step14102.pth"
  14.         args.vision_model = "microsoft/swin-base-patch4-window7-224"
  15.         args.llama_model = "../checkpoints/Llama-2-7b-chat-hf"
  16.         self.filed_parser = FieldParser(args)
  17.         self.model = R2GenGPT(args)
  18.         self.model.eval()  
  19.         self.model.cuda()
  20.         print("device : ", self.model.device)
  21.     def adapt(self, query):
  22.         query = query.replace("<image>", " ")
  23.         return query
  24.     def get_image_tensor(self, img_file):
  25.         with Image.open(img_file) as pil:
  26.             array = np.array(pil, dtype=np.uint8)
  27.             if array.shape[-1] != 3 or len(array.shape) != 3:
  28.                 array = np.array(pil.convert("RGB"), dtype=np.uint8)
  29.             image = self.filed_parser._parse_image(array)
  30.             image = image.to(self.model.device)
  31.         return image
  32.    
  33.     def generate(self, query, img_list):
  34.         self.model.llama_tokenizer.padding_side = "right"
  35.         images = []
  36.         for img_file in img_list:
  37.             image = self.get_image_tensor(img_file)
  38.             images.append(image.unsqueeze(0))
  39.         
  40.         self.model.prompt = self.adapt(query)
  41.         img_embeds, atts_img = self.model.encode_img(images)
  42.         img_embeds = self.model.layer_norm(img_embeds)
  43.         img_embeds, atts_img = self.model.prompt_wrap(img_embeds, atts_img)
  44.         batch_size = img_embeds.shape[0]
  45.         bos = torch.ones([batch_size, 1],
  46.                          dtype=atts_img.dtype,
  47.                          device=atts_img.device) * self.model.llama_tokenizer.bos_token_id
  48.         bos_embeds = self.model.embed_tokens(bos)
  49.         atts_bos = atts_img[:, :1]
  50.         inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
  51.         attention_mask = torch.cat([atts_bos, atts_img], dim=1)
  52.         with torch.inference_mode():
  53.             with torch.cuda.amp.autocast():
  54.                 outputs = self.model.llama_model.generate(
  55.                     inputs_embeds=inputs_embeds,
  56.                     num_beams=self.model.hparams.beam_size,
  57.                     do_sample=self.model.hparams.do_sample,
  58.                     min_new_tokens=self.model.hparams.min_new_tokens,
  59.                     max_new_tokens=self.model.hparams.max_new_tokens,
  60.                     repetition_penalty=self.model.hparams.repetition_penalty,
  61.                     length_penalty=self.model.hparams.length_penalty,
  62.                     temperature=self.model.hparams.temperature,
  63.                 )
  64.             
  65.         answer = self.model.decode(outputs[0])
  66.         return answer
  67.    
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

科技颠覆者

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表