OpenCV转pytorch

[复制链接]
发表于 2025-12-23 19:51:28 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

×
OpenCV的一些操纵转pytorch,从而有助于使用GPU加快,乃至导出onnx和转TensorRT
须要注意opencv的输入是numpy tensor,format是HW的2D张量大概HWC的3D张量,而pytorch一样平常是NCHW的4D大概CHW的3D张量。

Dilation腐蚀与膨胀

12: 腐蚀与膨胀 | 陌上见花开
https://blog.51cto.com/u_16175442/8629546
  1. import cv2
  2. import torch.nn.functional as F
  3. def dilate_cv(img, dilate_factor=10):
  4.     """
  5.     input img is np 2D, HWC 3D
  6.     """
  7.     img = img.astype(np.uint8)
  8.     img1 = cv2.dilate(
  9.         img,
  10.         np.ones((dilate_factor, dilate_factor), np.uint8),
  11.         iterations=1
  12.     )
  13.     return img1
  14. def dilate_torch(img, dilate_factor=10):
  15.     """
  16.     input img should be 3D CHW, or 4D NCHW
  17.     """
  18.     h, w = img.shape[-2:]
  19.     img1 = F.max_pool2d(img, kernel_size=dilate_factor, stride=1, padding=dilate_factor//2)
  20.     if dilate_factor % 2 == 0:
  21.         img1 = img1[:, :, :h, :w]
  22.     return img1
复制代码
贴一个DeepSeek转换的版本
  1. import torch
  2. import numpy as np
  3. def dilate_torch(img, dilate_factor=10):
  4.     """
  5.     input img is np 2D array
  6.     """
  7.     # 转换为PyTorch张量并添加batch和channel维度
  8.     img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
  9.    
  10.     # 创建最大池化层实现膨胀效果
  11.     kernel_size = dilate_factor
  12.     padding = dilate_factor // 2  # 保持输出尺寸与输入一致
  13.     max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=padding)
  14.    
  15.     # 应用池化操作
  16.     dilated_tensor = max_pool(img_tensor)
  17.    
  18.     # 转换回numpy数组并恢复原始维度
  19.     dilated_img = dilated_tensor.squeeze().cpu().numpy()
  20.    
  21.     return dilated_img.astype(np.uint8)
复制代码
Resize

  1. import cv2
  2. from torchvision.transforms.functional import resize
  3. from torchvision.transforms import InterpolationMode
  4. img_cv = cv2.resize(img_hwc, (scale*W, scale*H), interpolation=cv2.INTER_NEAREST)
  5. img_torch = resize(img_chw, (scale*H, scale*W), interpolation=InterpolationMode.NEAREST)
复制代码
须要注意的是opencv的resize和torch的resize效果不是完全对齐的,由于align方式的缘故原由。
颜色转换

  1. bgr_cv = cv2.cvtColor(data_np, cv2.COLOR_RGB2BGR)
  2. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  3. def bgr2rgb_torch_nchw(bgr_nchw):
  4.     b, g, r = bgr_nchw.split(split_size=1, dim=-3)
  5.     rgb = torch.cat([r, g, b], dim=-3).numpy()
  6.     return rgb
  7. def rgb2bgr_torch_nchw(rgb_nchw):
  8.     r, g, b = rgb_nchw.split(split_size=1, dim=-3)
  9.     bgr = torch.cat([b, g, r], dim=-3)
  10.     return bgr
复制代码
Blur

  1. import torch
  2. import numpy as np
  3. import cv2
  4. img_hwc = np.random.rand(*[256, 256, 3]).astype("float32")
  5. img_chw = img_hwc.transpose([2, 0, 1])
  6. img_chw_tc = torch.from_numpy(img_chw)
  7. kernel_size = 3
  8. img_blur_cv = cv2.blur(img_hwc, (kernel_size, kernel_size))
  9. img_blur_cv_chw = img_blur_cv.transpose([2, 0, 1])
  10. def mean_blur_torch(img_chw, kernel_size):
  11.     device = img_chw.device
  12.     dtype = img_chw.dtype
  13.     pad_l = kernel_size // 2
  14.     pad_r = kernel_size // 2
  15.     if kernel_size % 2 == 0:
  16.         pad_r = pad_r-1
  17.     img_chw1 = torch.nn.functional.pad(img_chw, pad=[pad_l, pad_r, pad_l, pad_r], mode='reflect')
  18.     weight = torch.ones(*(3, 1, kernel_size, kernel_size), dtype=dtype, device=device)/kernel_size/kernel_size
  19.     img_blur_chw = torch.nn.functional.conv2d(img_chw1, weight, padding=0, groups=3)
  20.     return img_blur_chw
  21. img_blur_torch_chw = mean_blur_torch(img_chw_tc, kernel_size)
  22. img_blur_torch_chw = img_blur_torch_chw.numpy()
  23. error = np.abs(img_blur_cv_chw - img_blur_torch_chw)
  24. print("error:", np.max(error), np.mean(error))
复制代码
  1. def to_tensor_torch(tensor):
  2.     # use torch tensor but not numpy as input
  3.     # hwc to chw / nhwc to nchw transpose, dtype conversion, rescale
  4.     if len(tensor.shape) == 2:
  5.         tensor = tensor.unsqueeze(dim=-1)
  6.     # transpose
  7.     if len(tensor.shape) == 3:
  8.         tensor = tensor.permute([2, 0, 1])
  9.     elif len(tensor.shape) == 4:
  10.         tensor = tensor.permute([0, 3, 1, 2])
  11.     else:
  12.         raise ValueError("unsupported")
  13.     is_uint8 = tensor.dtype == torch.uint8
  14.     # dtype conversion
  15.     tensor = tensor.to(torch.float32)
  16.     # rescale
  17.     if is_uint8:
  18.         tensor = tensor / 255.0
  19.     return tensor
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!qidao123.com:ToB企服之家,中国第一个企服评测及软件市场,开放入驻,技术点评得现金
回复

使用道具 举报

登录后关闭弹窗

登录参与点评抽奖  加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表