爆改YOLOv8 | 利用MB-TaylorFormer提高YOLOv8图像去雾检测
1,本文介绍MB-TaylorFormer是一种新型多支路线性Transformer网络,用于图像去雾任务。它通过泰勒展开改进了softmax-attention,使用多支路和多尺度结构以获取多条理和多尺度的信息,且比传统方法在性能、盘算量和网络重量上更优。
关于MB-TaylorFormer的详细介绍可以看论文:https://arxiv.org/pdf/2308.14036.pdf
本文将讲解怎样将MB-TaylorFormer融合进yolov8
话不多说,上代码!
2,将MB-TaylorFormer融合进YOLOv8
2.1 步调一
首先找到如下的目次'ultralytics/nn',然后在这个目次下创建一个'Addmodules'文件夹,然后在这个目次下创建一个TaylorFormer.py文件,文件名字可以根据你自己的习惯起,然后将MB-TaylorFormer的核心代码复制进去。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops.deform_conv import DeformConv2d
import numbers
import math
from einops import rearrange
import numpy as np
__all__ = ['MB_TaylorFormer']
freqs_dict = dict()
##########################################################################
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x, h, w):
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma + 1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
##########################################################################
## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
hidden_features = int(dim * ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
groups=hidden_features * 2, bias=bias)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
class refine_att(nn.Module):
"""Convolutional relative position encoding."""
def __init__(self, Ch, h, window):
super().__init__()
if isinstance(window, int):
# Set the same window size for all attention heads.
window = {window: h}
self.window = window
elif isinstance(window, dict):
self.window = window
else:
raise ValueError()
self.conv_list = nn.ModuleList()
self.head_splits = []
for cur_window, cur_head_split in window.items():
dilation = 1# Use dilation=1 at default.
padding_size = (cur_window + (cur_window - 1) *
(dilation - 1)) // 2
cur_conv = nn.Conv2d(
cur_head_split * Ch * 2,
cur_head_split,
kernel_size=(cur_window, cur_window),
padding=(padding_size, padding_size),
dilation=(dilation, dilation),
groups=cur_head_split,
)
self.conv_list.append(cur_conv)
self.head_splits.append(cur_head_split)
self.channel_splits =
def forward(self, q, k, v, size):
"""foward function"""
B, h, N, Ch = q.shape
H, W = size
# We don't use CLS_TOKEN
q_img = q
k_img = k
v_img = v
# Shape: -> .
q_img = rearrange(q_img, "B h (H W) Ch -> B h Ch H W", H=H, W=W)
k_img = rearrange(k_img, "B h Ch (H W) -> B h Ch H W", H=H, W=W)
qk_concat = torch.cat((q_img, k_img), 2)
qk_concat = rearrange(qk_concat, "B h Ch H W -> B (h Ch) H W", H=H, W=W)
# Split according to channels.
qk_concat_list = torch.split(qk_concat, self.channel_splits, dim=1)
qk_att_list = [
conv(x) for conv, x in zip(self.conv_list, qk_concat_list)
]
qk_att = torch.cat(qk_att_list, dim=1)
# Shape: -> .
qk_att = rearrange(qk_att, "B (h Ch) H W -> B h (H W) Ch", h=h)
return qk_att
##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias, shared_refine_att=None, qk_norm=1):
super(Attention, self).__init__()
self.norm = qk_norm
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
# self.Leakyrelu=nn.LeakyReLU(negative_slope=0.01,inplace=True)
self.sigmoid = nn.Sigmoid()
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
if num_heads == 8:
crpe_window = {
3: 2,
5: 3,
7: 3
}
elif num_heads == 1:
crpe_window = {
3: 1,
}
elif num_heads == 2:
crpe_window = {
3: 2,
}
elif num_heads == 4:
crpe_window = {
3: 2,
5: 2,
}
self.refine_att = refine_att(Ch=dim // num_heads,
h=num_heads,
window=crpe_window)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head (h w) c', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head (h w) c', head=self.num_heads)
# q = torch.nn.functional.normalize(q, dim=-1)
q_norm = torch.norm(q, p=2, dim=-1, keepdim=True) / self.norm + 1e-6
q = torch.div(q, q_norm)
k_norm = torch.norm(k, p=2, dim=-2, keepdim=True) / self.norm + 1e-6
k = torch.div(k, k_norm)
# k = torch.nn.functional.normalize(k, dim=-2)
refine_weight = self.refine_att(q, k, v, size=(h, w))
# refine_weight=self.Leakyrelu(refine_weight)
refine_weight = self.sigmoid(refine_weight)
attn = k @ v
# attn = attn.softmax(dim=-1)
# print(torch.sum(k, dim=-1).unsqueeze(3).shape)
out_numerator = torch.sum(v, dim=-2).unsqueeze(2) + (q @ attn)
out_denominator = torch.full((h * w, c // self.num_heads), h * w).to(q.device) \
+ q @ torch.sum(k, dim=-1).unsqueeze(3).repeat(1, 1, 1, c // self.num_heads) + 1e-6
# out=torch.div(out_numerator,out_denominator)*self.temperature*refine_weight
out = torch.div(out_numerator, out_denominator) * self.temperature
out = out * refine_weight
out = rearrange(out, 'b head (h w) c-> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
##########################################################################
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, shared_refine_att=None, qk_norm=1):
super(TransformerBlock, self).__init__()
self.norm1 = LayerNorm(dim, LayerNorm_type)
self.attn = Attention(dim, num_heads, bias, shared_refine_att=shared_refine_att, qk_norm=qk_norm)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
class MHCAEncoder(nn.Module):
"""Multi-Head Convolutional self-Attention Encoder comprised of `MHCA`
blocks."""
def __init__(
self,
dim,
num_layers=1,
num_heads=8,
ffn_expansion_factor=2.66,
bias=False,
LayerNorm_type='BiasFree',
qk_norm=1
):
super().__init__()
self.num_layers = num_layers
self.MHCA_layers = nn.ModuleList([
TransformerBlock(
dim,
num_heads=num_heads,
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
qk_norm=qk_norm
) for idx in range(self.num_layers)
])
def forward(self, x, size):
"""foward function"""
H, W = size
B = x.shape
# return x's shape : ->
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
for layer in self.MHCA_layers:
x = layer(x)
return x
class ResBlock(nn.Module):
"""Residual block for convolutional local feature."""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.Hardswish,
norm_layer=nn.BatchNorm2d,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
# self.act0 = act_layer()
self.conv1 = Conv2d_BN(in_features,
hidden_features,
act_layer=act_layer)
self.dwconv = nn.Conv2d(
hidden_features,
hidden_features,
3,
1,
1,
bias=False,
groups=hidden_features,
)
# self.norm = norm_layer(hidden_features)
self.act = act_layer()
self.conv2 = Conv2d_BN(hidden_features, out_features)
self.apply(self._init_weights)
def _init_weights(self, m):
"""
initialization
"""
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size * m.kernel_size * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
"""foward function"""
identity = x
# x=self.act0(x)
feat = self.conv1(x)
feat = self.dwconv(feat)
# feat = self.norm(feat)
feat = self.act(feat)
feat = self.conv2(feat)
return identity + feat
class MHCA_stage(nn.Module):
"""Multi-Head Convolutional self-Attention stage comprised of `MHCAEncoder`
layers."""
def __init__(
self,
embed_dim,
out_embed_dim,
num_layers=1,
num_heads=8,
ffn_expansion_factor=2.66,
num_path=4,
bias=False,
LayerNorm_type='BiasFree',
qk_norm=1
):
super().__init__()
self.mhca_blks = nn.ModuleList([
MHCAEncoder(
embed_dim,
num_layers,
num_heads,
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
qk_norm=qk_norm
) for _ in range(num_path)
])
self.aggregate = SKFF(embed_dim, height=num_path)
# self.InvRes = ResBlock(in_features=embed_dim, out_features=embed_dim)
# self.aggregate = Conv2d_aggregate(embed_dim * (num_path + 1),
# out_embed_dim,
# act_layer=nn.Hardswish)
def forward(self, inputs):
"""foward function"""
# att_outputs = )]
att_outputs = []
for x, encoder in zip(inputs, self.mhca_blks):
# ->
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2).contiguous()
att_outputs.append(encoder(x, size=(H, W)))
# out_concat = torch.cat(att_outputs, dim=1)
out = self.aggregate(att_outputs)
return out
##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class Conv2d_BN(nn.Module):
def __init__(
self,
in_ch,
out_ch,
kernel_size=1,
stride=1,
pad=0,
dilation=1,
groups=1,
bn_weight_init=1,
norm_layer=nn.BatchNorm2d,
act_layer=None,
):
super().__init__()
self.conv = torch.nn.Conv2d(in_ch,
out_ch,
kernel_size,
stride,
pad,
dilation,
groups,
bias=False)
# self.bn = norm_layer(out_ch)
# torch.nn.init.constant_(self.bn.weight, bn_weight_init)
# torch.nn.init.constant_(self.bn.bias, 0)
for m in self.modules():
if isinstance(m, nn.Conv2d):
# Note that there is no bias due to BN
fan_out = m.kernel_size * m.kernel_size * m.out_channels
m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out))
self.act_layer = act_layer() if act_layer is not None else nn.Identity()
def forward(self, x):
x = self.conv(x)
# x = self.bn(x)
x = self.act_layer(x)
return x
class SKFF(nn.Module):
def __init__(self, in_channels, height=2, reduction=8, bias=False):
super(SKFF, self).__init__()
self.height = height
d = max(int(in_channels / reduction), 4)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_du = nn.Sequential(nn.Conv2d(in_channels, d, 1, padding=0, bias=bias), nn.PReLU())
self.fcs = nn.ModuleList([])
for i in range(self.height):
self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1, bias=bias))
self.softmax = nn.Softmax(dim=1)
def forward(self, inp_feats):
batch_size = inp_feats.shape
n_feats = inp_feats.shape
inp_feats = torch.cat(inp_feats, dim=1)
inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape, inp_feats.shape)
feats_U = torch.sum(inp_feats, dim=1)
feats_S = self.avg_pool(feats_U)
feats_Z = self.conv_du(feats_S)
attention_vectors =
attention_vectors = torch.cat(attention_vectors, dim=1)
attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1)
# stx()
attention_vectors = self.softmax(attention_vectors)
feats_V = torch.sum(inp_feats * attention_vectors, dim=1)
return feats_V
class DWConv2d_BN(nn.Module):
def __init__(
self,
in_ch,
out_ch,
kernel_size=1,
stride=1,
norm_layer=nn.BatchNorm2d,
act_layer=nn.Hardswish,
bn_weight_init=1,
offset_clamp=(-1, 1)
):
super().__init__()
# dw
# self.conv=torch.nn.Conv2d(in_ch,out_ch,kernel_size,stride,(kernel_size - 1) // 2,bias=False,)
# self.mask_generator = nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=3,
# stride=1, padding=1, bias=False, groups=in_ch),
# nn.Conv2d(in_channels=in_ch, out_channels=9,
# kernel_size=1,
# stride=1, padding=0, bias=False)
# )
self.offset_clamp = offset_clamp
self.offset_generator = nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=3,
stride=1, padding=1, bias=False, groups=in_ch),
nn.Conv2d(in_channels=in_ch, out_channels=18,
kernel_size=1,
stride=1, padding=0, bias=False)
)
self.dcn = DeformConv2d(
in_channels=in_ch,
out_channels=in_ch,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=in_ch
)# .cuda(7)
self.pwconv = nn.Conv2d(in_ch, out_ch, 1, 1, 0, bias=False)
# self.bn = norm_layer(out_ch)
self.act = act_layer() if act_layer is not None else nn.Identity()
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size * m.kernel_size * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
if m.bias is not None:
m.bias.data.zero_()
# print(m)
# elif isinstance(m, nn.BatchNorm2d):
# m.weight.data.fill_(bn_weight_init)
# m.bias.data.zero_()
def forward(self, x):
# x=self.conv(x)
# x = self.bn(x)
# x = self.act(x)
# mask= torch.sigmoid(self.mask_generator(x))
# print('1')
offset = self.offset_generator(x)
# print('2')
if self.offset_clamp:
offset = torch.clamp(offset, min=self.offset_clamp, max=self.offset_clamp)# .cuda(7)1
# print(offset)
# print('3')
# x=x.cuda(7)
x = self.dcn(x, offset)
# x=x.cpu()
# print('4')
x = self.pwconv(x)
# print('5')
# x = self.bn(x)
x = self.act(x)
return x
class DWCPatchEmbed(nn.Module):
"""Depthwise Convolutional Patch Embedding layer Image to Patch
Embedding."""
def __init__(self,
in_chans=3,
embed_dim=768,
patch_size=16,
stride=1,
idx=0,
act_layer=nn.Hardswish,
offset_clamp=(-1, 1)):
super().__init__()
self.patch_conv = DWConv2d_BN(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=stride,
act_layer=act_layer,
offset_clamp=offset_clamp
)
"""
self.patch_conv = DWConv2d_BN(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=stride,
act_layer=act_layer,
)
"""
def forward(self, x):
"""foward function"""
x = self.patch_conv(x)
return x
class Patch_Embed_stage(nn.Module):
"""Depthwise Convolutional Patch Embedding stage comprised of
`DWCPatchEmbed` layers."""
def __init__(self, in_chans, embed_dim, num_path=4, isPool=False, offset_clamp=(-1, 1)):
super(Patch_Embed_stage, self).__init__()
self.patch_embeds = nn.ModuleList([
DWCPatchEmbed(
in_chans=in_chans if idx == 0 else embed_dim,
embed_dim=embed_dim,
patch_size=3,
stride=1,
idx=idx,
offset_clamp=offset_clamp
) for idx in range(num_path)
])
def forward(self, x):
"""foward function"""
att_inputs = []
for pe in self.patch_embeds:
x = pe(x)
att_inputs.append(x)
return att_inputs
class OverlapPatchEmbed(nn.Module):
def __init__(self, in_c=3, embed_dim=48, bias=False):
super(OverlapPatchEmbed, self).__init__()
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
# self.proj_dw = nn.Conv2d(in_c, in_c, kernel_size=3, stride=1, padding=1,groups=in_c, bias=bias)
# self.proj_pw = nn.Conv2d(in_c, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
# self.bn=nn.BatchNorm2d(embed_dim)
# self.act=nn.Hardswish()
def forward(self, x):
x = self.proj(x)
# x = self.proj_dw(x)
# x= self.proj_pw(x)
# x=self.bn(x)
# x=self.act(x)
return x
##########################################################################
## Resizing modules
class Downsample(nn.Module):
def __init__(self, input_feat, out_feat):
super(Downsample, self).__init__()
self.body = nn.Sequential(# nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False),
# dw
nn.Conv2d(input_feat, input_feat, kernel_size=3, stride=1, padding=1, groups=input_feat, bias=False, ),
# pw-linear
nn.Conv2d(input_feat, out_feat // 4, 1, 1, 0, bias=False),
# nn.BatchNorm2d(n_feat // 2),
# nn.Hardswish(),
nn.PixelUnshuffle(2))
def forward(self, x):
return self.body(x)
class Upsample(nn.Module):
def __init__(self, input_feat, out_feat):
super(Upsample, self).__init__()
self.body = nn.Sequential(# nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
# dw
nn.Conv2d(input_feat, input_feat, kernel_size=3, stride=1, padding=1, groups=input_feat, bias=False, ),
# pw-linear
nn.Conv2d(input_feat, out_feat * 4, 1, 1, 0, bias=False),
# nn.BatchNorm2d(n_feat*2),
# nn.Hardswish(),
nn.PixelShuffle(2))
def forward(self, x):
return self.body(x)
##########################################################################
##---------- Restormer -----------------------
class MB_TaylorFormer(nn.Module):
def __init__(self,
inp_channels=3,
dim=,
num_blocks=,
heads=,
bias=False,
dual_pixel_task=True,
num_path=,## True for dual-pixel defocus deblurring only. Also set inp_channels=6
qk_norm=1,
offset_clamp=(-1, 1)
):
super(MB_TaylorFormer, self).__init__()
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.patch_embed_encoder_level1 = Patch_Embed_stage(dim, dim, num_path=num_path, isPool=False,
offset_clamp=offset_clamp)
self.encoder_level1 = MHCA_stage(dim, dim, num_layers=num_blocks, num_heads=heads,
ffn_expansion_factor=2.66, num_path=num_path,
bias=False, LayerNorm_type='BiasFree', qk_norm=qk_norm)
self.down1_2 = Downsample(dim, dim)## From Level 1 to Level 2
self.patch_embed_encoder_level2 = Patch_Embed_stage(dim, dim, num_path=num_path, isPool=False,
offset_clamp=offset_clamp)
self.encoder_level2 = MHCA_stage(dim, dim, num_layers=num_blocks, num_heads=heads,
ffn_expansion_factor=2.66,
num_path=num_path, bias=False, LayerNorm_type='BiasFree', qk_norm=qk_norm)
self.down2_3 = Downsample(dim, dim)## From Level 2 to Level 3
self.patch_embed_encoder_level3 = Patch_Embed_stage(dim, dim, num_path=num_path,
isPool=False, offset_clamp=offset_clamp)
self.encoder_level3 = MHCA_stage(dim, dim, num_layers=num_blocks, num_heads=heads,
ffn_expansion_factor=2.66,
num_path=num_path, bias=False, LayerNorm_type='BiasFree', qk_norm=qk_norm)
self.down3_4 = Downsample(dim, dim)## From Level 3 to Level 4
self.patch_embed_latent = Patch_Embed_stage(dim, dim, num_path=num_path,
isPool=False, offset_clamp=offset_clamp)
self.latent = MHCA_stage(dim, dim, num_layers=num_blocks, num_heads=heads,
ffn_expansion_factor=2.66, num_path=num_path, bias=False,
LayerNorm_type='BiasFree', qk_norm=qk_norm)
self.up4_3 = Upsample(int(dim), dim)## From Level 4 to Level 3
self.reduce_chan_level3 = nn.Sequential(
nn.Conv2d(dim * 2, dim, 1, 1, 0, bias=bias),
# nn.BatchNorm2d(dim * 2**2),
# nn.Hardswish(),
)
self.patch_embed_decoder_level3 = Patch_Embed_stage(dim, dim, num_path=num_path,
isPool=False, offset_clamp=offset_clamp)
self.decoder_level3 = MHCA_stage(dim, dim, num_layers=num_blocks, num_heads=heads,
ffn_expansion_factor=2.66, num_path=num_path, bias=False,
LayerNorm_type='BiasFree', qk_norm=qk_norm)
self.up3_2 = Upsample(int(dim), dim)## From Level 3 to Level 2
self.reduce_chan_level2 = nn.Sequential(
nn.Conv2d(dim * 2, dim, 1, 1, 0, bias=bias),
# nn.BatchNorm2d( dim * 2),
# nn.Hardswish(),
)
self.patch_embed_decoder_level2 = Patch_Embed_stage(dim, dim, num_path=num_path,
isPool=False, offset_clamp=offset_clamp)
self.decoder_level2 = MHCA_stage(dim, dim, num_layers=num_blocks, num_heads=heads,
ffn_expansion_factor=2.66, num_path=num_path, bias=False,
LayerNorm_type='BiasFree', qk_norm=qk_norm)
self.up2_1 = Upsample(int(dim), dim)## From Level 2 to Level 1(NO 1x1 conv to reduce channels)
self.patch_embed_decoder_level1 = Patch_Embed_stage(dim, dim, num_path=num_path,
isPool=False, offset_clamp=offset_clamp)
self.decoder_level1 = MHCA_stage(dim, dim, num_layers=num_blocks, num_heads=heads,
ffn_expansion_factor=2.66, num_path=num_path, bias=False,
LayerNorm_type='BiasFree', qk_norm=qk_norm)
self.patch_embed_refinement = Patch_Embed_stage(dim, dim, num_path=num_path,
isPool=False, offset_clamp=offset_clamp)
self.refinement = MHCA_stage(dim, dim, num_layers=num_blocks, num_heads=heads,
ffn_expansion_factor=2.66, num_path=num_path, bias=False,
LayerNorm_type='BiasFree', qk_norm=qk_norm)
#### For Dual-Pixel Defocus Deblurring Task ####
self.dual_pixel_task = dual_pixel_task
if self.dual_pixel_task:
self.skip_conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
###########################
# self.output = nn.Conv2d(dim*2**1, 3, kernel_size=3, stride=1, padding=1, bias=False)
self.output = nn.Sequential(# nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
# nn.BatchNorm2d(dim*2),
# nn.Hardswish(),
nn.Conv2d(dim, 3, kernel_size=3, stride=1, padding=1, bias=False, ),
)
def forward(self, inp_img):
inp_enc_level1 = self.patch_embed(inp_img)
inp_enc_level1_list = self.patch_embed_encoder_level1(inp_enc_level1)
out_enc_level1 = self.encoder_level1(inp_enc_level1_list) + inp_enc_level1
# out_enc_level1 = self.encoder_level1(inp_enc_level1_list)
inp_enc_level2 = self.down1_2(out_enc_level1)
inp_enc_level2_list = self.patch_embed_encoder_level2(inp_enc_level2)
out_enc_level2 = self.encoder_level2(inp_enc_level2_list) + inp_enc_level2
inp_enc_level3 = self.down2_3(out_enc_level2)
inp_enc_level3_list = self.patch_embed_encoder_level3(inp_enc_level3)
out_enc_level3 = self.encoder_level3(inp_enc_level3_list) + inp_enc_level3
inp_enc_level4 = self.down3_4(out_enc_level3)
inp_latent = self.patch_embed_latent(inp_enc_level4)
latent = self.latent(inp_latent) + inp_enc_level4
inp_dec_level3 = self.up4_3(latent)
inp_dec_level3 = torch.cat(, 1)
inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
inp_dec_level3_list = self.patch_embed_decoder_level3(inp_dec_level3)
out_dec_level3 = self.decoder_level3(inp_dec_level3_list) + inp_dec_level3
inp_dec_level2 = self.up3_2(out_dec_level3)
inp_dec_level2 = torch.cat(, 1)
inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
inp_dec_level2_list = self.patch_embed_decoder_level2(inp_dec_level2)
out_dec_level2 = self.decoder_level2(inp_dec_level2_list) + inp_dec_level2
inp_dec_level1 = self.up2_1(out_dec_level2)
inp_dec_level1 = torch.cat(, 1)
inp_dec_level1_list = self.patch_embed_decoder_level1(inp_dec_level1)
out_dec_level1 = self.decoder_level1(inp_dec_level1_list) + inp_dec_level1
inp_latent_list = self.patch_embed_refinement(out_dec_level1)
out_dec_level1 = self.refinement(inp_latent_list) + out_dec_level1
# nn.Hardswish()
#### For Dual-Pixel Defocus Deblurring Task ####
if self.dual_pixel_task:
out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
out_dec_level1 = self.output(out_dec_level1)
###########################
else:
out_dec_level1 = self.output(out_dec_level1) + inp_img
return out_dec_level1
def count_param(model):
param_count = 0
for param in model.parameters():
param_count += param.view(-1).size()
return param_count
if __name__ == "__main__":
from thop import profile
model = MB_TaylorFormer()
model.eval()
print("params", count_param(model))
inputs = torch.randn(1, 3, 640, 640)
output = model(inputs)
print(output.size()) 2.2 步调二
在Addmodules下创建一个新的py文件名字为'__init__.py',然后在其内部添加如下代码
https://i-blog.csdnimg.cn/direct/5d8af592a6f64a8eb6d38f046097a4ab.png
2.3 步调三
在task.py举行导入
https://i-blog.csdnimg.cn/direct/69cb07188ab34322a7313d21f53f1305.png
到此注册成功,复制背面的yaml文件直接运行即可
yaml文件
# Ultralytics YOLO
页:
[1]