写在前面
Swin Transformer(Shifted Window Transformer)是一种新颖的视觉Transformer模子,在2021年由微软亚洲研究院提出。这一模子提出了一种基于局部窗口的自注意力机制,明显改善了Vision Transformer(ViT)在处理高分辨率图像时的性能,尤其是在图像分类、物体检测等计算机视觉任务中表现出色。
Swin Transformer的最大创新之一是其引入了“平移窗口”机制,降服了传统自注意力方法在大图像处理时计算资源斲丧过大的问题。这一机制使得模子可以或许在不同层次上以局部的方式计算自注意力,同时保持全局信息的处理能力。
在本文中,我们将通过具体的分析,介绍Swin Transformer的模子布局、核心头脑及实在现,最后提供一个基于PyTorch的简单实现。
论文地址:https://arxiv.org/pdf/2103.14030
官方代码实现:https://github.com/microsoft/Swin-Transformer
Swin网络布局
如下图所示,Swin Transformer的Encoder接纳分层的方式,通过多个阶段(Stage)渐渐镌汰特性图的分辨率,同时增加特性维度。每个Stage包罗多少个Transformer Block。
每个Block通常由以下几个部分组成:
- Window-based Self-Attention:每个Block使用窗口自注意力机制,在每个窗口内计算自注意力。这种方式镌汰了计算量,因为自注意力只在局部窗口内进行计算,而不是整个图像。
- Shifted Window:为了增强不同窗口之间的接洽,Swin Transformer在每一层的Block中接纳了“窗口位移”策略。每一层中的窗口会偏移肯定的步长,使得窗口之间的重叠地区增加,从而促进信息交换。
Patch Partition
Patch Partition 是将输入图像分割成固定大小的块(patch)并将其映射到高维空间的操作。就相称于是VIT模子当中的 Patch Embedding。
- from functools import partial
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from pyzjr.utils.FormatConver import to_2tuple
- from pyzjr.nn.models.bricks.drop import DropPath
- LayerNorm = partial(nn.LayerNorm, eps=1e-6)
- class PatchPartition(nn.Module):
- def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):
- super().__init__()
- self.patch_size = to_2tuple(patch_size)
- self.embed_dim = embed_dim
- self.proj = nn.Conv2d(in_channels, self.embed_dim,
- kernel_size=self.patch_size, stride=self.patch_size)
- self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- _, _, H, W = x.shape
- if H % self.patch_size[0] != 0:
- pad_h = self.patch_size[0] - H % self.patch_size[0]
- x = F.pad(x, (0, 0, 0, pad_h))
- if W % self.patch_size[1] != 0:
- pad_w = self.patch_size[1] - W % self.patch_size[1]
- x = F.pad(x, (0, pad_w, 0, 0))
- x = self.proj(x) # [B, embed_dim, H/patch_size, W/patch_size]
- Wh, Ww = x.shape[2:]
- x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
- # Linear Embedding
- x = self.norm(x)
- # x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
- return x, Wh, Ww
- if __name__=="__main__":
- batch_size = 1
- in_channels = 3
- height, width = 30, 32
- patch_size = 4
- embed_dim = 96
- x = torch.randn(batch_size, in_channels, height, width)
- patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
- output,_ ,_ = patch_partition(x)
- print(f"Output shape: {output.shape}")
复制代码 Patch Merging
PatchMerging 这一层用于将输入的特性图进行下采样,雷同于卷积神经网络中的池化层。

