DB算法原理与构建

打印 上一主题 下一主题

主题 1660|帖子 1660|积分 4980

参考:
https://aistudio.baidu.com/projectdetail/4483048
https://www.bilibili.com/video/BV1xf4y1p7Gf/
Real-Time Scene Text Detection with Differentiable Binarization
如何读论文-by 李沐
DB (Real-Time Scene Text Detection with Differentiable Binarization)
原理

DB是一个基于分割的文本检测算法,其提出的可微分阈值,接纳动态的阈值区分文本区域与背景

基于分割的平凡文本检测算法,流程如上图蓝色箭头所示,得到分割结果后接纳固定的阈值(标准二值化不可微,导致网络无法端到端训练)得到二值化的分割图,之后接纳诸如像素聚类的开导式算法得到文本区域。
DB算法的流程如图中红色箭头所示,最大的差别在于DB有一个阈值图,通过网络去推测图片每个位置处的阈值,threshold map相当于自适应threshold,每个像素点都有一个阈值,而不是接纳一个固定的值,结合segment map与threshold map得到近似二值化特征图

上风:
1.算法结构简单,无需繁琐的后处理
2.开源数据上拥有精良的精度和性能
输入的图像经过网络Backbone和FPN提取特征,提取后的特征级联在一起,得到原图四分之一大小的特征,然后使用卷积层分别得到文本区域推测概率图和阈值图,进而通过DB的后处理得到文本困绕曲线。
FPN-Convolution
deformable 卷积作用是为了增大感受野



如何通过两个map得到近似二值图

k是一个参数,论文里是50
DB算法提出了可微二值化,可微二值化将标准二值化中的阶跃函数举行了近似,使用如下公式举行代替:



白色区域是probalibitymap
黄色区域是threshold map

使用这种设置,是为了可微分

L是周长 A是面积

i,j 是点到红色四条边的距离
归一化



正样本是笔墨区域 负样本是背景区域





九个方位。每个方位x,y以是是18















DB文本检测模型构建

DB文本检测模型可以分为三个部分:
Backbone网络,负责提取图像的特征
FPN网络,特征金字塔结构加强特征
Head网络,计算文本区域概率图
backbone网络:论文中使用了ResNet50,本节实验中,为了加速训练速率,接纳MobileNetV3 large结构作为backbone。
DB的Backbone用于提取图像的多标准特征,如下代码所示,假设输入的形状为[640, 640],backbone网络的输出有四个特征,其形状分别是 [1, 16, 160, 160],[1, 24, 80, 80], [1, 56, 40, 40],[1, 480, 20, 20]。 这些特征将输入给特征金字塔FPN网络进一步的加强特征。
  1. import paddle
  2. from ppocr.modeling.backbones.det_mobilenet_v3 import MobileNetV3
  3. fake_inputs = paddle.randn([1, 3, 640, 640], dtype="float32")
  4. # 1. 声明Backbone
  5. model_backbone = MobileNetV3()
  6. model_backbone.eval()
  7. # 2. 执行预测
  8. outs = model_backbone(fake_inputs)
  9. # 3. 打印网络结构
  10. # print(model_backbone)
  11. # 4. 打印输出特征形状
  12. for idx, out in enumerate(outs):
  13.     print("The index is ", idx, "and the shape of output is ", out.shape)
复制代码
FPN网络
特征金字塔结构FPN是一种卷积网络来高效提取图片中各维度特征的常用方法。
FPN网络的输入为Backbone部分的输出,输出特征图的高度和宽度为原图的四分之一。假设输入图像的形状为[1, 3, 640, 640],FPN输出特征的高度和宽度为[160, 160]
  1. import paddle
  2. from paddle import nn
  3. import paddle.nn.functional as F
  4. from paddle import ParamAttr
  5. class DBFPN(nn.Layer):
  6.     def __init__(self, in_channels, out_channels, **kwargs):
  7.         super(DBFPN, self).__init__()
  8.         self.out_channels = out_channels
  9.         # DBFPN详细实现参考: https://github.com/PaddlePaddle/PaddleOCRblob/release%2F2.4/ppocr/modeling/necks/db_fpn.py
  10.     def forward(self, x):
  11.         c2, c3, c4, c5 = x
  12.         in5 = self.in5_conv(c5)
  13.         in4 = self.in4_conv(c4)
  14.         in3 = self.in3_conv(c3)
  15.         in2 = self.in2_conv(c2)
  16.         # 特征上采样
  17.         out4 = in4 + F.upsample(
  18.             in5, scale_factor=2, mode="nearest", align_mode=1)  # 1/16
  19.         out3 = in3 + F.upsample(
  20.             out4, scale_factor=2, mode="nearest", align_mode=1)  # 1/8
  21.         out2 = in2 + F.upsample(
  22.             out3, scale_factor=2, mode="nearest", align_mode=1)  # 1/4
  23.         p5 = self.p5_conv(in5)
  24.         p4 = self.p4_conv(out4)
  25.         p3 = self.p3_conv(out3)
  26.         p2 = self.p2_conv(out2)
  27.         # 特征上采样
  28.         p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
  29.         p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
  30.         p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
  31.         fuse = paddle.concat([p5, p4, p3, p2], axis=1)
  32.         return fuse
