Swin Transformer模子详解(附pytorch实现)

打印 上一主题 下一主题

主题 1048|帖子 1048|积分 3144

写在前面

Swin Transformer(Shifted Window Transformer)是一种新颖的视觉Transformer模子,在2021年由微软亚洲研究院提出。这一模子提出了一种基于局部窗口的自注意力机制,明显改善了Vision Transformer(ViT)在处理高分辨率图像时的性能,尤其是在图像分类、物体检测等计算机视觉任务中表现出色。
Swin Transformer的最大创新之一是其引入了“平移窗口”机制,降服了传统自注意力方法在大图像处理时计算资源斲丧过大的问题。这一机制使得模子可以或许在不同层次上以局部的方式计算自注意力,同时保持全局信息的处理能力。
在本文中,我们将通过具体的分析,介绍Swin Transformer的模子布局、核心头脑及实在现,最后提供一个基于PyTorch的简单实现。
论文地址:https://arxiv.org/pdf/2103.14030
官方代码实现:https://github.com/microsoft/Swin-Transformer
Swin网络布局

如下图所示,Swin Transformer的Encoder接纳分层的方式,通过多个阶段(Stage)渐渐镌汰特性图的分辨率,同时增加特性维度。每个Stage包罗多少个Transformer Block。

每个Block通常由以下几个部分组成:



  • Window-based Self-Attention:每个Block使用窗口自注意力机制,在每个窗口内计算自注意力。这种方式镌汰了计算量,因为自注意力只在局部窗口内进行计算,而不是整个图像。
  • Shifted Window:为了增强不同窗口之间的接洽,Swin Transformer在每一层的Block中接纳了“窗口位移”策略。每一层中的窗口会偏移肯定的步长,使得窗口之间的重叠地区增加,从而促进信息交换。
Patch Partition

Patch Partition 是将输入图像分割成固定大小的块(patch)并将其映射到高维空间的操作。就相称于是VIT模子当中的 Patch Embedding。
  1. from functools import partial
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from pyzjr.utils.FormatConver import to_2tuple
  6. from pyzjr.nn.models.bricks.drop import DropPath
  7. LayerNorm = partial(nn.LayerNorm, eps=1e-6)
  8. class PatchPartition(nn.Module):
  9.     def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):
  10.         super().__init__()
  11.         self.patch_size = to_2tuple(patch_size)
  12.         self.embed_dim = embed_dim
  13.         self.proj = nn.Conv2d(in_channels, self.embed_dim,
  14.                               kernel_size=self.patch_size, stride=self.patch_size)
  15.         self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()
  16.     def forward(self, x):
  17.         _, _, H, W = x.shape
  18.         if H % self.patch_size[0] != 0:
  19.             pad_h = self.patch_size[0] - H % self.patch_size[0]
  20.             x = F.pad(x, (0, 0, 0, pad_h))
  21.         if W % self.patch_size[1] != 0:
  22.             pad_w = self.patch_size[1] - W % self.patch_size[1]
  23.             x = F.pad(x, (0, pad_w, 0, 0))
  24.         x = self.proj(x)     # [B, embed_dim, H/patch_size, W/patch_size]
  25.         Wh, Ww = x.shape[2:]
  26.         x = x.flatten(2).transpose(1, 2)    # [B, num_patches, embed_dim]
  27.         # Linear Embedding
  28.         x = self.norm(x)
  29.         # x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
  30.         return x, Wh, Ww
  31. if __name__=="__main__":
  32.     batch_size = 1
  33.     in_channels = 3
  34.     height, width = 30, 32
  35.     patch_size = 4
  36.     embed_dim = 96
  37.     x = torch.randn(batch_size, in_channels, height, width)
  38.     patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
  39.     output,_ ,_ = patch_partition(x)
  40.     print(f"Output shape: {output.shape}")
复制代码
Patch Merging

PatchMerging 这一层用于将输入的特性图进行下采样,雷同于卷积神经网络中的池化层。

