如安在YOLOv8网络中添加自界说留意力机制

商道如狼道  金牌会员 | 2024-11-4 20:00:44 | 来自手机 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 838|帖子 838|积分 2514

在目标检测任务中,加入留意力机制可以提升模型的检测结果。本文将介绍如安在YOLOv8模型中集成多种留意力机制,如 SimAM、ShuffleAttention、TripletAttention、MHSA、CBAM 和 EMA,以增强模型对图像特征的提取能力。我们将展示每个留意力机制的代码示例,并讨论如何将这些模块添加到YOLOv8网络中。

目次



  • 1. 留意力机制示例代码

    • 1.1 SimAM 模块代码
    • 1.2 ShuffleAttention 模块代码
    • 1.3 TripletAttention 模块代码
    • 1.4 MHSA 模块代码
    • 1.5 CBAM 模块代码
    • 1.6 EMA 模块代码
    • 1.7 ECA 模块代码

  • 2. 添加留意力机制的步骤

    • 2.1 修改YOLOv8的配置文件
    • 2.2 编写自界说留意力机制模块
    • 2.3 训练和验证留意力机制

  • 3. 总结

1. 留意力机制示例代码

下面介绍六种常用的留意力机制模块,并提供代码示例。每种留意力机制都有其独特的优点,可以根据任务需求选择最适合的机制。
1.1 SimAM 模块代码

SimAM (Simple Attention Module) 是一种轻量级的留意力机制;通过简单的操纵实现了留意力机制的结果,适用于对计算资源敏感的项目。它适合那些盼望在提升模型性能的同时,尽量淘汰计算开销的任务,比如嵌入式设备上的实时目标检测。。
  1. import torch
  2. import torch.nn as nn
  3. class SimAM(torch.nn.Module):
  4.     def __init__(self, e_lambda=1e-4):
  5.         super(SimAM, self).__init__()
  6.         self.activaton = nn.Sigmoid()
  7.         self.e_lambda = e_lambda
  8.     def __repr__(self):
  9.         s = self.__class__.__name__ + '('
  10.         s += ('lambda=%f)' % self.e_lambda)
  11.         return s
  12.     @staticmethod
  13.     def get_module_name():
  14.         return "simam"
  15.     def forward(self, x):
  16.         b, c, h, w = x.size()
  17.         n = w * h - 1
  18.         x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
  19.         y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
  20.         return x * self.activaton(y)
  21. # if __name__ == '__main__':
  22. #     input = torch.randn(3, 64, 7, 7)
  23. #     model = SimAM()
  24. #     outputs = model(input)
  25. #     print(outputs.shape)
复制代码
1.2 ShuffleAttention 模块代码

ShuffleAttention 适合需要全局特征交互的场景;通过通道洗牌操纵重新排列特征,确保模型能够在不同通道间传递信息,提升特征的全局表达能力。对于需要处理复杂、具有多样性特征的图像(如交通场景、复杂的自然环境),这种机制能有用提升模型的感知能力。
  1. import torch
  2. from torch import nn
  3. from torch.nn import init
  4. from torch.nn.parameter import Parameter
  5. class ShuffleAttention(nn.Module):
  6.     def __init__(self, channel=512, reduction=16, G=8):
  7.         super().__init__()
  8.         self.G = G
  9.         self.channel = channel
  10.         self.avg_pool = nn.AdaptiveAvgPool2d(1)
  11.         self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
  12.         self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
  13.         self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
  14.         self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
  15.         self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
  16.         self.sigmoid = nn.Sigmoid()
  17.     def init_weights(self):
  18.         for m in self.modules():
  19.             if isinstance(m, nn.Conv2d):
  20.                 init.kaiming_normal_(m.weight, mode='fan_out')
  21.                 if m.bias is not None:
  22.                     init.constant_(m.bias, 0)
  23.             elif isinstance(m, nn.BatchNorm2d):
  24.                 init.constant_(m.weight, 1)
  25.                 init.constant_(m.bias, 0)
  26.             elif isinstance(m, nn.Linear):
  27.                 init.normal_(m.weight, std=0.001)
  28.                 if m.bias is not None:
  29.                     init.constant_(m.bias, 0)
  30.     @staticmethod
  31.     def channel_shuffle(x, groups):
  32.         b, c, h, w = x.shape
  33.         x = x.reshape(b, groups, -1, h, w)
  34.         x = x.permute(0, 2, 1, 3, 4)
  35.         # flatten
  36.         x = x.reshape(b, -1, h, w)
  37.         return x
  38.     def forward(self, x):
  39.         b, c, h, w = x.size()
  40.         # group into subfeatures
  41.         x = x.view(b * self.G, -1, h, w)  # bs*G,c//G,h,w
  42.         # channel_split
  43.         x_0, x_1 = x.chunk(2, dim=1)  # bs*G,c//(2*G),h,w
  44.         # channel attention
  45.         x_channel = self.avg_pool(x_0)  # bs*G,c//(2*G),1,1
  46.         x_channel = self.cweight * x_channel + self.cbias  # bs*G,c//(2*G),1,1
  47.         x_channel = x_0 * self.sigmoid(x_channel)
  48.         # spatial attention
  49.         x_spatial = self.gn(x_1)  # bs*G,c//(2*G),h,w
  50.         x_spatial = self.sweight * x_spatial + self.sbias  # bs*G,c//(2*G),h,w
  51.         x_spatial = x_1 * self.sigmoid(x_spatial)  # bs*G,c//(2*G),h,w
  52.         # concatenate along channel axis
  53.         out = torch.cat([x_channel, x_spatial], dim=1)  # bs*G,c//G,h,w
  54.         out = out.contiguous().view(b, -1, h, w)
  55.         # channel shuffle
  56.         out = self.channel_shuffle(out, 2)
  57.         return out
  58. if __name__ == '__main__':
  59.     input = torch.randn(50, 512, 7, 7)
  60.     se = ShuffleAttention(channel=512, G=8)
  61.     output = se(input)