复制代码
Head网络
计算文本区域概率图,文本区域阈值图以及文本区域二值图。
DB Head网络会在FPN特征的基础上作上采样,将FPN特征由原图的四分之一大小映射到原图大小。
  1. import math
  2. import paddle
  3. from paddle import nn
  4. import paddle.nn.functional as F
  5. from paddle import ParamAttr
  6. class DBHead(nn.Layer):
  7.     """
  8.     Differentiable Binarization (DB) for text detection:
  9.         see https://arxiv.org/abs/1911.08947
  10.     args:
  11.         params(dict): super parameters for build DB network
  12.     """
  13.     def __init__(self, in_channels, k=50, **kwargs):
  14.         super(DBHead, self).__init__()
  15.         self.k = k
  16.         # DBHead详细实现参考 https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/modeling/heads/det_db_head.py
  17.     def step_function(self, x, y):
  18.         # 可微二值化实现,通过概率图和阈值图计算文本分割二值图
  19.         return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
  20.     def forward(self, x, targets=None):
  21.         shrink_maps = self.binarize(x)
  22.         if not self.training:
  23.             return {'maps': shrink_maps}
  24.         threshold_maps = self.thresh(x)
  25.         binary_maps = self.step_function(shrink_maps, threshold_maps)
  26.         y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
  27.         return {'maps': y}
复制代码
  1. # 1. 从PaddleOCR中imort DBHead
  2. from ppocr.modeling.heads.det_db_head import DBHead
  3. import paddle
  4. # 2. 计算DBFPN网络输出结果
  5. fake_inputs = paddle.randn([1, 3, 640, 640], dtype="float32")
  6. model_backbone = MobileNetV3()
  7. in_channles = model_backbone.out_channels
  8. model_fpn = DBFPN(in_channels=in_channles, out_channels=256)
  9. outs = model_backbone(fake_inputs)
  10. fpn_outs = model_fpn(outs)
  11. # 3. 声明Head网络
  12. model_db_head = DBHead(in_channels=256)
  13. # 4. 打印DBhead网络
  14. print(model_db_head)
  15. # 5. 计算Head网络的输出
  16. db_head_outs = model_db_head(fpn_outs)
  17. print(f"The shape of fpn outs {fpn_outs.shape}")
  18. print(f"The shape of DB head outs {db_head_outs['maps'].shape}")
复制代码