假如图像的高度或宽度是奇数,PatchMerging 会进行添补,使得其变为偶数。这是因为下采样操作必要将图像分割为以2为步长的地区。假如图像的高度或宽度是奇数,直接进行切片会导致不匀称的分割,因此必要添补以保证每个块的大小一致。
这里我们在吧如上图的相同颜色块提取并进行拼接,沿着通道维度合并成一个更大的特性,将合并后的张量重新调解形状,新的空间分辨率是原来的一半(H/2 和 W/2)。
  1. class PatchMerging(nn.Module):
  2.     def __init__(self, dim, norm_layer=LayerNorm):
  3.         super().__init__()
  4.         self.dim = dim
  5.         self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  6.         self.norm = norm_layer(4 * dim)
  7.     def forward(self, x, H, W):
  8.         """
  9.         Args:
  10.             x: Input feature, tensor size (B, H*W, C).
  11.             H, W: Spatial resolution of the input feature.
  12.         """
  13.         B, L, C = x.shape
  14.         assert L == H * W, "input feature has wrong size"
  15.         x = x.view(B, H, W, C)
  16.         if H % 2 == 1 or W % 2 == 1:
  17.             x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
  18.         x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
  19.         x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
  20.         x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
  21.         x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
  22.         x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
  23.         x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
  24.         x = self.norm(x)
  25.         x = self.reduction(x)
  26.         return x
  27. if __name__=="__main__":
  28.     batch_size = 1
  29.     in_channels = 3
  30.     height, width = 30, 32
  31.     patch_size = 4
  32.     embed_dim = 96
  33.     x = torch.randn(batch_size, in_channels, height, width)
  34.     patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
  35.     output, Wh, Ww = patch_partition(x)
  36.     patch_merging = PatchMerging(dim=embed_dim)
  37.     output = patch_merging(output, Wh, Ww)
  38.     print(output.shape)
复制代码
在代码中呢就是在高和宽的维度通过切片的形式获得,x0表示的是左上角,x1表示的是右上角,x2表示的是左下角,x3表示的是右下角。经过一系列操作后,最后通过线性层实现通道数翻倍。
W-MSA

W-MSA(Window-based Multi-Head Self-Attention)是Swin Transformer中的一个核心创新,它是为了优化传统自注意力机制在高分辨率输入图像处理中的效率问题而提出的。


这是原论文当中给出的计算公式,h,w和C分别表示特性的高度,宽度和深度,M表示窗口的大小。在标准的 Transformer 模子中,自注意力机制必要对整个输入进行计算,这使得计算和内存的斲丧随着输入的增大而急剧增长。而在图像任务中,输入图像往往具有非常高的分辨率,因此直接应用标准的全局自注意力在计算上不可行。
W-MSA 通过在局部窗口内进行自注意力计算来解决这一问题,极大地镌汰了计算和内存开销,同时保持了模子的表示能力。
SW-MSA

SW-MSA (Shifted Window-based Multi-Head Self-Attention)结合了局部窗口化自注意力和窗口偏移(shifted)策略,既提升了计算效率,又能在捕捉局部信息的底子上,保持对全局信息的建模能力。

左侧就是刚刚说到的W-MSA,经过窗口的偏移酿成了右边的SW-MSA,偏移的策略可以或许让模子在每一层的计算中捕捉到不同窗口之间的依赖关系,制止了 W-MSA 只能在单一窗口内计算的局限。这样,相邻窗口之间的信息就可以或许通过偏移和交错的方式进行交换,增强了模子的全局感知能力。
但是,现在的窗口从原来的四个酿成了九个,假如对每一个窗口再进行W-MSA那就太麻烦了。为了应对这种情况,作者提出了一种 高效批处理计算方法,旨在优化窗口偏移后的大规模窗口计算。其核心头脑是:通过批处理计算的方式来有效地处理这些偏移后的窗口,而不是每个窗口单独计算。

意思就是说将图中的A,B,C的位置通过偏移和交错方式变化后,可以将这些窗口的计算同一进行批处理,而不是一个一个地处理。这样可以明显镌汰计算时间和内存占用。
这个过程我个人感觉比力像是卡诺图,具体的过程可以看我下面画的图:

然后这里的4还和原来的一样,5和3组合成一个窗口,1和7组合成一个窗口,8、2、6、0又组合成一个窗口,这样就和原来一样是4个4x4的窗口了,保证了计算量的不变。但是假如这样做了就会将不相邻的信息混合在一起了。作者这里接纳掩蔽机制将自注意力计算限定在每个子窗口内,实在就是创建一个蒙板来屏蔽信息。
Relative Position Bias

