ToB企服应用市场:ToB评测及商务社交产业平台

标题: 办理Stable Diffusion TensorRT转换模型报错cpu and cuda:0! (when checkin [打印本页]

作者: 尚未崩坏    时间: 2024-7-24 09:09
标题: 办理Stable Diffusion TensorRT转换模型报错cpu and cuda:0! (when checkin
纪录Stable Diffusion webUI TensorRT插件使用过程的报错:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
拷贝下面的代码覆盖extensions\stable-diffusion-webui-tensorrt里的export_onnx.py文件,将模型和相关的张量移动到GPU,即可办理。
  1. import os
  2. from modules import sd_hijack, sd_unet
  3. from modules import shared, devices
  4. import torch
  5. def export_current_unet_to_onnx(filename, opset_version=17):
  6.     if torch.cuda.is_available():  
  7.        print("CUDA is available")  
  8.     else:  
  9.         print("CUDA is not available")
  10.     device = 'cuda' if torch.cuda.is_available() else 'cpu'  # 根据CUDA是否可用选择设备  
  11.     shared.sd_model.model.diffusion_model.to(device)
  12.     x = torch.randn(1, 4, 16, 16).to(devices.device, devices.dtype)
  13.     timesteps = torch.zeros((1,)).to(devices.device, devices.dtype) + 500
  14.     context = torch.randn(1, 77, 768).to(devices.device, devices.dtype)
  15.     x = x.to(device)  
  16.     timesteps = timesteps.to(device)  
  17.     context = context.to(device)  
  18.     print(x.device, timesteps.device, context.device)
  19.     def disable_checkpoint(self):
  20.         if getattr(self, 'use_checkpoint', False) == True:
  21.             self.use_checkpoint = False
  22.         if getattr(self, 'checkpoint', False) == True:
  23.             self.checkpoint = False
  24.     shared.sd_model.model.diffusion_model.apply(disable_checkpoint)
  25.     sd_unet.apply_unet("None")
  26.     sd_hijack.model_hijack.apply_optimizations('None')
  27.     os.makedirs(os.path.dirname(filename), exist_ok=True)
  28.     with devices.autocast():
  29.         torch.onnx.export(
  30.             shared.sd_model.model.diffusion_model,
  31.             (x, timesteps, context),
  32.             filename,
  33.             export_params=True,
  34.             opset_version=opset_version,
  35.             do_constant_folding=True,
  36.             input_names=['x', 'timesteps', 'context'],
  37.             output_names=['output'],
  38.             dynamic_axes={
  39.                 'x': {0: 'batch_size', 2: 'height', 3: 'width'},
  40.                 'timesteps': {0: 'batch_size'},
  41.                 'context': {0: 'batch_size', 1: 'sequence_length'},
  42.                 'output': {0: 'batch_size'},
  43.             },
  44.         )
  45.     sd_hijack.model_hijack.apply_optimizations()
  46.     sd_unet.apply_unet()
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4