Title

Coordinate Attention for Efficient Mobile Network Design
代码:https://github.com/Andrew-Qibin/CoordAttention

National University of Singapore
没好好看,但觉得这个适应性挺广的,2021年的CVPR,记录一下关键结构

Summary

提出了一个轻量化的注意力机制coordinate attention,在通道信息的基础上融入了空间位置上的信息,显著提高了分类、检测、分割等模型的效果。

Method(s)

image.png
(a)是SE block中的channel attention结构
(b)是CBAM方法(一篇18年的paper,这篇paper其实也是希望空间位置+通道信息,但对于空间信息的编码上不如本文方法有优势)
(c)是本文提出的结构,和a相比关键在于它利用到了H方向和W方向的avg pool,在原有通道注意力的基础上拿到了位置注意力。

如果上图的(c)看了有疑问的话可以重点看下面这版:
image.png
可以看出,和(a)的区别主要在:H、W通道分别的avg pool以及对应的concat和split操作

Evaluation

检测
image.png
语义分割
image.png

Code

挺好理解的

  1. class CoordAtt(nn.Module):
  2. def __init__(self, inp, oup, reduction=32):
  3. super(CoordAtt, self).__init__()
  4. self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
  5. self.pool_w = nn.AdaptiveAvgPool2d((1, None))
  6. mip = max(8, inp // reduction)
  7. self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
  8. self.bn1 = nn.BatchNorm2d(mip)
  9. self.act = h_swish()
  10. self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
  11. self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
  12. def forward(self, x):
  13. identity = x
  14. n,c,h,w = x.size()
  15. x_h = self.pool_h(x)
  16. x_w = self.pool_w(x).permute(0, 1, 3, 2)
  17. y = torch.cat([x_h, x_w], dim=2)
  18. y = self.conv1(y)
  19. y = self.bn1(y)
  20. y = self.act(y)
  21. x_h, x_w = torch.split(y, [h, w], dim=2)
  22. x_w = x_w.permute(0, 1, 3, 2)
  23. a_h = self.conv_h(x_h).sigmoid()
  24. a_w = self.conv_w(x_w).sigmoid()
  25. out = identity * a_w * a_h
  26. return out