核心内容
对静态权重的深度分离卷积和MLP-Mixer,以及动态权重的Self-Attention操作进行了形式上的统一表示,并且提出了一种统一整合了这些方法思想的结构——Container。
首先定义了多头形式的上下文信息聚合框架:。这里的A表示不同头(独立的通道组)内部使用的仿射矩阵,用于上下文信息的聚合。而V=XW,是输入X线性变换后的结果。三种操作的不同主要在于仿射矩阵的定义方式有所不同。
- Self-Attention:
- 深度分离卷积:
- MLP-Mixer:
第一种形式是与输入相关的动态权重,后两种形式中的权重都是与输入无关的静态权重。
本文通过组合动态权重与静态权重,从而提出结构:。动态权重的计算仍然遵循原始的注意力机制的形式,而后面静态权重的计算则可以按照标准的深度分离卷积(通道间权重独立,即头内部通道数为1)或者是标准的MLP-Mixer形式。除此之外,也可以使用通道间权重独立的MLP形式,这种形式可以称为多头的MLP。搭配不同的权重
组合,可以实现不同形式的特化。由此实现了三种形式的统一。
当组合深度分离卷积的时候,作者提供了如下一图来和传统的自注意力权重形式进行比较:
核心代码
class Attention_pure(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
C = int(C // 3)
qkv = x.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj_drop(x)
return x
class MixBlock(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
self.norm1 = nn.BatchNorm2d(dim)
self.conv1 = nn.Conv2d(dim, 3 * dim, 1)
self.conv2 = nn.Conv2d(dim, dim, 1)
self.conv = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.attn = Attention_pure(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.BatchNorm2d(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
# 用于加权自注意力和静态聚合操作的权重
self.sa_weight = nn.Parameter(torch.Tensor([0.0]))
def forward(self, x):
x = x + self.pos_embed(x)
B, _, H, W = x.shape
residual = x
x = self.norm1(x)
qkv = self.conv1(x)
# 对V直接卷积
conv = qkv[:, 2 * self.dim:, :, :]
conv = self.conv(conv)
sa = qkv.flatten(2).transpose(1, 2)
sa = self.attn(sa)
sa = sa.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
x = residual + self.drop_path(self.conv2(torch.sigmoid(self.sa_weight) * sa + (1 - torch.sigmoid(self.sa_weight)) * conv))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x