假如图像的高度或宽度是奇数,PatchMerging 会进行添补,使得其变为偶数。这是因为下采样操作必要将图像分割为以2为步长的地区。假如图像的高度或宽度是奇数,直接进行切片会导致不匀称的分割,因此必要添补以保证每个块的大小一致。
这里我们在吧如上图的相同颜色块提取并进行拼接,沿着通道维度合并成一个更大的特性,将合并后的张量重新调解形状,新的空间分辨率是原来的一半(H/2 和 W/2)。
- class PatchMerging(nn.Module):
- def __init__(self, dim, norm_layer=LayerNorm):
- super().__init__()
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
- def forward(self, x, H, W):
- """
- Args:
- x: Input feature, tensor size (B, H*W, C).
- H, W: Spatial resolution of the input feature.
- """
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- x = x.view(B, H, W, C)
- if H % 2 == 1 or W % 2 == 1:
- x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
- x = self.norm(x)
- x = self.reduction(x)
- return x
- if __name__=="__main__":
- batch_size = 1
- in_channels = 3
- height, width = 30, 32
- patch_size = 4
- embed_dim = 96
- x = torch.randn(batch_size, in_channels, height, width)
- patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
- output, Wh, Ww = patch_partition(x)
- patch_merging = PatchMerging(dim=embed_dim)
- output = patch_merging(output, Wh, Ww)
- print(output.shape)
复制代码 在代码中呢就是在高和宽的维度通过切片的形式获得,x0表示的是左上角,x1表示的是右上角,x2表示的是左下角,x3表示的是右下角。经过一系列操作后,最后通过线性层实现通道数翻倍。
W-MSA
W-MSA(Window-based Multi-Head Self-Attention)是Swin Transformer中的一个核心创新,它是为了优化传统自注意力机制在高分辨率输入图像处理中的效率问题而提出的。
这是原论文当中给出的计算公式,h,w和C分别表示特性的高度,宽度和深度,M表示窗口的大小。在标准的 Transformer 模子中,自注意力机制必要对整个输入进行计算,这使得计算和内存的斲丧随着输入的增大而急剧增长。而在图像任务中,输入图像往往具有非常高的分辨率,因此直接应用标准的全局自注意力在计算上不可行。
W-MSA 通过在局部窗口内进行自注意力计算来解决这一问题,极大地镌汰了计算和内存开销,同时保持了模子的表示能力。
SW-MSA
SW-MSA (Shifted Window-based Multi-Head Self-Attention)结合了局部窗口化自注意力和窗口偏移(shifted)策略,既提升了计算效率,又能在捕捉局部信息的底子上,保持对全局信息的建模能力。
左侧就是刚刚说到的W-MSA,经过窗口的偏移酿成了右边的SW-MSA,偏移的策略可以或许让模子在每一层的计算中捕捉到不同窗口之间的依赖关系,制止了 W-MSA 只能在单一窗口内计算的局限。这样,相邻窗口之间的信息就可以或许通过偏移和交错的方式进行交换,增强了模子的全局感知能力。
但是,现在的窗口从原来的四个酿成了九个,假如对每一个窗口再进行W-MSA那就太麻烦了。为了应对这种情况,作者提出了一种 高效批处理计算方法,旨在优化窗口偏移后的大规模窗口计算。其核心头脑是:通过批处理计算的方式来有效地处理这些偏移后的窗口,而不是每个窗口单独计算。
意思就是说将图中的A,B,C的位置通过偏移和交错方式变化后,可以将这些窗口的计算同一进行批处理,而不是一个一个地处理。这样可以明显镌汰计算时间和内存占用。
这个过程我个人感觉比力像是卡诺图,具体的过程可以看我下面画的图:
然后这里的4还和原来的一样,5和3组合成一个窗口,1和7组合成一个窗口,8、2、6、0又组合成一个窗口,这样就和原来一样是4个4x4的窗口了,保证了计算量的不变。但是假如这样做了就会将不相邻的信息混合在一起了。作者这里接纳掩蔽机制将自注意力计算限定在每个子窗口内,实在就是创建一个蒙板来屏蔽信息。
Relative Position Bias
关于这一部分,作者没有怎么提,只是经过了相对位置偏移,指标有明显的提示。
关于这一部分,我是参考的官方代码以及b站的讲解视频理解的。首先必要创建一个相对位置偏置的参数表,它的范围是从[-Wh+1, Wh-1],这里的 +1 和 -1 是因为偏移量是相对于当前元素的位置而言的,当前元素自身的偏移量为0,但我们不包罗0在偏移量的计算中(因为0表示没有偏移,通常会在自注意力机制中以其他方式处理)。因此,对于垂直方向(或水平方向),总的偏移量数量是 win_h(或 win_w)的正偏移量数量加上 win_h(或 win_w)的负偏移量数量,再减去一个(因为我们不计算0偏移量)。因此,相对位置偏置表的尺寸为:
[(2 * Wh - 1) * (2 * Ww - 1), num_heads]
每个元素的查询(Query)和键(Key)之间的内积会得到一个相似度分数,在这些分数的底子上,会加入相对位置偏置,调解相似度:
Attention = softmax((QK^T + Relative_Position_Bias) / sqrt(d_k))
其中,Q 是查询向量,K 是键向量,Relative_Position_Bias 是根据相对位置计算得到的偏置。加入相对位置偏置后,模子可以更好地捕捉到局部布局的依赖关系。
网络实现