复制代码
1.3 TripletAttention 模块代码

TripletAttention 适合需要捕捉多方向特征的场景;在通道上引入了三个方向的留意力(水平、垂直、深度),能够帮助模型更好地感知多方向上的特征。这种机制特别适用于那些需要捕捉方向性信息的任务,比如门路标记检测和自然场景理解。
  1. import torch
  2. import torch.nn as nn
  3. class BasicConv(nn.Module):  # https://arxiv.org/pdf/2010.03045.pdf
  4.     def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
  5.                  bn=True, bias=False):
  6.         super(BasicConv, self).__init__()
  7.         self.out_channels = out_planes
  8.         self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
  9.                               dilation=dilation, groups=groups, bias=bias)
  10.         self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
  11.         self.relu = nn.ReLU() if relu else None
  12.     def forward(self, x):
  13.         x = self.conv(x)
  14.         if self.bn is not None:
  15.             x = self.bn(x)
  16.         if self.relu is not None:
  17.             x = self.relu(x)
  18.         return x
  19. class ZPool(nn.Module):
  20.     def forward(self, x):
  21.         return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
  22. class AttentionGate(nn.Module):
  23.     def __init__(self):
  24.         super(AttentionGate, self).__init__()
  25.         kernel_size = 7
  26.         self.compress = ZPool()
  27.         self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)
  28.     def forward(self, x):
  29.         x_compress = self.compress(x)
  30.         x_out = self.conv(x_compress)
  31.         scale = torch.sigmoid_(x_out)
  32.         return x * scale
  33. class TripletAttention(nn.Module):
  34.     def __init__(self, no_spatial=False):
  35.         super(TripletAttention, self).__init__()
  36.         self.cw = AttentionGate()
  37.         self.hc = AttentionGate()
  38.         self.no_spatial = no_spatial
  39.         if not no_spatial:
  40.             self.hw = AttentionGate()
  41.     def forward(self, x):
  42.         x_perm1 = x.permute(0, 2, 1, 3).contiguous()
  43.         x_out1 = self.cw(x_perm1)
  44.         x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
  45.         x_perm2 = x.permute(0, 3, 2, 1).contiguous()
  46.         x_out2 = self.hc(x_perm2)
  47.         x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
  48.         if not self.no_spatial:
  49.             x_out = self.hw(x)
  50.             x_out = 1 / 3 * (x_out + x_out11 + x_out21)
  51.         else:
  52.             x_out = 1 / 2 * (x_out11 + x_out21)
  53.         return x_out
复制代码
1.4 MHSA 模块代码

MHSA (Multi-Head Self-Attention) 是常用于Transformer模型的留意力机制,适合大规模上下文建模的场景;通过多头自留意力的机制,能够帮助模型捕捉图像中的长距离依赖关系。它在需要处理上下文信息的任务中体现出色,如自然场景中的多物体检测。对于需要全局信息并且图像内物体之间具有复杂相互关系的任务,MHSA 是理想的选择。
  1. import torch
  2. import torch.nn as nn
  3. class MHSA(nn.Module):
  4.     def __init__(self, n_dims, width=14, height=14, heads=4, pos_emb=False):
  5.         super(MHSA, self).__init__()
  6.         self.heads = heads
  7.         self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
  8.         self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
  9.         self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
  10.         self.pos = pos_emb
  11.         if self.pos:
  12.             self.rel_h_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, 1, int(height)]),
  13.                                              requires_grad=True)
  14.             self.rel_w_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, int(width), 1]),
  15.                                              requires_grad=True)
  16.         self.softmax = nn.Softmax(dim=-1)
  17.     def forward(self, x):
  18.         n_batch, C, width, height = x.size()
  19.         q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
  20.         k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
  21.         v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)
  22.         content_content = torch.matmul(q.permute(0, 1, 3, 2), k)  # 1,C,h*w,h*w
  23.         c1, c2, c3, c4 = content_content.size()
  24.         if self.pos:
  25.             content_position = (self.rel_h_weight + self.rel_w_weight).view(1, self.heads, C // self.heads, -1).permute(
  26.                 0, 1, 3, 2)  # 1,4,1024,64
  27.             content_position = torch.matmul(content_position, q)  # ([1, 4, 1024, 256])
  28.             content_position = content_position if (
  29.                     content_content.shape == content_position.shape) else content_position[:, :, :c3, ]
  30.             assert (content_content.shape == content_position.shape)
  31.             energy = content_content + content_position
  32.         else:
  33.             energy = content_content
  34.         attention = self.softmax(energy)
  35.         out = torch.matmul(v, attention.permute(0, 1, 3, 2))  # 1,4,256,64
  36.         out = out.view(n_batch, C, width, height)
  37.         return out
  38. # if __name__ == '__main__':
  39. #     input = torch.randn(50, 512, 7, 7)
  40. #     mhsa = MHSA(n_dims=512)
  41. #     output = mhsa(input)
  42. #     print(output.shape)
