0 概述

  • 论文模型结构

image.png

  • 代码地址

  • 对应代码的详细模型结构

无标题.png

1 Vision Transformer

image.png

  1. class VisionTransformer(nn.Module):
  2. def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
  3. embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
  4. qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
  5. attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
  6. act_layer=None):
  7. """
  8. Args:
  9. img_size (int, tuple): input image size
  10. patch_size (int, tuple): patch size
  11. in_c (int): number of input channels
  12. num_classes (int): number of classes for classification head
  13. embed_dim (int): embedding dimension
  14. depth (int): depth of transformer
  15. num_heads (int): number of attention heads
  16. mlp_ratio (int): ratio of mlp hidden dim to embedding dim
  17. qkv_bias (bool): enable bias for qkv if True
  18. qk_scale (float): override default qk scale of head_dim ** -0.5 if set
  19. representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
  20. distilled (bool): model includes a distillation token and head as in DeiT models
  21. drop_ratio (float): dropout rate
  22. attn_drop_ratio (float): attention dropout rate
  23. drop_path_ratio (float): stochastic depth rate
  24. embed_layer (nn.Module): patch embedding layer
  25. norm_layer: (nn.Module): normalization layer
  26. """
  27. super(VisionTransformer, self).__init__()
  28. self.num_classes = num_classes
  29. self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
  30. self.num_tokens = 2 if distilled else 1
  31. norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
  32. act_layer = act_layer or nn.GELU
  33. self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
  34. num_patches = self.patch_embed.num_patches
  35. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  36. # dist_token用于建立DeiT,ViT用不到
  37. self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
  38. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
  39. self.pos_drop = nn.Dropout(p=drop_ratio)
  40. dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
  41. self.blocks = nn.Sequential(*[
  42. Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
  43. drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
  44. norm_layer=norm_layer, act_layer=act_layer)
  45. for i in range(depth)
  46. ])
  47. self.norm = norm_layer(embed_dim)
  48. # Representation layer
  49. if representation_size and not distilled:
  50. self.has_logits = True
  51. self.num_features = representation_size
  52. self.pre_logits = nn.Sequential(OrderedDict([
  53. ("fc", nn.Linear(embed_dim, representation_size)),
  54. ("act", nn.Tanh())
  55. ]))
  56. else:
  57. self.has_logits = False
  58. self.pre_logits = nn.Identity()
  59. # Classifier head(s)
  60. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  61. self.head_dist = None
  62. if distilled:
  63. self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
  64. # Weight init
  65. nn.init.trunc_normal_(self.pos_embed, std=0.02)
  66. if self.dist_token is not None:
  67. nn.init.trunc_normal_(self.dist_token, std=0.02)
  68. nn.init.trunc_normal_(self.cls_token, std=0.02)
  69. self.apply(_init_vit_weights)
  70. def forward_features(self, x):
  71. # step 1: 生成tokens
  72. # [B, C, H, W] -> [B, num_patches, embed_dim] eg: [B, 196, 768]
  73. x = self.patch_embed(x)
  74. # step 2: 拼接class token
  75. # [1, 1, 768] -> [B, 1, 768]
  76. cls_token = self.cls_token.expand(x.shape[0], -1, -1)
  77. if self.dist_token is None:
  78. # [B, 197, 768]
  79. x = torch.cat((cls_token, x), dim=1)
  80. else:
  81. x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
  82. # step 3: 累加position embedding
  83. x = self.pos_drop(x + self.pos_embed)
  84. # step 4: pass transformer encoders
  85. x = self.blocks(x)
  86. x = self.norm(x)
  87. if self.dist_token is None:
  88. return self.pre_logits(x[:, 0])
  89. else:
  90. return x[:, 0], x[:, 1]
  91. def forward(self, x):
  92. x = self.forward_features(x)
  93. # 构建ViT时,head_dist为None
  94. if self.head_dist is not None:
  95. x, x_dist = self.head(x[0]), self.head_dist(x[1])
  96. if self.training and not torch.jit.is_scripting():
  97. # during inference, return the average of both classifier predictions
  98. return x, x_dist
  99. else:
  100. return (x + x_dist) / 2
  101. else:
  102. # step 5: pass MLP Head
  103. x = self.head(x)
  104. return x

2 Patch Embedding

image.png

  • 代码

    1. class PatchEmbed(nn.Module):
    2. """
    3. 2D Image to Patch Embedding
    4. """
    5. def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
    6. super().__init__()
    7. img_size = (img_size, img_size)
    8. patch_size = (patch_size, patch_size)
    9. self.img_size = img_size
    10. self.patch_size = patch_size
    11. self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
    12. self.num_patches = self.grid_size[0] * self.grid_size[1]
    13. # 卷积核大小 = patch_size
    14. self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
    15. self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
    16. def forward(self, x):
    17. # 传入图像的H和W与预设值不一致时将报错
    18. B, C, H, W = x.shape
    19. assert H == self.img_size[0] and W == self.img_size[1], \
    20. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
    21. # Conv2d eg: [B, 3, 224, 224] -> [B, 768, 14, 14]
    22. # flatten eg: [B, 768, 14, 14] -> [B, 768, 196]
    23. # transpose eg: [B, 768, 196] -> [B, 196, 768]
    24. x = self.proj(x).flatten(2).transpose(1, 2)
    25. x = self.norm(x)
    26. return x