关于这一部分,作者没有怎么提,只是经过了相对位置偏移,指标有明显的提示。

关于这一部分,我是参考的官方代码以及b站的讲解视频理解的。首先必要创建一个相对位置偏置的参数表,它的范围是从[-Wh+1, Wh-1],这里的 +1 和 -1 是因为偏移量是相对于当前元素的位置而言的,当前元素自身的偏移量为0,但我们不包罗0在偏移量的计算中(因为0表示没有偏移,通常会在自注意力机制中以其他方式处理)。因此,对于垂直方向(或水平方向),总的偏移量数量是 win_h(或 win_w)的正偏移量数量加上 win_h(或 win_w)的负偏移量数量,再减去一个(因为我们不计算0偏移量)。因此,相对位置偏置表的尺寸为:
   [(2 * Wh - 1) * (2 * Ww - 1), num_heads]
  每个元素的查询(Query)和键(Key)之间的内积会得到一个相似度分数,在这些分数的底子上,会加入相对位置偏置,调解相似度:
   Attention = softmax((QK^T + Relative_Position_Bias) / sqrt(d_k))
  其中,Q 是查询向量,K 是键向量,Relative_Position_Bias 是根据相对位置计算得到的偏置。加入相对位置偏置后,模子可以更好地捕捉到局部布局的依赖关系。