复制代码
1.5 CBAM 模块代码

CBAM(Convolutional Block Attention Module)适合需要联合通道和空间特征的场景;通过联合通道留意力和空间留意力,帮助网络更加精准地捕捉图像中的关键区域。它适用于大多数量标检测任务,特别是当需要细化某些特定物体的检测时,比如在自动驾驶中的行人检测或交通标记检测。
  1. import torch
  2. from torch import nn
  3. class ChannelAttention(nn.Module):
  4.     # Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet
  5.     def __init__(self, channels: int) -> None:
  6.         super().__init__()
  7.         self.pool = nn.AdaptiveAvgPool2d(1)
  8.         self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
  9.         self.act = nn.Sigmoid()
  10.     def forward(self, x: torch.Tensor) -> torch.Tensor:
  11.         return x * self.act(self.fc(self.pool(x)))
  12. class SpatialAttention(nn.Module):
  13.     # Spatial-attention module
  14.     def __init__(self, kernel_size=7):
  15.         super().__init__()
  16.         assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
  17.         padding = 3 if kernel_size == 7 else 1
  18.         self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
  19.         self.act = nn.Sigmoid()
  20.     def forward(self, x):
  21.         return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
  22. class CBAM(nn.Module):
  23.     # Convolutional Block Attention Module
  24.     def __init__(self, c1, kernel_size=7):  # ch_in, kernels
  25.         super().__init__()
  26.         self.channel_attention = ChannelAttention(c1)
  27.         self.spatial_attention = SpatialAttention(kernel_size)
  28.     def forward(self, x):
  29.         return self.spatial_attention(self.channel_attention(x))
复制代码
1.6 EMA 模块代码

EMA(Efficient Multi-Head Attention)适合盼望在多头自留意力中提升服从的场景;它通过淘汰计算复杂度而提升性能,适用于大规模数据集的训练。它在保持留意力机制强盛的特征捕捉能力的同时,还能显著低落计算成本,适合高性能要求的任务场景。
  1. import torch
  2. from torch import nn
  3. class EMA(nn.Module):
  4.     def __init__(self, channels, c2=None, factor=32):
  5.         super(EMA, self).__init__()
  6.         self.groups = factor
  7.         assert channels // self.groups > 0
  8.         self.softmax = nn.Softmax(-1)
  9.         self.agp = nn.AdaptiveAvgPool2d((1, 1))
  10.         self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
  11.         self.pool_w = nn.AdaptiveAvgPool2d((1, None))
  12.         self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
  13.         self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
  14.         self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)
  15.     def forward(self, x):
  16.         b, c, h, w = x.size()
  17.         group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,w
  18.         x_h = self.pool_h(group_x)
  19.         x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
  20.         hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
  21.         x_h, x_w = torch.split(hw, [h, w], dim=2)
  22.         x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
  23.         x2 = self.conv3x3(group_x)
  24.         x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
  25.         x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
  26.         x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
  27.         x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
  28.         weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
  29.         return (group_x * weights.sigmoid()).reshape(b, c, h, w)
复制代码
1.7 ECA 模块代码

ECA(Efficient Channel Attention)适合需要高效通道留意力的场景;通过消除全毗连层,利用1D卷积进行局部交互,大大淘汰了参数量,同时仍旧保留了通道留意力的能力。它适合那些对计算资源有限定的场景,比如移动设备上进行目标检测的任务。
  1. import torch
  2. from torch import nn
  3. class ECA(nn.Module):
  4.     def __init__(self, channels: int, k_size: int = 3):
  5.         super(ECA, self).__init__()
  6.         self.avg_pool = nn.AdaptiveAvgPool2d(1)
  7.         self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
  8.         self.sigmoid = nn.Sigmoid()
  9.     def forward(self, x):
  10.         # Apply global average pooling
  11.         y = self.avg_pool(x)
  12.         # Reshape and apply 1D convolution
  13.         y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
  14.         # Apply sigmoid activation and element-wise multiplication
  15.         return x * self.sigmoid(y)
复制代码

2. 添加留意力机制的步骤

2.1 修改YOLOv8的配置文件

我们可以通过在YOLOv8配置文件中指定利用留意力机制,以下是如安在第10层加入留意力机制的配置示例,以ShuffleAttention留意力机制为例,用到哪个放开哪个:
[code]# Ultralytics YOLO

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

商道如狼道

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表