运行后发现报错:
类不完整,于是重新到github paddle ocr目录下下载相应文件
db_fpn.py
det_db_head.py
完整代码:
  1. import math
  2. import paddle
  3. from paddle import nn
  4. import paddle.nn.functional as F
  5. from paddle import ParamAttr
  6. def make_divisible(v, divisor=8, min_value=None):
  7.     if min_value is None:
  8.         min_value = divisor
  9.     new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  10.     if new_v < 0.9 * v:
  11.         new_v += divisor
  12.     return new_v
  13. class MobileNetV3(nn.Layer):
  14.     def __init__(self,
  15.                  in_channels=3,
  16.                  model_name='large',
  17.                  scale=0.5,
  18.                  disable_se=False,
  19.                  **kwargs):
  20.         """
  21.         the MobilenetV3 backbone network for detection module.
  22.         Args:
  23.             params(dict): the super parameters for build network
  24.         """
  25.         super(MobileNetV3, self).__init__()
  26.         self.disable_se = disable_se
  27.         if model_name == "large":
  28.             cfg = [
  29.                 # k, exp, c,  se,     nl,  s,
  30.                 [3, 16, 16, False, 'relu', 1],
  31.                 [3, 64, 24, False, 'relu', 2],
  32.                 [3, 72, 24, False, 'relu', 1],
  33.                 [5, 72, 40, True, 'relu', 2],
  34.                 [5, 120, 40, True, 'relu', 1],
  35.                 [5, 120, 40, True, 'relu', 1],
  36.                 [3, 240, 80, False, 'hardswish', 2],
  37.                 [3, 200, 80, False, 'hardswish', 1],
  38.                 [3, 184, 80, False, 'hardswish', 1],
  39.                 [3, 184, 80, False, 'hardswish', 1],
  40.                 [3, 480, 112, True, 'hardswish', 1],
  41.                 [3, 672, 112, True, 'hardswish', 1],
  42.                 [5, 672, 160, True, 'hardswish', 2],
  43.                 [5, 960, 160, True, 'hardswish', 1],
  44.                 [5, 960, 160, True, 'hardswish', 1],
  45.             ]
  46.             cls_ch_squeeze = 960
  47.         elif model_name == "small":
  48.             cfg = [
  49.                 # k, exp, c,  se,     nl,  s,
  50.                 [3, 16, 16, True, 'relu', 2],
  51.                 [3, 72, 24, False, 'relu', 2],
  52.                 [3, 88, 24, False, 'relu', 1],
  53.                 [5, 96, 40, True, 'hardswish', 2],
  54.                 [5, 240, 40, True, 'hardswish', 1],
  55.                 [5, 240, 40, True, 'hardswish', 1],
  56.                 [5, 120, 48, True, 'hardswish', 1],
  57.                 [5, 144, 48, True, 'hardswish', 1],
  58.                 [5, 288, 96, True, 'hardswish', 2],
  59.                 [5, 576, 96, True, 'hardswish', 1],
  60.                 [5, 576, 96, True, 'hardswish', 1],
  61.             ]
  62.             cls_ch_squeeze = 576
  63.         else:
  64.             raise NotImplementedError("mode[" + model_name +
  65.                                       "_model] is not implemented!")
  66.         supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
  67.         assert scale in supported_scale, \
  68.             "supported scale are {} but input scale is {}".format(supported_scale, scale)
  69.         inplanes = 16
  70.         # conv1
  71.         self.conv = ConvBNLayer(
  72.             in_channels=in_channels,
  73.             out_channels=make_divisible(inplanes * scale),
  74.             kernel_size=3,
  75.             stride=2,
  76.             padding=1,
  77.             groups=1,
  78.             if_act=True,
  79.             act='hardswish')
  80.         self.stages = []
  81.         self.out_channels = []
  82.         block_list = []
  83.         i = 0
  84.         inplanes = make_divisible(inplanes * scale)
  85.         for (k, exp, c, se, nl, s) in cfg:
  86.             se = se and not self.disable_se
  87.             start_idx = 2 if model_name == 'large' else 0
  88.             if s == 2 and i > start_idx:
  89.                 self.out_channels.append(inplanes)
  90.                 self.stages.append(nn.Sequential(*block_list))
  91.                 block_list = []
  92.             block_list.append(
  93.                 ResidualUnit(
  94.                     in_channels=inplanes,
  95.                     mid_channels=make_divisible(scale * exp),
  96.                     out_channels=make_divisible(scale * c),
  97.                     kernel_size=k,
  98.                     stride=s,
  99.                     use_se=se,
  100.                     act=nl))
  101.             inplanes = make_divisible(scale * c)
  102.             i += 1
  103.         block_list.append(
  104.             ConvBNLayer(
  105.                 in_channels=inplanes,
  106.                 out_channels=make_divisible(scale * cls_ch_squeeze),
  107.                 kernel_size=1,
  108.                 stride=1,
  109.                 padding=0,
  110.                 groups=1,
  111.                 if_act=True,
  112.                 act='hardswish'))
  113.         self.stages.append(nn.Sequential(*block_list))
  114.         self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
  115.         for i, stage in enumerate(self.stages):
  116.             self.add_sublayer(sublayer=stage, name="stage{}".format(i))
  117.     def forward(self, x):
  118.         x = self.conv(x)
  119.         out_list = []
  120.         for stage in self.stages:
  121.             x = stage(x)
  122.             out_list.append(x)
  123.         return out_list
  124. class ConvBNLayer(nn.Layer):
  125.     def __init__(self,
  126.                  in_channels,
  127.                  out_channels,
  128.                  kernel_size,
  129.                  stride,
  130.                  padding,
  131.                  groups=1,
  132.                  if_act=True,
  133.                  act=None):
  134.         super(ConvBNLayer, self).__init__()
  135.         self.if_act = if_act
  136.         self.act = act
  137.         self.conv = nn.Conv2D(
  138.             in_channels=in_channels,
  139.             out_channels=out_channels,
  140.             kernel_size=kernel_size,
  141.             stride=stride,
  142.             padding=padding,
  143.             groups=groups,
  144.             bias_attr=False)
  145.         self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
  146.     def forward(self, x):
  147.         x = self.conv(x)
  148.         x = self.bn(x)
  149.         if self.if_act:
  150.             if self.act == "relu":
  151.                 x = F.relu(x)
  152.             elif self.act == "hardswish":
  153.                 x = F.hardswish(x)
  154.             else:
  155.                 print("The activation function({}) is selected incorrectly.".
  156.                       format(self.act))
  157.                 exit()
  158.         return x
  159. class ResidualUnit(nn.Layer):
  160.     def __init__(self,
  161.                  in_channels,
  162.                  mid_channels,
  163.                  out_channels,
  164.                  kernel_size,
  165.                  stride,
  166.                  use_se,
  167.                  act=None):
  168.         super(ResidualUnit, self).__init__()
  169.         self.if_shortcut = stride == 1 and in_channels == out_channels
  170.         self.if_se = use_se
  171.         self.expand_conv = ConvBNLayer(
  172.             in_channels=in_channels,
  173.             out_channels=mid_channels,
  174.             kernel_size=1,
  175.             stride=1,
  176.             padding=0,
  177.             if_act=True,
  178.             act=act)
  179.         self.bottleneck_conv = ConvBNLayer(
  180.             in_channels=mid_channels,
  181.             out_channels=mid_channels,
  182.             kernel_size=kernel_size,
  183.             stride=stride,
  184.             padding=int((kernel_size - 1) // 2),
  185.             groups=mid_channels,
  186.             if_act=True,
  187.             act=act)
  188.         if self.if_se:
  189.             self.mid_se = SEModule(mid_channels)
  190.         self.linear_conv = ConvBNLayer(
  191.             in_channels=mid_channels,
  192.             out_channels=out_channels,
  193.             kernel_size=1,
  194.             stride=1,
  195.             padding=0,
  196.             if_act=False,
  197.             act=None)
  198.     def forward(self, inputs):
  199.         x = self.expand_conv(inputs)
  200.         x = self.bottleneck_conv(x)
  201.         if self.if_se:
  202.             x = self.mid_se(x)
  203.         x = self.linear_conv(x)
  204.         if self.if_shortcut:
  205.             x = paddle.add(inputs, x)
  206.         return x
  207. class SEModule(nn.Layer):
  208.     def __init__(self, in_channels, reduction=4):
  209.         super(SEModule, self).__init__()
  210.         self.avg_pool = nn.AdaptiveAvgPool2D(1)
  211.         self.conv1 = nn.Conv2D(
  212.             in_channels=in_channels,
  213.             out_channels=in_channels // reduction,
  214.             kernel_size=1,
  215.             stride=1,
  216.             padding=0)
  217.         self.conv2 = nn.Conv2D(
  218.             in_channels=in_channels // reduction,
  219.             out_channels=in_channels,
  220.             kernel_size=1,
  221.             stride=1,
  222.             padding=0)
  223.     def forward(self, inputs):
  224.         outputs = self.avg_pool(inputs)
  225.         outputs = self.conv1(outputs)
  226.         outputs = F.relu(outputs)
  227.         outputs = self.conv2(outputs)
  228.         outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
  229.         return inputs * outputs
  230. class DBFPN(nn.Layer):
  231.     def __init__(self, in_channels, out_channels, **kwargs):
  232.         super(DBFPN, self).__init__()
  233.         self.out_channels = out_channels
  234.         weight_attr = paddle.nn.initializer.KaimingUniform()
  235.         self.in2_conv = nn.Conv2D(
  236.             in_channels=in_channels[0],
  237.             out_channels=self.out_channels,
  238.             kernel_size=1,
  239.             weight_attr=ParamAttr(initializer=weight_attr),
  240.             bias_attr=False)
  241.         self.in3_conv = nn.Conv2D(
  242.             in_channels=in_channels[1],
  243.             out_channels=self.out_channels,
  244.             kernel_size=1,
  245.             weight_attr=ParamAttr(initializer=weight_attr),
  246.             bias_attr=False)
  247.         self.in4_conv = nn.Conv2D(
  248.             in_channels=in_channels[2],
  249.             out_channels=self.out_channels,
  250.             kernel_size=1,
  251.             weight_attr=ParamAttr(initializer=weight_attr),
  252.             bias_attr=False)
  253.         self.in5_conv = nn.Conv2D(
  254.             in_channels=in_channels[3],
  255.             out_channels=self.out_channels,
  256.             kernel_size=1,
  257.             weight_attr=ParamAttr(initializer=weight_attr),
  258.             bias_attr=False)
  259.         self.p5_conv = nn.Conv2D(
  260.             in_channels=self.out_channels,
  261.             out_channels=self.out_channels // 4,
  262.             kernel_size=3,
  263.             padding=1,
  264.             weight_attr=ParamAttr(initializer=weight_attr),
  265.             bias_attr=False)
  266.         self.p4_conv = nn.Conv2D(
  267.             in_channels=self.out_channels,
  268.             out_channels=self.out_channels // 4,
  269.             kernel_size=3,
  270.             padding=1,
  271.             weight_attr=ParamAttr(initializer=weight_attr),
  272.             bias_attr=False)
  273.         self.p3_conv = nn.Conv2D(
  274.             in_channels=self.out_channels,
  275.             out_channels=self.out_channels // 4,
  276.             kernel_size=3,
  277.             padding=1,
  278.             weight_attr=ParamAttr(initializer=weight_attr),
  279.             bias_attr=False)
  280.         self.p2_conv = nn.Conv2D(
  281.             in_channels=self.out_channels,
  282.             out_channels=self.out_channels // 4,
  283.             kernel_size=3,
  284.             padding=1,
  285.             weight_attr=ParamAttr(initializer=weight_attr),
  286.             bias_attr=False)
  287.     def forward(self, x):
  288.         c2, c3, c4, c5 = x
  289.         in5 = self.in5_conv(c5)
  290.         in4 = self.in4_conv(c4)
  291.         in3 = self.in3_conv(c3)
  292.         in2 = self.in2_conv(c2)
  293.         out4 = in4 + F.upsample(
  294.             in5, scale_factor=2, mode="nearest", align_mode=1)  # 1/16
  295.         out3 = in3 + F.upsample(
  296.             out4, scale_factor=2, mode="nearest", align_mode=1)  # 1/8
  297.         out2 = in2 + F.upsample(
  298.             out3, scale_factor=2, mode="nearest", align_mode=1)  # 1/4
  299.         p5 = self.p5_conv(in5)
  300.         p4 = self.p4_conv(out4)
  301.         p3 = self.p3_conv(out3)
  302.         p2 = self.p2_conv(out2)
  303.         p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
  304.         p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
  305.         p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
  306.         fuse = paddle.concat([p5, p4, p3, p2], axis=1)
  307.         return fuse
  308. def get_bias_attr(k):
  309.     stdv = 1.0 / math.sqrt(k * 1.0)
  310.     initializer = paddle.nn.initializer.Uniform(-stdv, stdv)
  311.     bias_attr = ParamAttr(initializer=initializer)
  312.     return bias_attr
  313. class Head(nn.Layer):
  314.     def __init__(self, in_channels, name_list):
  315.         super(Head, self).__init__()
  316.         self.conv1 = nn.Conv2D(
  317.             in_channels=in_channels,
  318.             out_channels=in_channels // 4,
  319.             kernel_size=3,
  320.             padding=1,
  321.             weight_attr=ParamAttr(),
  322.             bias_attr=False)
  323.         self.conv_bn1 = nn.BatchNorm(
  324.             num_channels=in_channels // 4,
  325.             param_attr=ParamAttr(
  326.                 initializer=paddle.nn.initializer.Constant(value=1.0)),
  327.             bias_attr=ParamAttr(
  328.                 initializer=paddle.nn.initializer.Constant(value=1e-4)),
  329.             act='relu')
  330.         self.conv2 = nn.Conv2DTranspose(
  331.             in_channels=in_channels // 4,
  332.             out_channels=in_channels // 4,
  333.             kernel_size=2,
  334.             stride=2,
  335.             weight_attr=ParamAttr(
  336.                 initializer=paddle.nn.initializer.KaimingUniform()),
  337.             bias_attr=get_bias_attr(in_channels // 4))
  338.         self.conv_bn2 = nn.BatchNorm(
  339.             num_channels=in_channels // 4,
  340.             param_attr=ParamAttr(
  341.                 initializer=paddle.nn.initializer.Constant(value=1.0)),
  342.             bias_attr=ParamAttr(
  343.                 initializer=paddle.nn.initializer.Constant(value=1e-4)),
  344.             act="relu")
  345.         self.conv3 = nn.Conv2DTranspose(
  346.             in_channels=in_channels // 4,
  347.             out_channels=1,
  348.             kernel_size=2,
  349.             stride=2,
  350.             weight_attr=ParamAttr(
  351.                 initializer=paddle.nn.initializer.KaimingUniform()),
  352.             bias_attr=get_bias_attr(in_channels // 4), )
  353.     def forward(self, x):
  354.         x = self.conv1(x)
  355.         x = self.conv_bn1(x)
  356.         x = self.conv2(x)
  357.         x = self.conv_bn2(x)
  358.         x = self.conv3(x)
  359.         x = F.sigmoid(x)
  360.         return x
  361. class DBHead(nn.Layer):
  362.     """
  363.     Differentiable Binarization (DB) for text detection:
  364.         see https://arxiv.org/abs/1911.08947
  365.     args:
  366.         params(dict): super parameters for build DB network
  367.     """
  368.     def __init__(self, in_channels, k=50, **kwargs):
  369.         super(DBHead, self).__init__()
  370.         self.k = k
  371.         binarize_name_list = [
  372.             'conv2d_56', 'batch_norm_47', 'conv2d_transpose_0', 'batch_norm_48',
  373.             'conv2d_transpose_1', 'binarize'
  374.         ]
  375.         thresh_name_list = [
  376.             'conv2d_57', 'batch_norm_49', 'conv2d_transpose_2', 'batch_norm_50',
  377.             'conv2d_transpose_3', 'thresh'
  378.         ]
  379.         self.binarize = Head(in_channels, binarize_name_list)
  380.         self.thresh = Head(in_channels, thresh_name_list)
  381.     def step_function(self, x, y):
  382.         return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
  383.     def forward(self, x, targets=None):
  384.         shrink_maps = self.binarize(x)
  385.         if not self.training:
  386.             return {'maps': shrink_maps}
  387.         threshold_maps = self.thresh(x)
  388.         binary_maps = self.step_function(shrink_maps, threshold_maps)
  389.         y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
  390.         return {'maps': y}
  391. if __name__=='__main__':
  392.     fake_inputs = paddle.randn([1, 3, 640, 640], dtype="float32")
  393.     #   声明Backbone
  394.     model_backbone = MobileNetV3()
  395.    
  396.     in_channles = model_backbone.out_channels
  397.     # 声明FPN网络
  398.     model_fpn = DBFPN(in_channels=in_channles, out_channels=256)
  399.     #  打印FPN网络
  400.     print(model_fpn)
  401.     # DBFPN(
  402.     #   (in2_conv): Conv2D(16, 256, kernel_size=[1, 1], data_format=NCHW)
  403.     #   (in3_conv): Conv2D(24, 256, kernel_size=[1, 1], data_format=NCHW)
  404.     #   (in4_conv): Conv2D(56, 256, kernel_size=[1, 1], data_format=NCHW)
  405.     #   (in5_conv): Conv2D(480, 256, kernel_size=[1, 1], data_format=NCHW)
  406.     #   (p5_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  407.     #   (p4_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  408.     #   (p3_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  409.     #   (p2_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  410.     # )
  411.     # 5. 计算得到FPN结果输出
  412.     outs = model_backbone(fake_inputs)
  413.     fpn_outs = model_fpn(outs)
  414.     # The shape of fpn outs [1, 256, 160, 160]
  415.     # 3. 声明Head网络
  416.     model_db_head = DBHead(in_channels=256)
  417.     # 4. 打印DBhead网络
  418.     print(model_db_head)
  419.     # DBHead(
  420.     #   (binarize): Head(
  421.     #     (conv1): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  422.     #     (conv_bn1): BatchNorm()
  423.     #     (conv2): Conv2DTranspose(64, 64, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
  424.     #     (conv_bn2): BatchNorm()
  425.     #     (conv3): Conv2DTranspose(64, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
  426.     #   )
  427.     #   (thresh): Head(
  428.     #     (conv1): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  429.     #     (conv_bn1): BatchNorm()
  430.     #     (conv2): Conv2DTranspose(64, 64, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
  431.     #     (conv_bn2): BatchNorm()
  432.     #     (conv3): Conv2DTranspose(64, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
  433.     #   )
  434.     # )
  435.     # 5. 计算Head网络的输出
  436.     db_head_outs = model_db_head(fpn_outs)
  437.     print(f"The shape of fpn outs {fpn_outs.shape}")
  438.     # The shape of fpn outs [1, 256, 160, 160]
  439.     print(f"The shape of DB head outs {db_head_outs['maps'].shape}")
  440.     # The shape of DB head outs [1, 3, 640, 640]
复制代码
结果:
  1. DBFPN(
  2.   (in2_conv): Conv2D(16, 256, kernel_size=[1, 1], data_format=NCHW)
  3.   (in3_conv): Conv2D(24, 256, kernel_size=[1, 1], data_format=NCHW)
  4.   (in4_conv): Conv2D(56, 256, kernel_size=[1, 1], data_format=NCHW)
  5.   (in5_conv): Conv2D(480, 256, kernel_size=[1, 1], data_format=NCHW)
  6.   (p5_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  7.   (p4_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  8.   (p3_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  9.   (p2_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  10. )
  11. DBHead(
  12.   (binarize): Head(
  13.     (conv1): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  14.     (conv_bn1): BatchNorm()
  15.     (conv2): Conv2DTranspose(64, 64, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
  16.     (conv_bn2): BatchNorm()
  17.     (conv3): Conv2DTranspose(64, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
  18.   )
  19.   (thresh): Head(
  20.     (conv1): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)
  21.     (conv_bn1): BatchNorm()
  22.     (conv2): Conv2DTranspose(64, 64, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
  23.     (conv_bn2): BatchNorm()
  24.     (conv3): Conv2DTranspose(64, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)
  25.   )
  26. )
  27. The shape of fpn outs [1, 256, 160, 160]
  28. The shape of DB head outs [1, 3, 640, 640]
复制代码
简化代码
  1. from ppocr.modeling.backbones.det_mobilenet_v3 import MobileNetV3
  2. from ppocr.modeling.necks.db_fpn import DBFPN
  3. from ppocr.modeling.heads.det_db_head import DBHead
  4. import paddle
  5. fake_inputs = paddle.randn([1, 3, 640, 640], dtype="float32")
  6. model_backbone = MobileNetV3()
  7. print(model_backbone)
  8. in_channles = model_backbone.out_channels
  9. outs = model_backbone(fake_inputs)
  10. model_fpn = DBFPN(in_channels=in_channles, out_channels=256)
  11. print(model_fpn)
  12. fpn_outs = model_fpn(outs)
  13. model_db_head = DBHead(in_channels=256)
  14. print(model_db_head)
  15. db_head_outs = model_db_head(fpn_outs)
  16. print(f"The shape of fpn outs {fpn_outs.shape}") # [1, 256, 160, 160]
  17. print(f"The shape of DB head outs {db_head_outs['maps'].shape}")#[1, 3, 640, 640]
复制代码
DB算法优点:(有监督,backbone选ResNet50效果更好)

  • 我们的方法在五个场景文本基准数据集中均取得了连续精良的表现,涵盖了程度、多方向和弯曲文本。
    2.相比先前的领先方法,我们的方法运行速率显著更快,因为DB可以或许提供高度鲁棒的二值化图,极大地简化了后期处理步调。
    3.当搭配轻量级主干网络时,DB表现出色,特殊是在使用ResNet-18作为主干网络时,显著提升了检测性能。
    4.由于在推理阶段可以移除DB而不断送性能,因此测试阶段无需额外的内存/时间开销。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

万有斥力

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表