https://www.bilibili.com/video/BV1yg411K7Yc?spm_id_from=333.337.search-card.all.click

ba25677c1a4602e57b4fa9147dcac52.jpg


Class SwinTransformer

  1. class SwinTransformer(nn.Module):
  2. r""" Swin Transformer
  3. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  4. https://arxiv.org/pdf/2103.14030
  5. Args:
  6. patch_size (int | tuple(int)): Patch size. Default: 4
  7. in_chans (int): Number of input image channels. Default: 3
  8. num_classes (int): Number of classes for classification head. Default: 1000
  9. embed_dim (int): Patch embedding dimension. Default: 96
  10. depths (tuple(int)): Depth of each Swin Transformer layer.
  11. num_heads (tuple(int)): Number of attention heads in different layers.
  12. window_size (int): Window size. Default: 7
  13. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  14. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  15. drop_rate (float): Dropout rate. Default: 0
  16. attn_drop_rate (float): Attention dropout rate. Default: 0
  17. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  18. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  19. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  20. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
  21. """
  22. def __init__(self, patch_size=4, in_chans=3, num_classes=1000,
  23. embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
  24. window_size=7, mlp_ratio=4., qkv_bias=True,
  25. drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
  26. norm_layer=nn.LayerNorm, patch_norm=True,
  27. use_checkpoint=False, **kwargs):
  28. super().__init__()
  29. self.num_classes = num_classes
  30. self.num_layers = len(depths)
  31. self.embed_dim = embed_dim
  32. self.patch_norm = patch_norm
  33. # stage4输出特征矩阵的channels
  34. self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
  35. self.mlp_ratio = mlp_ratio
  36. # split image into non-overlapping patches
  37. self.patch_embed = PatchEmbed(
  38. patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
  39. norm_layer=norm_layer if self.patch_norm else None)
  40. self.pos_drop = nn.Dropout(p=drop_rate)
  41. # stochastic depth
  42. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  43. # build layers
  44. self.layers = nn.ModuleList()
  45. for i_layer in range(self.num_layers):
  46. # 注意这里构建的stage和论文图中有些差异
  47. # 这里的stage不包含该stage的patch_merging层,包含的是下个stage的
  48. layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
  49. depth=depths[i_layer],
  50. num_heads=num_heads[i_layer],
  51. window_size=window_size,
  52. mlp_ratio=self.mlp_ratio,
  53. qkv_bias=qkv_bias,
  54. drop=drop_rate,
  55. attn_drop=attn_drop_rate,
  56. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  57. norm_layer=norm_layer,
  58. downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
  59. use_checkpoint=use_checkpoint)
  60. self.layers.append(layers)
  61. self.norm = norm_layer(self.num_features)
  62. self.avgpool = nn.AdaptiveAvgPool1d(1)
  63. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  64. self.apply(self._init_weights)
  65. def _init_weights(self, m):
  66. if isinstance(m, nn.Linear):
  67. nn.init.trunc_normal_(m.weight, std=.02)
  68. if isinstance(m, nn.Linear) and m.bias is not None:
  69. nn.init.constant_(m.bias, 0)
  70. elif isinstance(m, nn.LayerNorm):
  71. nn.init.constant_(m.bias, 0)
  72. nn.init.constant_(m.weight, 1.0)
  73. def forward(self, x):
  74. # x: [B, L, C]
  75. x, H, W = self.patch_embed(x)
  76. x = self.pos_drop(x)
  77. for layer in self.layers:
  78. x, H, W = layer(x, H, W)
  79. x = self.norm(x) # [B, L, C]
  80. x = self.avgpool(x.transpose(1, 2)) # [B, C, 1]
  81. x = torch.flatten(x, 1)
  82. x = self.head(x)
  83. return x

image.png
在这里有个def init()
里面是用来初始化Swin Transformer模型的。

1.patch_size

