YOLOv11改进 | 注意力篇 | YOLOv11引入GAM注意力机制
1.GAM先容https://i-blog.csdnimg.cn/direct/599e7c825f624359910e1b1bd6cf0da2.png
摘要:为了提高各种计算机视觉任务的性能,人们研究了各种注意机制。然而,现有的方法忽略了保留通道和空间信息以加强跨维交互的重要性。因此,我们提出了一种通过减少信息减少和放大全球交互表示来提高深度神经网络性能的全球驻留机制。我们引入了具有多层单个Ceptron的3D置换用于信道注意,同时还引入了卷积空间注意子模块。对 CIFAR-100和lmageNet-1K上图像分类任务的拟议机制的评估表明我们的方法稳固地优于ResNet和轻量级的 MobileNet的几个最近的注意机制。
官方论文地址:https://ar5iv.labs.arxiv.org/html/2112.05561
官方代码地址:https://github.com/dengbuqi/GAM_Pytorch/blob/main/CAM.py
简单先容: GAM旨在通过设计一种机制,减少信息丧失并放大全局维度互动特性,从而办理传统注意力机制在通道和空间两个维度上保留信息不足的问题。GAM采用了次序的通道-空间注意力制,并对子模块进行了重新设计。具体来说,通道注意力子模块使用3D分列来跨三个维度保留信息,并通过一个两层的MLP加强跨维度的通道-空间依赖性。在空间注意力子模块中,为了更好地关注空间信息,采用了两个卷积层进行空间信息融合,同时去除了大概导致信息减少的最大池化利用。
GAM模块结构图如下:
https://i-blog.csdnimg.cn/direct/e2a0b7880fa04202923ebe5de060f669.png
2.焦点代码
import torch
import torch.nn as nn
class GAM(nn.Module):
def __init__(self, in_channels, rate=4):
super().__init__()
out_channels = in_channels
in_channels = int(in_channels)
out_channels = int(out_channels)
inchannel_rate = int(in_channels/rate)
self.linear1 = nn.Linear(in_channels, inchannel_rate)
self.relu = nn.ReLU(inplace=True)
self.linear2 = nn.Linear(inchannel_rate, in_channels)
self.conv1=nn.Conv2d(in_channels, inchannel_rate,kernel_size=7,padding=3,padding_mode='replicate')
self.conv2=nn.Conv2d(inchannel_rate, out_channels,kernel_size=7,padding=3,padding_mode='replicate')
self.norm1 = nn.BatchNorm2d(inchannel_rate)
self.norm2 = nn.BatchNorm2d(out_channels)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
b, c, h, w = x.shape
# B,C,H,W ==> B,H*W,C
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
# B,H*W,C ==> B,H,W,C
x_att_permute = self.linear2(self.relu(self.linear1(x_permute))).view(b, h, w, c)
# B,H,W,C ==> B,C,H,W
x_channel_att = x_att_permute.permute(0, 3, 1, 2)
x = x * x_channel_att
x_spatial_att = self.relu(self.norm1(self.conv1(x)))
x_spatial_att = self.sigmoid(self.norm2(self.conv2(x_spatial_att)))
out = x * x_spatial_att
return out
if __name__ == '__main__':
img = torch.rand(1,64,32,48)
b, c, h, w = img.shape
net = GAM(in_channels=c, out_channels=c)
output = net(img)
print(output.shape) 3.YOLOv11中添加GAM方式
3.1 在ultralytics/nn下新建Extramodule
https://i-blog.csdnimg.cn/blog_migrate/2633ad1138b9917900deebad8c6b4a59.png
https://i-blog.csdnimg.cn/blog_migrate/a629754e984069d6136c0a90e6f90f44.png
3.2 在Extramodule里创建GAM
https://i-blog.csdnimg.cn/blog_migrate/8deae7f5cb8433b3098b0e9ba56947be.png
https://i-blog.csdnimg.cn/direct/0ba72ec6cc5745609c7a6130cf524d46.png
在GAM.py文件里添加给出的GAM代码
添加完GAM代码后,在ultralytics/nn/Extramodule/__init__.py文件中引用
https://i-blog.csdnimg.cn/direct/0d2ab5f61a87441a9931d322c9d5c50d.png
3.3 在task.py里引用
在ultralytics/nn/tasks.py文件里引用Extramodule
https://i-blog.csdnimg.cn/blog_migrate/8ac3dc795027dac1dbcdbd4de2e218f1.png
在tasks.py找到parse_model(ctrl+f可以直接搜刮parse_model位置)
添加如下代码:
https://i-blog.csdnimg.cn/direct/afe58e90db994da885bf963112d0abcc.png
elif m in {GAM}:
c2 = ch
args = 4.新建一个yolo11GAM.yaml文件
# Ultralytics YOLO
页:
[1]