- """
- Copyright (c) 2025, Auorui.
- All rights reserved.
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
- <https://arxiv.org/pdf/2103.14030>
- use for reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/swin_transformer.py
- https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/pytorch_classification/swin_transformer/model.py
- """
- from functools import partial
- import torch
- import numpy as np
- import torch.nn as nn
- import torch.nn.functional as F
- from pyzjr.utils.FormatConver import to_2tuple
- from pyzjr.nn.models.bricks.drop import DropPath
- from pyzjr.nn.models.bricks.initer import trunc_normal_
- LayerNorm = partial(nn.LayerNorm, eps=1e-6)
- class PatchPartition(nn.Module):
- def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):
- super().__init__()
- self.patch_size = to_2tuple(patch_size)
- self.embed_dim = embed_dim
- self.proj = nn.Conv2d(in_channels, self.embed_dim,
- kernel_size=self.patch_size, stride=self.patch_size)
- self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- _, _, H, W = x.shape
- if H % self.patch_size[0] != 0:
- pad_h = self.patch_size[0] - H % self.patch_size[0]
- x = F.pad(x, (0, 0, 0, pad_h))
- if W % self.patch_size[1] != 0:
- pad_w = self.patch_size[1] - W % self.patch_size[1]
- x = F.pad(x, (0, pad_w, 0, 0))
- x = self.proj(x) # [B, embed_dim, H/patch_size, W/patch_size]
- Wh, Ww = x.shape[2:]
- x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
- # Linear Embedding
- x = self.norm(x)
- # x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
- return x, Wh, Ww
- class MLP(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop_ratio=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop_ratio)
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
- class PatchMerging(nn.Module):
- def __init__(self, dim, norm_layer=LayerNorm):
- super().__init__()
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
- def forward(self, x, H, W):
- """
- Args:
- x: Input feature, tensor size (B, H*W, C).
- H, W: Spatial resolution of the input feature.
- """
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- x = x.view(B, H, W, C)
- if H % 2 == 1 or W % 2 == 1:
- x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
- x = self.norm(x)
- x = self.reduction(x)
- return x
- class WindowAttention(nn.Module):
- """
- Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports shifted and non-shifted windows.
- """
- def __init__(
- self,
- dim,
- window_size,
- num_heads,
- qkv_bias=True,
- proj_bias=True,
- attention_dropout_ratio=0.,
- proj_drop=0.,
- ):
- super().__init__()
- self.dim = dim
- self.window_size = to_2tuple(window_size)
- win_h, win_w = self.window_size
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = head_dim ** -0.5
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)
- ) # [2*Wh-1 * 2*Ww-1, nHeads] Offset Range: -Wh+1, Wh-1
- self.register_buffer("relative_position_index",
- self.get_relative_position_index(win_h, win_w), persistent=False)
- trunc_normal_(self.relative_position_bias_table, std=.02)
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attention_dropout_ratio)
- self.proj = nn.Linear(dim, dim, bias=proj_bias)
- self.proj_drop = nn.Dropout(proj_drop)
- self.softmax = nn.Softmax(dim=-1)
- def get_relative_position_index(self, win_h: int, win_w: int):
- # get pair-wise relative position index for each token inside the window
- coords = torch.stack(torch.meshgrid(torch.arange(win_h), torch.arange(win_w), indexing='ij')) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += win_h - 1 # shift to start from 0
- relative_coords[:, :, 1] += win_w - 1
- relative_coords[:, :, 0] *= 2 * win_w - 1
- return relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[:3]
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- attn = attn + relative_position_bias.unsqueeze(0)
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
- attn = self.attn_drop(attn)
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
- def window_partition(x, window_size: int):
- """
- 将feature map按照window_size划分成一个个没有重叠的window
- Args:
- x: (B, H, W, C)
- window_size (int): window size(M)
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
- # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
- def window_reverse(windows, window_size: int, H: int, W: int):
- """
- 将一个个window还原成一个feature map
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size(M)
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
- # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
- class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block."""
- mlp_ratio = 4
- def __init__(
- self,
- dim,
- num_heads,
- window_size=7,
- shift_size=0,
- qkv_bias=True,
- proj_bias=True,
- attention_dropout_ratio=0.,
- proj_drop=0.,
- drop_path_ratio=0.,
- norm_layer=LayerNorm,
- act_layer=nn.GELU,
- ):
- super(SwinTransformerBlock, self).__init__()
- self.dim = dim
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- assert 0 <= self.shift_size < window_size, "shift_size must in 0-window_size"
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim,
- window_size=self.window_size,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- proj_bias=proj_bias,
- attention_dropout_ratio=attention_dropout_ratio,
- proj_drop=proj_drop,
- )
- self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * self.mlp_ratio)
- self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_ratio=proj_bias)
- self.H = None
- self.W = None
- def forward(self, x, mask_matrix):
- """
- Args:
- x: Input feature, tensor size (B, H*W, C).
- H, W: Spatial resolution of the input feature.
- mask_matrix: Attention mask for cyclic shift.
- """
- B, L, C = x.shape
- H, W = self.H, self.W
- assert L == H * W, "input feature has wrong size"
- shortcut = x
- x = self.norm1(x)
- x = x.view(B, H, W, C)
- # pad feature maps to multiples of window size
- pad_l = pad_t = 0
- pad_r = (self.window_size - W % self.window_size) % self.window_size
- pad_b = (self.window_size - H % self.window_size) % self.window_size
- x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
- _, Hp, Wp, _ = x.shape
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- attn_mask = mask_matrix
- else:
- shifted_x = x
- attn_mask = None
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
- # W-MSA/SW-MSA
- attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
- if pad_r > 0 or pad_b > 0:
- x = x[:, :H, :W, :].contiguous()
- x = x.view(B, H * W, C)
- # FFN
- x = shortcut + self.drop_path(x)
- x = x + self.drop_path(self.mlp(self.norm2(x)))
- return x
- class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage."""
- def __init__(self,
- dim,
- num_layers,
- num_heads,
- drop_path,
- window_size=7,
- qkv_bias=True,
- proj_bias=True,
- attention_dropout_ratio=0.,
- proj_drop=0.,
- norm_layer=LayerNorm,
- act_layer=nn.GELU,
- downsample=None):
- super().__init__()
- self.window_size = window_size
- self.shift_size = window_size // 2
- self.num_layers = num_layers
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(
- dim=dim,
- num_heads=num_heads,
- window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- qkv_bias=qkv_bias,
- proj_bias=proj_bias,
- attention_dropout_ratio=attention_dropout_ratio,
- proj_drop=proj_drop,
- drop_path_ratio=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer,
- act_layer=act_layer)
- for i in range(num_layers)])
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
- def forward(self, x, H, W):
- """ Forward function.
- Args:
- x: Input feature, tensor size (B, H*W, C).
- H, W: Spatial resolution of the input feature.
- """
- # calculate attention mask for SW-MSA
- Hp = int(np.ceil(H / self.window_size)) * self.window_size
- Wp = int(np.ceil(W / self.window_size)) * self.window_size
- img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
- for blk in self.blocks:
- blk.H, blk.W = H, W
- x = blk(x, attn_mask)
- if self.downsample is not None:
- x = self.downsample(x, H, W)
- H, W = (H + 1) // 2, (W + 1) // 2
- return x, H, W
- class SwinTransformer(nn.Module):
- """ Swin Transformer backbone."""
- def __init__(self,
- patch_size=4,
- in_channels=3,
- num_classes=1000,
- embed_dim=96,
- depths=(2, 2, 6, 2),
- num_heads=(3, 6, 12, 24),
- window_size=7,
- qkv_bias=True,
- proj_bias=True,
- attention_dropout_ratio=0.,
- proj_drop=0.,
- drop_path_rate=0.2,
- norm_layer=LayerNorm,
- patch_norm=True,
- ):
- super().__init__()
- self.num_classes = num_classes
- self.num_layers = len(depths)
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.patch_norm = patch_norm
- # stage4输出特征矩阵的channels
- self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
- # split image into non-overlapping patches
- self.patch_embed = PatchPartition(
- patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- self.pos_drop = nn.Dropout(p=proj_drop)
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
- layers = []
- for i_layer in range(self.num_layers):
- layer = BasicLayer(
- dim=int(embed_dim * 2 ** i_layer),
- num_layers=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- qkv_bias=qkv_bias,
- proj_bias=proj_bias,
- attention_dropout_ratio=attention_dropout_ratio,
- proj_drop=proj_drop,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
- norm_layer=norm_layer,
- downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
- )
- layers.append(layer)
- self.layers = nn.Sequential(*layers)
- self.norm = norm_layer(self.num_features)
- self.avgpool = nn.AdaptiveAvgPool1d(1)
- self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
- self._initialize_weights()
- def _initialize_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- def forward(self, x):
- # x: [B, L, C]
- x, H, W = self.patch_embed(x)
- x = self.pos_drop(x)
- for layer in self.layers:
- x, H, W = layer(x, H, W)
- x = self.norm(x) # [B, L, C]
- x = self.avgpool(x.transpose(1, 2))
- x = torch.flatten(x, 1)
- x = self.head(x)
- return x
- def swin_t(num_classes) -> SwinTransformer:
- model = SwinTransformer(in_channels=3,
- patch_size=4,
- window_size=7,
- embed_dim=96,
- depths=(2, 2, 6, 2),
- num_heads=(3, 6, 12, 24),
- num_classes=num_classes)
- return model
- def swin_s(num_classes) -> SwinTransformer:
- model = SwinTransformer(in_channels=3,
- patch_size=4,
- window_size=7,
- embed_dim=96,
- depths=(2, 2, 18, 2),
- num_heads=(3, 6, 12, 24),
- num_classes=num_classes)
- return model
- def swin_b(num_classes) -> SwinTransformer:
- model = SwinTransformer(in_channels=3,
- patch_size=4,
- window_size=7,
- embed_dim=128,
- depths=(2, 2, 18, 2),
- num_heads=(4, 8, 16, 32),
- num_classes=num_classes)
- return model
- def swin_l(num_classes) -> SwinTransformer:
- model = SwinTransformer(in_channels=3,
- patch_size=4,
- window_size=7,
- embed_dim=192,
- depths=(2, 2, 18, 2),
- num_heads=(6, 12, 24, 48),
- num_classes=num_classes)
- return model
- if __name__=="__main__":
- import pyzjr
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
- input = torch.ones(2, 3, 224, 224).to(device)
- net = swin_l(num_classes=4)
- net = net.to(device)
- out = net(input)
- print(out)
- print(out.shape)
- pyzjr.summary_1(net, input_size=(3, 224, 224))
- # swin_t Total params: 27,499,108
- # swin_s Total params: 48,792,676
- # swin_b Total params: 86,683,780
- # swin_l Total params: 194,906,308
复制代码 参考文章
Swin-Transformer网络布局详解_swin transformer-CSDN博客
Swin-transformer详解_swin transformer-CSDN博客
【深度学习】详解 Swin Transformer (SwinT)-CSDN博客
推荐的视频:12.1 Swin-Transformer网络布局详解_哔哩哔哩_bilibili
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |