image.png

核心内容

对静态权重的深度分离卷积和MLP-Mixer,以及动态权重的Self-Attention操作进行了形式上的统一表示,并且提出了一种统一整合了这些方法思想的结构——Container。
首先定义了多头形式的上下文信息聚合框架:image.png。这里的A表示不同头(独立的通道组)内部使用的仿射矩阵,用于上下文信息的聚合。而V=XW,是输入X线性变换后的结果。三种操作的不同主要在于仿射矩阵的定义方式有所不同。
Container: Context Aggregation Network - 图3

  • Self-Attention:image.png
  • 深度分离卷积:image.png
  • MLP-Mixer:image.png

第一种形式是与输入相关的动态权重,后两种形式中的权重都是与输入无关的静态权重。
本文通过组合动态权重与静态权重,从而提出结构:image.png。动态权重的计算仍然遵循原始的注意力机制的形式,而后面静态权重的计算则可以按照标准的深度分离卷积(通道间权重独立,即头内部通道数为1)或者是标准的MLP-Mixer形式。除此之外,也可以使用通道间权重独立的MLP形式,这种形式可以称为多头的MLP。搭配不同的权重Container: Context Aggregation Network - 图8组合,可以实现不同形式的特化。由此实现了三种形式的统一。
当组合深度分离卷积的时候,作者提供了如下一图来和传统的自注意力权重形式进行比较:
Container: Context Aggregation Network - 图9

核心代码

  1. class Attention_pure(nn.Module):
  2. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  3. def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
  4. super().__init__()
  5. self.num_heads = num_heads
  6. head_dim = dim // num_heads
  7. # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
  8. self.scale = qk_scale or head_dim ** -0.5
  9. self.attn_drop = nn.Dropout(attn_drop)
  10. self.proj_drop = nn.Dropout(proj_drop)
  11. def forward(self, x):
  12. B, N, C = x.shape
  13. C = int(C // 3)
  14. qkv = x.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  15. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  16. attn = (q @ k.transpose(-2, -1)) * self.scale
  17. attn = attn.softmax(dim=-1)
  18. attn = self.attn_drop(attn)
  19. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  20. x = self.proj_drop(x)
  21. return x
  22. class MixBlock(nn.Module):
  23. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  24. def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
  25. drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
  26. super().__init__()
  27. self.dim = dim
  28. self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
  29. self.norm1 = nn.BatchNorm2d(dim)
  30. self.conv1 = nn.Conv2d(dim, 3 * dim, 1)
  31. self.conv2 = nn.Conv2d(dim, dim, 1)
  32. self.conv = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
  33. self.attn = Attention_pure(
  34. dim,
  35. num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
  36. attn_drop=attn_drop, proj_drop=drop)
  37. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  38. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  39. self.norm2 = nn.BatchNorm2d(dim)
  40. mlp_hidden_dim = int(dim * mlp_ratio)
  41. self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  42. # 用于加权自注意力和静态聚合操作的权重
  43. self.sa_weight = nn.Parameter(torch.Tensor([0.0]))
  44. def forward(self, x):
  45. x = x + self.pos_embed(x)
  46. B, _, H, W = x.shape
  47. residual = x
  48. x = self.norm1(x)
  49. qkv = self.conv1(x)
  50. # 对V直接卷积
  51. conv = qkv[:, 2 * self.dim:, :, :]
  52. conv = self.conv(conv)
  53. sa = qkv.flatten(2).transpose(1, 2)
  54. sa = self.attn(sa)
  55. sa = sa.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
  56. x = residual + self.drop_path(self.conv2(torch.sigmoid(self.sa_weight) * sa + (1 - torch.sigmoid(self.sa_weight)) * conv))
  57. x = x + self.drop_path(self.mlp(self.norm2(x)))
  58. return x

链接