【深度学习】OCR中的Shrink操纵详解

打印 上一主题 下一主题

主题 1712|帖子 1712|积分 5136

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
OCR中的Shrink操纵详解

在光学字符辨认(OCR)中,shrink操纵用于对文本框多边形进行缩放,以生成用于训练和检测的特征图。本文将介绍shrink操纵的背景、实现方法及其应用。以下是用户提供的代码,详细展示了怎样实现这一过程。
背景介绍

在OCR任务中,文本通常以多边形的形式标注于图像中。为了更好地训练检测模型,通常需要将这些多边形进行肯定比例的缩放(shrink),以生成不同大小的特征图,从而提高模型的泛化能力和精度。shrink操纵的目标是将文本框缩小,以淘汰噪声对检测结果的影响。
代码实现

以下是实现shrink操纵的详细代码:
  1. import numpy as np
  2. import cv2
  3. import pyclipper
  4. from shapely.geometry import Polygon
  5. def shrink_polygon_py(polygon, shrink_ratio):
  6.     """
  7.     对框进行缩放,返回去的比例为1/shrink_ratio 即可
  8.     """
  9.     cx = polygon[:, 0].mean()
  10.     cy = polygon[:, 1].mean()
  11.     polygon[:, 0] = cx + (polygon[:, 0] - cx) * shrink_ratio
  12.     polygon[:, 1] = cy + (polygon[:, 1] - cy) * shrink_ratio
  13.     return polygon
  14. def shrink_polygon_pyclipper(polygon, shrink_ratio):
  15.     polygon_shape = Polygon(polygon)
  16.     distance = (
  17.         polygon_shape.area * (1 - np.power(shrink_ratio, 2)) / polygon_shape.length
  18.     )
  19.     subject = [tuple(l) for l in polygon]
  20.     padding = pyclipper.PyclipperOffset()
  21.     padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  22.     shrinked = padding.Execute(-distance)
  23.     if shrinked == []:
  24.         shrinked = np.array(shrinked)
  25.     else:
  26.         shrinked = np.array(shrinked[0]).reshape(-1, 2)
  27.     return shrinked
  28. class MakeShrinkMap:
  29.     def __init__(self, min_text_size=8, shrink_ratio=0.4, shrink_type="pyclipper"):
  30.         shrink_func_dict = {
  31.             "py": shrink_polygon_py,
  32.             "pyclipper": shrink_polygon_pyclipper,
  33.         }
  34.         self.shrink_func = shrink_func_dict[shrink_type]
  35.         self.min_text_size = min_text_size
  36.         self.shrink_ratio = shrink_ratio
  37.     def __call__(self, data: dict) -> dict:
  38.         image = data["img"]
  39.         text_polys = data["text_polys"]
  40.         ignore_tags = data["ignore_tags"]
  41.         h, w = image.shape[:2]
  42.         text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
  43.         gt = np.zeros((h, w), dtype=np.float32)
  44.         mask = np.ones((h, w), dtype=np.float32)
  45.         shrinked_polygons = []
  46.         for i in range(len(text_polys)):
  47.             polygon = text_polys[i]
  48.             height = max(polygon[:, 1]) - min(polygon[:, 1])
  49.             width = max(polygon[:, 0]) - min(polygon[:, 0])
  50.             if ignore_tags[i] or min(height, width) < self.min_text_size:
  51.                 cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
  52.                 ignore_tags[i] = True
  53.             else:
  54.                 shrinked = self.shrink_func(polygon, self.shrink_ratio)
  55.                 shrinked_polygons.append(shrinked)
  56.                 if shrinked.size == 0:
  57.                     cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
  58.                     ignore_tags[i] = True
  59.                     continue
  60.                 cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1)
  61.         data["shrink_map"] = gt
  62.         data["shrink_mask"] = mask
  63.         data["shrinked_polygons"] = shrinked_polygons
  64.         return data
  65.     def validate_polygons(self, polygons, ignore_tags, h, w):
  66.         if len(polygons) == 0:
  67.             return polygons, ignore_tags
  68.         assert len(polygons) == len(ignore_tags)
  69.         for polygon in polygons:
  70.             polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
  71.             polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
  72.         for i in range(len(polygons)):
  73.             area = self.polygon_area(polygons[i])
  74.             if abs(area) < 1:
  75.                 ignore_tags[i] = True
  76.             if area > 0:
  77.                 polygons[i] = polygons[i][::-1, :]
  78.         return polygons, ignore_tags
  79.     def polygon_area(self, polygon):
  80.         return cv2.contourArea(polygon)
  81. if __name__ == "__main__":
  82.     # 示例图像
  83.     image = np.ones((200, 200, 3), dtype=np.uint8) * 255
  84.     # 示例文本框多边形
  85.     text_polys = [
  86.         np.array([[50, 50], [150, 50], [150, 100], [50, 100]]),
  87.         np.array([[60, 120], [140, 120], [140, 160], [60, 160]])
  88.     ]
  89.     # 示例忽略标志
  90.     ignore_tags = [False, False]
  91.     # 构建输入数据字典
  92.     data = {
  93.         "img": image,
  94.         "text_polys": text_polys,
  95.         "ignore_tags": ignore_tags
  96.     }
  97.     # 初始化 MakeShrinkMap 类
  98.     make_shrink_map = MakeShrinkMap(min_text_size=8, shrink_ratio=0.4, shrink_type="pyclipper")
  99.     # 调用类处理数据
  100.     result = make_shrink_map(data)
  101.     # 获取生成的shrink_map和shrink_mask
  102.     shrink_map = result["shrink_map"]
  103.     shrink_mask = result["shrink_mask"]
  104.     shrinked_polygons = result["shrinked_polygons"]
  105.     # 在原图上绘制shrink前的多边形
  106.     original_image = image.copy()
  107.     for polygon in text_polys:
  108.         cv2.polylines(original_image, [polygon.astype(np.int32)], True, (0, 0, 255), 2)
  109.     # 在原图上绘制shrink后的多边形
  110.     shrinked_image = image.copy()
  111.     for polygon in shrinked_polygons:
  112.         cv2.polylines(shrinked_image, [polygon.astype(np.int32)], True, (0, 255, 0), 2)
  113.     # 保存结果图像
  114.     cv2.imwrite("original_image.png", original_image)
  115.     cv2.imwrite("shrinked_image.png", shrinked_image)
  116.     cv2.imwrite("shrink_map.png", shrink_map * 255)  # 将shrink_map转换为图像
  117.     cv2.imwrite("shrink_mask.png", shrink_mask * 255)  # 将shrink_mask转换为图像
  118.     # 显示结果
  119.     # cv2.imshow("Original Image", original_image)
  120.     # cv2.imshow("Shrinked Image", shrinked_image)
  121.     # cv2.imshow("Shrink Map", shrink_map)
  122.     # cv2.imshow("Shrink Mask", shrink_mask)
  123.     # cv2.waitKey(0)
  124.     # cv2.destroyAllWindows()
