CVPR | CNN融合留意力机制,芜湖腾飞!

打印 上一主题 下一主题

主题 945|帖子 945|积分 2835

**标题:**On the Integration of Self-Attention and Convolution
**论文链接:**https://arxiv.org/pdf/2111.14556
**代码链接:**https://github.com/LeapLabTHU/ACmix

创新点


1. 揭示卷积和自留意力的内涵接洽

文章通过重新分解卷积和自留意力模块的操作,发现它们在第一阶段(特性投影)都依赖于 1×1 卷积操作,并且这一阶段占据了大部分的盘算复杂度(与通道数的平方成正比)。这一发现为整合两种模块提供了理论基础。
2. 提出 ACmix 模型

基于上述发现,作者提出了 ACmix 模型,它通过共享 1×1 卷积操作来同时实现卷积和自留意力的功能。具体来说:
**第一阶段:**输入特性通过 1×1 卷积投影,生成中间特性。
**第二阶段:**这些中间特性分别用于卷积路径(通过移位和聚合操作)和自留意力路径(盘算留意力权重并聚合值)。最终,两条路径的输出通过可学习的权重加权求和,得到最终输出。
3. 改进的移位和聚合操作

文章还提出了一种改进的移位操作,通过使用 固定卷积核的分组卷积 来替换传统的张量移位操作。这种方法不光提高了盘算服从,还允许卷积核的可学习性,进一步增强了模型的灵活性。
4. 顺应性路径权重

ACmix 引入了两个可学习的标量参数(α 和 β),用于动态调解卷积路径和自留意力路径的权重。这种设计不光提高了模型的灵活性,还允许模型在不同深度上自顺应地选择更适合的特性提取方式。实行表明,这种设计在模型的不同阶段表现出不同的偏好,比方在早期阶段更倾向于卷积,在后期阶段更倾向于自留意力。
整体结构


第一阶段:特性投影

在第一阶段,输入特性通过三个1×1卷积举行投影,分别生成查询(query)、键(key)和值(value)特性映射。这些特性映射随后被重塑为N块,形成一个包罗3×N特性映射的中间特性集。
第二阶段:特性聚合

在第二阶段,中间特性集被分为两个路径举行处理:


  • **自留意力路径:**将中间特性集分为N组,每组包罗三个特性映射(分别对应查询、键和值)。这些特性映射按照传统的多头自留意力机制举行处理,盘算留意力权重并聚合值。
  • **卷积路径:**通过轻量级的全连接层生成k²个特性映射(k为卷积核巨细)。这些特性映射通过移位和聚合操作,以类似传统卷积的方式处理输入特性,从局部感受野网络信息。
输出整合

最后,自留意力路径和卷积路径的输出通过两个可学习的标量参数(α和β)加权求和,得到最终的输出。
改进的移位和聚合操作

为了提高盘算服从,ACmix模型采用了改进的移位操作,通过固定卷积核的分组卷积来替换传统的张量移位操作。这种方法不光提高了盘算服从,还允许卷积核的可学习性,进一步增强了模型的灵活性。
模型的灵活性和泛化本领

ACmix模型不光适用于标准的自留意力机制,还可以与各种变体(如Patchwise Attention、Window Attention和Global Attention)结合使用。这种设计使得ACmix能够顺应不同的任务需求,具有广泛的适用性。
消融实行


1. 结合两个路径的输出

消融实行探索了卷积和自留意力输出的不同组合方式对模型性能的影响。实行结果表明:


  • **卷积和自留意力的组合优于单一起径:**使用卷积和自留意力模块的组合始终优于仅使用单一起径(如仅卷积或仅自留意力)的模型。
  • **可学习参数的灵活性:**通过引入可学习的参数(如α和β)来动态调解卷积和自留意力路径的权重,ACmix能够根据网络中不同位置的需求自顺应地调解路径强度,从而得到更高的灵活性和性能。
2. 组卷积核的选择

实行还对组卷积核的设计举行了验证,结果表明:


  • **用组卷积替换张量位移:**通过使用组卷积替换传统的张量位移操作,显著提高了模型的推理速率。
  • **可学习卷积核和初始化:**使用可学习的卷积核并结合精心设计的初始化方法,进一步增强了模型的灵活性,并有助于提升最终性能。
3. 不同路径的偏好

ACmix模型引入了两个可学习标量α和β,用于动态调解卷积和自留意力路径的权重。通过平行实行,观察到以下趋势:


  • **早期阶段偏好卷积:**在Transformer模型的早期阶段,卷积作为特性提取器表现更好。
  • **中间阶段混合使用:**在网络的中间阶段,模型倾向于混合使用两种路径,并渐渐增长对卷积的偏好。
  • **后期阶段偏好自留意力:**在网络的最后阶段,自留意力表现优于卷积。
4. 对模型性能的影响

这些消融实行结果表明,ACmix模型通过合理结合卷积和自留意力的上风,并优化盘算路径,不光在多个视觉任务上取得了显著的性能提升,还保持了较高的盘算服从
ACmix模块的作用

1. 融合卷积和自留意力的上风

ACmix模块通过结合卷积的局部特性提取本领和自留意力的全局感知本领,实现了一种高效的特性融合策略。这种设计使得模型能够同时使用卷积的局部感受野特性和自留意力的灵活性。
2. 优化盘算路径

ACmix通过优化盘算路径和减少重复盘算,提高了整体模块的盘算服从。具体来说,它通过1×1卷积对输入特性图举行投影,生成中间特性,然后根据不同的范式(卷积和自留意力)分别重用和聚合这些中间特性。这种设计不光减少了盘算开销,还提升了模型性能。
3. 改进的位移与求和操作

在卷积路径中,ACmix采用深度可分离卷积(depthwise convolution)来替换低效的张量位移操作,从而提高了现实推理服从。
4. 动态调解路径权重

ACmix引入了两个可学习的标量参数(α和β),用于动态调解卷积和自留意力路径的权重。这种设计使得模型能够根据网络中不同位置的需求自顺应地调解路径强度,从而得到更高的灵活性。
5. 广泛的应用潜力

ACmix在多个视觉任务(如图像分类、语义分割和目的检测)上均表现出优于单一机制(仅卷积或仅自留意力)的性能,展示了其广泛的应用潜力。
6. 实行验证

实行结果表明,ACmix在保持较低盘算开销的同时,能够显著提升模型的性能。比方,在ImageNet分类任务中,ACmix模型在相同的FLOPs或参数数量下表现出色,并且在与竞争对手的基准比力中取得了一连的改进。别的,ACmix在ADE20K语义分割任务和COCO目的检测任务中也表现出明显的改进
代码实现

  1. import torch
  2. import torch.nn as nn
  3. def position(H, W, is_cuda=True):
  4.     if is_cuda:
  5.         loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)
  6.         loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)
  7.     else:
  8.         loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)
  9.         loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)
  10.     loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)
  11.     return loc
  12. def stride(x, stride):
  13.     b, c, h, w = x.shape
  14.     return x[:, :, ::stride, ::stride]
  15. def init_rate_half(tensor):
  16.     if tensor is not None:
  17.         tensor.data.fill_(0.5)
  18. def init_rate_0(tensor):
  19.     if tensor is not None:
  20.         tensor.data.fill_(0.)
  21. class ACmix(nn.Module):
  22.     def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):
  23.         super(ACmix, self).__init__()
  24.         self.in_planes = in_planes
  25.         self.out_planes = out_planes
  26.         self.head = head
  27.         self.kernel_att = kernel_att
  28.         self.kernel_conv = kernel_conv
  29.         self.stride = stride
  30.         self.dilation = dilation
  31.         self.rate1 = torch.nn.Parameter(torch.Tensor(1))
  32.         self.rate2 = torch.nn.Parameter(torch.Tensor(1))
  33.         self.head_dim = self.out_planes // self.head
  34.         self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
  35.         self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
  36.         self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
  37.         self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)
  38.         self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2
  39.         self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)
  40.         self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)
  41.         self.softmax = torch.nn.Softmax(dim=1)
  42.         self.fc = nn.Conv2d(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)
  43.         self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes,
  44.                                   kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1,
  45.                                   stride=stride)
  46.         self.reset_parameters()
  47.     def reset_parameters(self):
  48.         init_rate_half(self.rate1)
  49.         init_rate_half(self.rate2)
  50.         kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)
  51.         for i in range(self.kernel_conv * self.kernel_conv):
  52.             kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.
  53.         kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)
  54.         self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)
  55.         self.dep_conv.bias = init_rate_0(self.dep_conv.bias)
  56.     def forward(self, x):
  57.         q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)
  58.         scaling = float(self.head_dim) ** -0.5
  59.         b, c, h, w = q.shape
  60.         h_out, w_out = h // self.stride, w // self.stride
  61.         # ### att
  62.         # ## positional encoding
  63.         pe = self.conv_p(position(h, w, x.is_cuda))
  64.         q_att = q.view(b * self.head, self.head_dim, h, w) * scaling
  65.         k_att = k.view(b * self.head, self.head_dim, h, w)
  66.         v_att = v.view(b * self.head, self.head_dim, h, w)
  67.         if self.stride > 1:
  68.             q_att = stride(q_att, self.stride)
  69.             q_pe = stride(pe, self.stride)
  70.         else:
  71.             q_pe = pe
  72.         unfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head, self.head_dim,
  73.                                                          self.kernel_att * self.kernel_att, h_out,
  74.                                                          w_out) # b*head, head_dim, k_att^2, h_out, w_out
  75.         unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out,
  76.                                                         w_out) # 1, head_dim, k_att^2, h_out, w_out
  77.         att = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(
  78.             1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)
  79.         att = self.softmax(att)
  80.         out_att = self.unfold(self.pad_att(v_att)).view(b * self.head, self.head_dim, self.kernel_att * self.kernel_att,
  81.                                                         h_out, w_out)
  82.         out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)
  83.         ## conv
  84.         f_all = self.fc(torch.cat(
  85.             [q.view(b, self.head, self.head_dim, h * w), k.view(b, self.head, self.head_dim, h * w),
  86.              v.view(b, self.head, self.head_dim, h * w)], 1))
  87.         f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])
  88.         out_conv = self.dep_conv(f_conv)
  89.         return self.rate1 * out_att + self.rate2 * out_conv
  90. #输入 B C H W, 输出 B C H W
  91. if __name__ == '__main__':
  92.     block = ACmix(in_planes=64, out_planes=64)
  93.     input = torch.rand(3, 64, 32, 32)
  94.     output = block(input)
  95.     print(input.size(), output.size())
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

小小小幸运

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表