网络实现


  1. """
  2. Copyright (c) 2025, Auorui.
  3. All rights reserved.
  4. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  5.     <https://arxiv.org/pdf/2103.14030>
  6. use for reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/swin_transformer.py
  7.                    https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/pytorch_classification/swin_transformer/model.py
  8. """
  9. from functools import partial
  10. import torch
  11. import numpy as np
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. from pyzjr.utils.FormatConver import to_2tuple
  15. from pyzjr.nn.models.bricks.drop import DropPath
  16. from pyzjr.nn.models.bricks.initer import trunc_normal_
  17. LayerNorm = partial(nn.LayerNorm, eps=1e-6)
  18. class PatchPartition(nn.Module):
  19.     def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):
  20.         super().__init__()
  21.         self.patch_size = to_2tuple(patch_size)
  22.         self.embed_dim = embed_dim
  23.         self.proj = nn.Conv2d(in_channels, self.embed_dim,
  24.                               kernel_size=self.patch_size, stride=self.patch_size)
  25.         self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()
  26.     def forward(self, x):
  27.         _, _, H, W = x.shape
  28.         if H % self.patch_size[0] != 0:
  29.             pad_h = self.patch_size[0] - H % self.patch_size[0]
  30.             x = F.pad(x, (0, 0, 0, pad_h))
  31.         if W % self.patch_size[1] != 0:
  32.             pad_w = self.patch_size[1] - W % self.patch_size[1]
  33.             x = F.pad(x, (0, pad_w, 0, 0))
  34.         x = self.proj(x)     # [B, embed_dim, H/patch_size, W/patch_size]
  35.         Wh, Ww = x.shape[2:]
  36.         x = x.flatten(2).transpose(1, 2)    # [B, num_patches, embed_dim]
  37.         # Linear Embedding
  38.         x = self.norm(x)
  39.         # x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
  40.         return x, Wh, Ww
  41. class MLP(nn.Module):
  42.     def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop_ratio=0.):
  43.         super().__init__()
  44.         out_features = out_features or in_features
  45.         hidden_features = hidden_features or in_features
  46.         self.fc1 = nn.Linear(in_features, hidden_features)
  47.         self.act = act_layer()
  48.         self.fc2 = nn.Linear(hidden_features, out_features)
  49.         self.drop = nn.Dropout(drop_ratio)
  50.     def forward(self, x):
  51.         x = self.fc1(x)
  52.         x = self.act(x)
  53.         x = self.drop(x)
  54.         x = self.fc2(x)
  55.         x = self.drop(x)
  56.         return x
  57. class PatchMerging(nn.Module):
  58.     def __init__(self, dim, norm_layer=LayerNorm):
  59.         super().__init__()
  60.         self.dim = dim
  61.         self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  62.         self.norm = norm_layer(4 * dim)
  63.     def forward(self, x, H, W):
  64.         """
  65.         Args:
  66.             x: Input feature, tensor size (B, H*W, C).
  67.             H, W: Spatial resolution of the input feature.
  68.         """
  69.         B, L, C = x.shape
  70.         assert L == H * W, "input feature has wrong size"
  71.         x = x.view(B, H, W, C)
  72.         if H % 2 == 1 or W % 2 == 1:
  73.             x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
  74.         x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
  75.         x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
  76.         x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
  77.         x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
  78.         x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
  79.         x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
  80.         x = self.norm(x)
  81.         x = self.reduction(x)
  82.         return x
  83. class WindowAttention(nn.Module):
  84.     """
  85.     Window based multi-head self attention (W-MSA) module with relative position bias.
  86.     It supports shifted and non-shifted windows.
  87.     """
  88.     def __init__(
  89.             self,
  90.             dim,
  91.             window_size,
  92.             num_heads,
  93.             qkv_bias=True,
  94.             proj_bias=True,
  95.             attention_dropout_ratio=0.,
  96.             proj_drop=0.,
  97.     ):
  98.         super().__init__()
  99.         self.dim = dim
  100.         self.window_size = to_2tuple(window_size)
  101.         win_h, win_w = self.window_size
  102.         self.num_heads = num_heads
  103.         head_dim = dim // num_heads
  104.         self.scale = head_dim ** -0.5
  105.         # define a parameter table of relative position bias
  106.         self.relative_position_bias_table = nn.Parameter(
  107.             torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)
  108.         )   # [2*Wh-1 * 2*Ww-1, nHeads]   Offset Range: -Wh+1, Wh-1
  109.         self.register_buffer("relative_position_index",
  110.                              self.get_relative_position_index(win_h, win_w), persistent=False)
  111.         trunc_normal_(self.relative_position_bias_table, std=.02)
  112.         self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  113.         self.attn_drop = nn.Dropout(attention_dropout_ratio)
  114.         self.proj = nn.Linear(dim, dim, bias=proj_bias)
  115.         self.proj_drop = nn.Dropout(proj_drop)
  116.         self.softmax = nn.Softmax(dim=-1)
  117.     def get_relative_position_index(self, win_h: int, win_w: int):
  118.         # get pair-wise relative position index for each token inside the window
  119.         coords = torch.stack(torch.meshgrid(torch.arange(win_h), torch.arange(win_w), indexing='ij'))  # 2, Wh, Ww
  120.         coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
  121.         relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
  122.         relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
  123.         relative_coords[:, :, 0] += win_h - 1  # shift to start from 0
  124.         relative_coords[:, :, 1] += win_w - 1
  125.         relative_coords[:, :, 0] *= 2 * win_w - 1
  126.         return relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
  127.     def forward(self, x, mask=None):
  128.         """
  129.         Args:
  130.             x: input features with shape of (num_windows*B, N, C)
  131.             mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
  132.         """
  133.         B, N, C = x.shape
  134.         qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  135.         q, k, v = qkv[:3]
  136.         q = q * self.scale
  137.         attn = (q @ k.transpose(-2, -1))
  138.         relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  139.             self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
  140.         relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
  141.         attn = attn + relative_position_bias.unsqueeze(0)
  142.         if mask is not None:
  143.             nW = mask.shape[0]
  144.             attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  145.             attn = attn.view(-1, self.num_heads, N, N)
  146.             attn = self.softmax(attn)
  147.         else:
  148.             attn = self.softmax(attn)
  149.         attn = self.attn_drop(attn)
  150.         x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  151.         x = self.proj(x)
  152.         x = self.proj_drop(x)
  153.         return x
  154. def window_partition(x, window_size: int):
  155.     """
  156.     将feature map按照window_size划分成一个个没有重叠的window
  157.     Args:
  158.         x: (B, H, W, C)
  159.         window_size (int): window size(M)
  160.     Returns:
  161.         windows: (num_windows*B, window_size, window_size, C)
  162.     """
  163.     B, H, W, C = x.shape
  164.     x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  165.     # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
  166.     # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
  167.     windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  168.     return windows
  169. def window_reverse(windows, window_size: int, H: int, W: int):
  170.     """
  171.     将一个个window还原成一个feature map
  172.     Args:
  173.         windows: (num_windows*B, window_size, window_size, C)
  174.         window_size (int): Window size(M)
  175.         H (int): Height of image
  176.         W (int): Width of image
  177.     Returns:
  178.         x: (B, H, W, C)
  179.     """
  180.     B = int(windows.shape[0] / (H * W / window_size / window_size))
  181.     # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
  182.     x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
  183.     # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
  184.     # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
  185.     x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  186.     return x
  187. class SwinTransformerBlock(nn.Module):
  188.     r""" Swin Transformer Block."""
  189.     mlp_ratio = 4
  190.     def __init__(
  191.         self,
  192.         dim,
  193.         num_heads,
  194.         window_size=7,
  195.         shift_size=0,
  196.         qkv_bias=True,
  197.         proj_bias=True,
  198.         attention_dropout_ratio=0.,
  199.         proj_drop=0.,
  200.         drop_path_ratio=0.,
  201.         norm_layer=LayerNorm,
  202.         act_layer=nn.GELU,
  203.     ):
  204.         super(SwinTransformerBlock, self).__init__()
  205.         self.dim = dim
  206.         self.num_heads = num_heads
  207.         self.window_size = window_size
  208.         self.shift_size = shift_size
  209.         assert 0 <= self.shift_size < window_size, "shift_size must in 0-window_size"
  210.         self.norm1 = norm_layer(dim)
  211.         self.attn = WindowAttention(
  212.             dim,
  213.             window_size=self.window_size,
  214.             num_heads=num_heads,
  215.             qkv_bias=qkv_bias,
  216.             proj_bias=proj_bias,
  217.             attention_dropout_ratio=attention_dropout_ratio,
  218.             proj_drop=proj_drop,
  219.         )
  220.         self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
  221.         self.norm2 = norm_layer(dim)
  222.         mlp_hidden_dim = int(dim * self.mlp_ratio)
  223.         self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_ratio=proj_bias)
  224.         self.H = None
  225.         self.W = None
  226.     def forward(self, x, mask_matrix):
  227.         """
  228.         Args:
  229.             x: Input feature, tensor size (B, H*W, C).
  230.             H, W: Spatial resolution of the input feature.
  231.             mask_matrix: Attention mask for cyclic shift.
  232.         """
  233.         B, L, C = x.shape
  234.         H, W = self.H, self.W
  235.         assert L == H * W, "input feature has wrong size"
  236.         shortcut = x
  237.         x = self.norm1(x)
  238.         x = x.view(B, H, W, C)
  239.         # pad feature maps to multiples of window size
  240.         pad_l = pad_t = 0
  241.         pad_r = (self.window_size - W % self.window_size) % self.window_size
  242.         pad_b = (self.window_size - H % self.window_size) % self.window_size
  243.         x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
  244.         _, Hp, Wp, _ = x.shape
  245.         # cyclic shift
  246.         if self.shift_size > 0:
  247.             shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  248.             attn_mask = mask_matrix
  249.         else:
  250.             shifted_x = x
  251.             attn_mask = None
  252.         # partition windows
  253.         x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
  254.         x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
  255.         # W-MSA/SW-MSA
  256.         attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C
  257.         # merge windows
  258.         attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  259.         shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C
  260.         # reverse cyclic shift
  261.         if self.shift_size > 0:
  262.             x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  263.         else:
  264.             x = shifted_x
  265.         if pad_r > 0 or pad_b > 0:
  266.             x = x[:, :H, :W, :].contiguous()
  267.         x = x.view(B, H * W, C)
  268.         # FFN
  269.         x = shortcut + self.drop_path(x)
  270.         x = x + self.drop_path(self.mlp(self.norm2(x)))
  271.         return x
  272. class BasicLayer(nn.Module):
  273.     """ A basic Swin Transformer layer for one stage."""
  274.     def __init__(self,
  275.                  dim,
  276.                  num_layers,
  277.                  num_heads,
  278.                  drop_path,
  279.                  window_size=7,
  280.                  qkv_bias=True,
  281.                  proj_bias=True,
  282.                  attention_dropout_ratio=0.,
  283.                  proj_drop=0.,
  284.                  norm_layer=LayerNorm,
  285.                  act_layer=nn.GELU,
  286.                  downsample=None):
  287.         super().__init__()
  288.         self.window_size = window_size
  289.         self.shift_size = window_size // 2
  290.         self.num_layers = num_layers
  291.         # build blocks
  292.         self.blocks = nn.ModuleList([
  293.             SwinTransformerBlock(
  294.                 dim=dim,
  295.                 num_heads=num_heads,
  296.                 window_size=window_size,
  297.                 shift_size=0 if (i % 2 == 0) else window_size // 2,
  298.                 qkv_bias=qkv_bias,
  299.                 proj_bias=proj_bias,
  300.                 attention_dropout_ratio=attention_dropout_ratio,
  301.                 proj_drop=proj_drop,
  302.                 drop_path_ratio=drop_path[i] if isinstance(drop_path, list) else drop_path,
  303.                 norm_layer=norm_layer,
  304.                 act_layer=act_layer)
  305.             for i in range(num_layers)])
  306.         # patch merging layer
  307.         if downsample is not None:
  308.             self.downsample = downsample(dim=dim, norm_layer=norm_layer)
  309.         else:
  310.             self.downsample = None
  311.     def forward(self, x, H, W):
  312.         """ Forward function.
  313.         Args:
  314.             x: Input feature, tensor size (B, H*W, C).
  315.             H, W: Spatial resolution of the input feature.
  316.         """
  317.         # calculate attention mask for SW-MSA
  318.         Hp = int(np.ceil(H / self.window_size)) * self.window_size
  319.         Wp = int(np.ceil(W / self.window_size)) * self.window_size
  320.         img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
  321.         h_slices = (slice(0, -self.window_size),
  322.                     slice(-self.window_size, -self.shift_size),
  323.                     slice(-self.shift_size, None))
  324.         w_slices = (slice(0, -self.window_size),
  325.                     slice(-self.window_size, -self.shift_size),
  326.                     slice(-self.shift_size, None))
  327.         cnt = 0
  328.         for h in h_slices:
  329.             for w in w_slices:
  330.                 img_mask[:, h, w, :] = cnt
  331.                 cnt += 1
  332.         mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
  333.         mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  334.         attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  335.         attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  336.         for blk in self.blocks:
  337.             blk.H, blk.W = H, W
  338.             x = blk(x, attn_mask)
  339.         if self.downsample is not None:
  340.             x = self.downsample(x, H, W)
  341.             H, W = (H + 1) // 2, (W + 1) // 2
  342.         return x, H, W
  343. class SwinTransformer(nn.Module):
  344.     """ Swin Transformer backbone."""
  345.     def __init__(self,
  346.                  patch_size=4,
  347.                  in_channels=3,
  348.                  num_classes=1000,
  349.                  embed_dim=96,
  350.                  depths=(2, 2, 6, 2),
  351.                  num_heads=(3, 6, 12, 24),
  352.                  window_size=7,
  353.                  qkv_bias=True,
  354.                  proj_bias=True,
  355.                  attention_dropout_ratio=0.,
  356.                  proj_drop=0.,
  357.                  drop_path_rate=0.2,
  358.                  norm_layer=LayerNorm,
  359.                  patch_norm=True,
  360.                  ):
  361.         super().__init__()
  362.         self.num_classes = num_classes
  363.         self.num_layers = len(depths)
  364.         self.num_layers = len(depths)
  365.         self.embed_dim = embed_dim
  366.         self.patch_norm = patch_norm
  367.         # stage4输出特征矩阵的channels
  368.         self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
  369.         # split image into non-overlapping patches
  370.         self.patch_embed = PatchPartition(
  371.             patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
  372.             norm_layer=norm_layer if self.patch_norm else None)
  373.         self.pos_drop = nn.Dropout(p=proj_drop)
  374.         dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
  375.         layers = []
  376.         for i_layer in range(self.num_layers):
  377.             layer = BasicLayer(
  378.                 dim=int(embed_dim * 2 ** i_layer),
  379.                 num_layers=depths[i_layer],
  380.                 num_heads=num_heads[i_layer],
  381.                 window_size=window_size,
  382.                 qkv_bias=qkv_bias,
  383.                 proj_bias=proj_bias,
  384.                 attention_dropout_ratio=attention_dropout_ratio,
  385.                 proj_drop=proj_drop,
  386.                 drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  387.                 norm_layer=norm_layer,
  388.                 downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
  389.                 )
  390.             layers.append(layer)
  391.         self.layers = nn.Sequential(*layers)
  392.         self.norm = norm_layer(self.num_features)
  393.         self.avgpool = nn.AdaptiveAvgPool1d(1)
  394.         self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  395.         self._initialize_weights()
  396.     def _initialize_weights(self):
  397.         for m in self.modules():
  398.             if isinstance(m, nn.Linear):
  399.                 trunc_normal_(m.weight, std=.02)
  400.                 if isinstance(m, nn.Linear) and m.bias is not None:
  401.                     nn.init.constant_(m.bias, 0)
  402.             elif isinstance(m, nn.LayerNorm):
  403.                 nn.init.constant_(m.bias, 0)
  404.                 nn.init.constant_(m.weight, 1.0)
  405.     def forward(self, x):
  406.         # x: [B, L, C]
  407.         x, H, W = self.patch_embed(x)
  408.         x = self.pos_drop(x)
  409.         for layer in self.layers:
  410.             x, H, W = layer(x, H, W)
  411.         x = self.norm(x)  # [B, L, C]
  412.         x = self.avgpool(x.transpose(1, 2))
  413.         x = torch.flatten(x, 1)
  414.         x = self.head(x)
  415.         return x
  416. def swin_t(num_classes) -> SwinTransformer:
  417.     model = SwinTransformer(in_channels=3,
  418.                             patch_size=4,
  419.                             window_size=7,
  420.                             embed_dim=96,
  421.                             depths=(2, 2, 6, 2),
  422.                             num_heads=(3, 6, 12, 24),
  423.                             num_classes=num_classes)
  424.     return model
  425. def swin_s(num_classes) -> SwinTransformer:
  426.     model = SwinTransformer(in_channels=3,
  427.                             patch_size=4,
  428.                             window_size=7,
  429.                             embed_dim=96,
  430.                             depths=(2, 2, 18, 2),
  431.                             num_heads=(3, 6, 12, 24),
  432.                             num_classes=num_classes)
  433.     return model
  434. def swin_b(num_classes) -> SwinTransformer:
  435.     model = SwinTransformer(in_channels=3,
  436.                             patch_size=4,
  437.                             window_size=7,
  438.                             embed_dim=128,
  439.                             depths=(2, 2, 18, 2),
  440.                             num_heads=(4, 8, 16, 32),
  441.                             num_classes=num_classes)
  442.     return model
  443. def swin_l(num_classes) -> SwinTransformer:
  444.     model = SwinTransformer(in_channels=3,
  445.                             patch_size=4,
  446.                             window_size=7,
  447.                             embed_dim=192,
  448.                             depths=(2, 2, 18, 2),
  449.                             num_heads=(6, 12, 24, 48),
  450.                             num_classes=num_classes)
  451.     return model
  452. if __name__=="__main__":
  453.     import pyzjr
  454.     device = 'cuda' if torch.cuda.is_available() else 'cpu'
  455.     input = torch.ones(2, 3, 224, 224).to(device)
  456.     net = swin_l(num_classes=4)
  457.     net = net.to(device)
  458.     out = net(input)
  459.     print(out)
  460.     print(out.shape)
  461.     pyzjr.summary_1(net, input_size=(3, 224, 224))
  462.     # swin_t Total params: 27,499,108
  463.     # swin_s Total params: 48,792,676
  464.     # swin_b Total params: 86,683,780
  465.     # swin_l Total params: 194,906,308
复制代码
参考文章

Swin-Transformer网络布局详解_swin transformer-CSDN博客
Swin-transformer详解_swin transformer-CSDN博客 
【深度学习】详解 Swin Transformer (SwinT)-CSDN博客 
推荐的视频:12.1 Swin-Transformer网络布局详解_哔哩哔哩_bilibili 

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

圆咕噜咕噜

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表