3 Encoder-Block

image.png image.png

  • 代码

    1. class Block(nn.Module):
    2. def __init__(self,
    3. dim,
    4. num_heads,
    5. mlp_ratio=4.,
    6. qkv_bias=False,
    7. qk_scale=None,
    8. drop_ratio=0.,
    9. attn_drop_ratio=0.,
    10. drop_path_ratio=0.,
    11. act_layer=nn.GELU,
    12. norm_layer=nn.LayerNorm):
    13. super(Block, self).__init__()
    14. self.norm1 = norm_layer(dim)
    15. self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
    16. attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
    17. self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
    18. self.norm2 = norm_layer(dim)
    19. mlp_hidden_dim = int(dim * mlp_ratio)
    20. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
    21. def forward(self, x):
    22. x = x + self.drop_path(self.attn(self.norm1(x)))
    23. x = x + self.drop_path(self.mlp(self.norm2(x)))
    24. return x

4 Encoder-Block: Multi-Head Attention

image.png

  • Transformer Encoder

Vision Transformer代码解析 - 图8

  • Self-Attention

Vision Transformer代码解析 - 图9
image.png

  • Multi-Head

用不同的Vision Transformer代码解析 - 图11Vision Transformer代码解析 - 图12Vision Transformer代码解析 - 图13就能得到不同的Q、K、V
Vision Transformer代码解析 - 图14
Vision Transformer代码解析 - 图15
现在一个Vision Transformer代码解析 - 图16对应多个版本的Vision Transformer代码解析 - 图17,那么怎么结合为一个Vision Transformer代码解析 - 图18
Vision Transformer代码解析 - 图19

  • Multi-Head Attention总结

Vision Transformer代码解析 - 图20
image.png

  • 代码中的流程
  • 代码

    1. class Attention(nn.Module):
    2. """
    3. Multi-head self-attention
    4. """
    5. def __init__(self,
    6. dim, # 输入token的dim
    7. num_heads=8, # multi-head中head的个数
    8. qkv_bias=False,
    9. qk_scale=None,
    10. attn_drop_ratio=0.,
    11. proj_drop_ratio=0.):
    12. super(Attention, self).__init__()
    13. self.num_heads = num_heads
    14. head_dim = dim // num_heads
    15. self.scale = qk_scale or head_dim ** -0.5
    16. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    17. self.attn_drop = nn.Dropout(attn_drop_ratio)
    18. self.proj = nn.Linear(dim, dim)
    19. self.proj_drop = nn.Dropout(proj_drop_ratio)
    20. def forward(self, x):
    21. # [batch_size, num_patches + 1, total_embed_dim]
    22. # +1是因为class token
    23. B, N, C = x.shape
    24. # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
    25. qkv = self.qkv(x)
    26. # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
    27. # reshape相当于把qkv和多head分离开来
    28. # 其中embed_dim_per_head == head_dim == C // self.num_heads
    29. qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
    30. # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
    31. # 这一步用来调整维度顺序
    32. qkv = qkv.permute(2, 0, 3, 1, 4)
    33. # q.shape = [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
    34. # 分别拿到q、k、v
    35. q, k, v = qkv[0], qkv[1], qkv[2]
    36. # softmax(QKt/根号下dk)
    37. # 矩阵乘法@ -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
    38. attn = (q @ k.transpose(-2, -1)) * self.scale
    39. attn = attn.softmax(dim=-1)
    40. attn = self.attn_drop(attn)
    41. # @ -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
    42. x = (attn @ v)
    43. # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
    44. x = x.transpose(1, 2)
    45. # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
    46. # 这一步就是拼接multi-head的结果
    47. x = x.reshape(B, N, C)
    48. # x = xW0
    49. x = self.proj(x)
    50. x = self.proj_drop(x)
    51. return x

5 Encoder-Block: MLP Block

image.png

  • 代码

    1. class Mlp(nn.Module):
    2. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
    3. super().__init__()
    4. out_features = out_features or in_features
    5. hidden_features = hidden_features or in_features
    6. self.fc1 = nn.Linear(in_features, hidden_features)
    7. self.act = act_layer()
    8. self.fc2 = nn.Linear(hidden_features, out_features)
    9. self.drop = nn.Dropout(drop)
    10. def forward(self, x):
    11. x = self.fc1(x)
    12. x = self.act(x)
    13. x = self.drop(x)
    14. x = self.fc2(x)
    15. x = self.drop(x)
    16. return x