CV05_深度学习模块之间的缝合教学(1)
1.1 在哪里缝测试文件?(×)
练习文件?(×)
模型文件?(√)
1.2 骨干网络与模块缝合
以Vision Transformer为例,模型文件里有很多类,我们只在末了集大成的那个类里添加模块。
https://i-blog.csdnimg.cn/direct/a9e2a10d87204884ab197a1f4d1beb2b.png
之后后,我们准备好我们要缝合的模块,比如SE Net模块,我们先创建一个测试文件测试可否跑通
import numpy as np
import torch
from torch import nn
from torch.nn import init
class SEAttention(nn.Module):
# 初始化SE模块,channel为通道数,reduction为降维比率
def __init__(self, channel=512, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)# 自适应平均池化层,将特征图的空间维度压缩为1x1
self.fc = nn.Sequential(# 定义两个全连接层作为激励操作,通过降维和升维调整通道重要性
nn.Linear(channel, channel // reduction, bias=False),# 降维,减少参数数量和计算量
nn.ReLU(inplace=True),# ReLU激活函数,引入非线性
nn.Linear(channel // reduction, channel, bias=False),# 升维,恢复到原始通道数
nn.Sigmoid()# Sigmoid激活函数,输出每个通道的重要性系数
)
# 权重初始化方法
def init_weights(self):
for m in self.modules():# 遍历模块中的所有子模块
if isinstance(m, nn.Conv2d):# 对于卷积层
init.kaiming_normal_(m.weight, mode='fan_out')# 使用Kaiming初始化方法初始化权重
if m.bias is not None:
init.constant_(m.bias, 0)# 如果有偏置项,则初始化为0
elif isinstance(m, nn.BatchNorm2d):# 对于批归一化层
init.constant_(m.weight, 1)# 权重初始化为1
init.constant_(m.bias, 0)# 偏置初始化为0
elif isinstance(m, nn.Linear):# 对于全连接层
init.normal_(m.weight, std=0.001)# 权重使用正态分布初始化
if m.bias is not None:
init.constant_(m.bias, 0)# 偏置初始化为0
# 前向传播方法
def forward(self, x):
b, c, _, _ = x.size()# 获取输入x的批量大小b和通道数c
y = self.avg_pool(x).view(b, c)# 通过自适应平均池化层后,调整形状以匹配全连接层的输入
y = self.fc(y).view(b, c, 1, 1)# 通过全连接层计算通道重要性,调整形状以匹配原始特征图的形状
return x * y.expand_as(x)# 将通道重要性系数应用到原始特征图上,进行特征重新校准
# 示例使用
if __name__ == '__main__':
input = torch.randn(50, 512, 7, 7)# 随机生成一个输入特征图
se = SEAttention(channel=512, reduction=8)# 实例化SE模块,设置降维比率为8
output = se(input)# 将输入特征图通过SE模块进行处理
print(output.shape)# 打印处理后的特征图形状,验证SE模块的作用 https://i-blog.csdnimg.cn/direct/bf62d13bc19a4c26937a56f0763ff907.png
打印处置处罚后的形状,我们这里要留意,缝合模块时只必要留意第一维,也就是这个channel,要和骨干网络保持一致,只要你把输入输出的通道数对齐,那么这个通道数就可以缝合乐成。
把模块复制进骨干网络中:
https://i-blog.csdnimg.cn/direct/aa2797d1055b4521ab5548c4e3cbd889.png
然后举行缝合,在缝合之前要先测试通道是否匹配,否则肯定报错。
怎样验证通道数
我们找到骨干网络前向传播的部门,在你想加入这个模块地方print(x.shape)即可。运行练习文件:
放在最前面:
https://i-blog.csdnimg.cn/direct/0e529589c50b40ab983b4344a0da41e4.png
https://i-blog.csdnimg.cn/direct/cffdf786040c418881a42348a4cee767.png
通道数为3(8为batch size)。
将模块添加进骨干网络
在骨干网络的init函数下添加:(ctrl+p可查看参数)通道数与之前查的对齐。
https://i-blog.csdnimg.cn/direct/ba72d2a5f3674c109d260fdfdf764f66.png
在前向传播中添加:
https://i-blog.csdnimg.cn/direct/ca47f6a6358c4624bc88c657084acf53.png
看看是否正常运行:
https://i-blog.csdnimg.cn/direct/7ef768805b0d4a389954570ceb43907d.png
正常运行,阐明模块缝合乐成!
打印缝合后的模型结构
该操作在模型文件中举行。
https://i-blog.csdnimg.cn/direct/2839763ff406461ea08b98d6d423cff8.png
VisionTransformer(
(patch_embed): PatchEmbed(
(proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
(norm): Identity()
)
(pos_drop): Dropout(p=0.0, inplace=False)
(blocks): Sequential(
(0): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(2): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(3): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(4): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(5): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(6): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(7): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(8): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(9): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(10): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(11): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
(norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(pre_logits): Sequential(
(fc): Linear(in_features=768, out_features=768, bias=True)
(act): Tanh()
)
(head): Linear(in_features=768, out_features=21843, bias=True)
(se): SEAttention(
(avg_pool): AdaptiveAvgPool2d(output_size=1)
(fc): Sequential(
(0): Linear(in_features=3, out_features=0, bias=False)
(1): ReLU(inplace=True)
(2): Linear(in_features=0, out_features=3, bias=False)
(3): Sigmoid()
)
)
)
我们可以看到多了一个SEAttention,阐明模块缝合进去了!
https://i-blog.csdnimg.cn/direct/5acbf0d85a8c462594d07a7431d5647b.png
1.3 模块之间缝合
以SENet和ECA模块为例。
https://i-blog.csdnimg.cn/direct/410d03295d3947a48359c2ad80e840fd.png
串联模块
方式1
同1.2。照猫画虎。(留意通道数保持一致)
https://i-blog.csdnimg.cn/direct/e45d848547124d23870c6f6dada65781.pnghttps://i-blog.csdnimg.cn/direct/fd5c72d8f3db4f36b17cc650231dd2e0.png
打印模型结构:
ECAAttention(
(gap): AdaptiveAvgPool2d(output_size=1)
(conv): Conv1d(1, 1, kernel_size=(3,), stride=(1,), padding=(1,))
(sigmoid): Sigmoid()
(se): SEAttention(
(avg_pool): AdaptiveAvgPool2d(output_size=1)
(fc): Sequential(
(0): Linear(in_features=64, out_features=4, bias=False)
(1): ReLU(inplace=True)
(2): Linear(in_features=4, out_features=64, bias=False)
(3): Sigmoid()
)))
方式2
我们界说一个串联函数,将模块之间串联起来:
https://i-blog.csdnimg.cn/direct/469b0dc4b0c04bc2847990e52db75987.png
实例化查看一下模型结构
https://i-blog.csdnimg.cn/direct/65442771359f4213b703c646be0782a8.png
输出结果:
torch.Size() torch.Size()
Cascade(
(se): SEAttention(
(avg_pool): AdaptiveAvgPool2d(output_size=1)
(fc): Sequential(
(0): Linear(in_features=63, out_features=3, bias=False)
(1): ReLU(inplace=True)
(2): Linear(in_features=3, out_features=63, bias=False)
(3): Sigmoid()
)
)
(eca): ECAAttention(
(gap): AdaptiveAvgPool2d(output_size=1)
(conv): Conv1d(1, 1, kernel_size=(63,), stride=(1,), padding=(31,))
(sigmoid): Sigmoid()
)
)
并联模块
对于并联模块,方法有很多种,两个两个模块输出的张量可以:
(1)逐元素相加(2)逐元素相乘(3)concat拼接(4)等等
https://i-blog.csdnimg.cn/direct/d92bd1728b8049fd9482b57265ebcc0a.png
输出结果:
torch.Size() torch.Size()
Cascade(
(se): SEAttention(
(avg_pool): AdaptiveAvgPool2d(output_size=1)
(fc): Sequential(
(0): Linear(in_features=63, out_features=3, bias=False)
(1): ReLU(inplace=True)
(2): Linear(in_features=3, out_features=63, bias=False)
(3): Sigmoid()
)
)
(eca): ECAAttention(
(gap): AdaptiveAvgPool2d(output_size=1)
(conv): Conv1d(1, 1, kernel_size=(63,), stride=(1,), padding=(31,))
(sigmoid): Sigmoid()
)
)
1.4 思考
我们不要拘泥于只串联获并联,可以将二者联合,多个模块中,部门模块并联后又与其他模块串联,等等。。这种排列组合之后,总会有一个你想要的模型!!!
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页:
[1]