patch在Swin Transformer中指最小的图像块的大小。
Swin Transformer - 图3
patch_size = 4就是4倍的下采样。
image.png
对应图中就是H/4 X W/4

2.depths=(2,2,6,2)

是对应于Stage1 ,2 ,3 ,4 分别用了几个Swin Transformer Block

3.num_heads = (3,6,12,24)

在Swin Transformer中,所采用的multi-head self-attention的个数
image.png
image.png这里有写3 ,6, 12, 24


除了这些参数,还有很多参数需要设置,但本质上都是配置模型。有些参数还得从ViT——Vision Transformer中去理解。

image.png

这里是把初始化的参数赋值给我们之后需要用到的参数。

Class PatchEmbed

PatchEmbed的作用就是将图片划分成一个个没有重叠的patches。
对应于——image.png,Patch Partition和Linear Embedding两个操作。

  1. self.patch_embed = PatchEmbed(
  2. patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
  3. norm_layer=norm_layer if self.patch_norm else None)
  1. class PatchEmbed(nn.Module):
  2. """
  3. 2D Image to Patch Embedding
  4. """
  5. def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
  6. super().__init__()
  7. patch_size = (patch_size, patch_size)
  8. self.patch_size = patch_size
  9. self.in_chans = in_c
  10. self.embed_dim = embed_dim
  11. self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
  12. self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
  13. def forward(self, x):
  14. _, _, H, W = x.shape
  15. # padding
  16. # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
  17. pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
  18. if pad_input:
  19. # to pad the last 3 dimensions,
  20. # (W_left, W_right, H_top,H_bottom, C_front, C_back)
  21. x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
  22. 0, self.patch_size[0] - H % self.patch_size[0],
  23. 0, 0))
  24. # 下采样patch_size倍
  25. x = self.proj(x)
  26. _, _, H, W = x.shape
  27. # flatten: [B, C, H, W] -> [B, C, HW]
  28. # transpose: [B, C, HW] -> [B, HW, C]
  29. x = x.flatten(2).transpose(1, 2)
  30. x = self.norm(x)
  31. return x, H, W

1.embed_dim=96

这里决定的是,transformer输入的维度。

2.下采样

  1. self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)

这里利用卷积来实现下采样。

Conv2d的参数有:
1.in_channels
2.out_channels
3.kernel_size
4.stride

这里的kernel_size = patch_size 说明是用4x4来进行下采样。stride = patch_size,这样说明不会重叠。

3.def forward

  1. def forward(self, x):
  2. _, _, H, W = x.shape
  3. # padding
  4. # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
  5. pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
  6. if pad_input:
  7. # to pad the last 3 dimensions,
  8. # (W_left, W_right, H_top,H_bottom, C_front, C_back)
  9. x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
  10. 0, self.patch_size[0] - H % self.patch_size[0],
  11. 0, 0))
  12. # 下采样patch_size倍
  13. x = self.proj(x)
  14. _, _, H, W = x.shape
  15. # flatten: [B, C, H, W] -> [B, C, HW]
  16. # transpose: [B, C, HW] -> [B, HW, C]
  17. x = x.flatten(2).transpose(1, 2)
  18. x = self.norm(x)
  19. return x, H, W

在开始卷积运算之前,要先做Padding的工作

  1. pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
  2. if pad_input:
  3. # to pad the last 3 dimensions,
  4. # (W_left, W_right, H_top,H_bottom, C_front, C_back)
  5. x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
  6. 0, self.patch_size[0] - H % self.patch_size[0],
  7. 0, 0))

这里先用pad_input 判断输入是否需要pad

x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
0, self.patch_size[0] - H % self.patch_size[0],
0, 0))
这里的

  1. self.patch_size[1] - W % self.patch_size[1]

这句话的意思是,把缺德地方补上0。比如说,patch_size = 4 但是只有3个
4 - (4%3) =1 所以补上1个。

pad之后,就是Patch_size的整数倍了,就可以进行下采样了。

def forward从整体上来看,就是一个前向传播的过程。