复制代码
代码详解


  • Shrink算法实现
    代码中实现了两种不同的shrink算法:shrink_polygon_py和shrink_polygon_pyclipper。

    • shrink_polygon_py:通过计算多边形的中央点,然后将多边形的每个点按照缩放比例向中央点收缩。
    • shrink_polygon_pyclipper:利用pyclipper库进行多边形缩放,计算更为精确,实用于复杂多边形。

  • MakeShrinkMap类
    MakeShrinkMap类用于将图像中的文本多边形进行shrink操纵。类的构造函数接受最小文本尺寸、缩放比例和缩放类型作为参数。__call__方法处理输入数据字典,并生成缩放后的特征图和掩码。
  • 代码示例
    在__main__部分,创建了一个示例图像和文本多边形,并利用MakeShrinkMap类进行shrink操纵。结果图像包括原始多边形和缩放后的多边形,并将生成的特征图和掩码生存为图像文件。
应用

Shrink操纵在OCR中有广泛的应用,如:


  • 文本检测:通过缩放文本框生成特征图,可以提高文本检测模型的准确性和鲁棒性。
  • 噪声过滤:缩小多边形可以淘汰背景噪声对检测结果的干扰。
  • 数据增强:生成不同缩放比例的特征图,有助于提升模型的泛化能力。
总结

本文介绍了OCR中shrink操纵的实现方法和应用,通过详细的代码示例展示了怎样对文本多边形进行缩放。shrink操纵在提高OCR模型性能方面具有重要作用,是文本检测和辨认过程中不可或缺的一

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

正序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

不到断气不罢休

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表