Paper https://arxiv.org/abs/1807.06521

https://github.com/Jongchan/attention-module

https://mp.weixin.qq.com/s?__biz=MzA4MjY4NTk0NQ==&mid=2247484531&idx=1&sn=625065862b28608428acb21da3330717&chksm=9f80bee5a8f737f399f0f564883337154dd8ca3ad5c246c85a86a88b0ac8ede7bf59ffc04554&token=897871599&lang=zh_CN&scene=21#wechat_redirect

image.png

代码

  1. import torch
  2. import math
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. class BasicConv(nn.Module):
  6. def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
  7. super(BasicConv, self).__init__()
  8. self.out_channels = out_planes
  9. self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
  10. self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
  11. self.relu = nn.ReLU() if relu else None
  12. def forward(self, x):
  13. x = self.conv(x)
  14. if self.bn is not None:
  15. x = self.bn(x)
  16. if self.relu is not None:
  17. x = self.relu(x)
  18. return x
  19. class Flatten(nn.Module):
  20. def forward(self, x):
  21. return x.view(x.size(0), -1)
  22. class ChannelGate(nn.Module):
  23. def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
  24. super(ChannelGate, self).__init__()
  25. self.gate_channels = gate_channels
  26. self.mlp = nn.Sequential(
  27. Flatten(),
  28. nn.Linear(gate_channels, gate_channels // reduction_ratio),
  29. nn.ReLU(),
  30. nn.Linear(gate_channels // reduction_ratio, gate_channels)
  31. )
  32. self.pool_types = pool_types
  33. def forward(self, x):
  34. channel_att_sum = None
  35. for pool_type in self.pool_types:
  36. if pool_type=='avg': # 平均池化
  37. avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
  38. channel_att_raw = self.mlp( avg_pool )
  39. elif pool_type=='max': # 全局池化
  40. max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
  41. channel_att_raw = self.mlp( max_pool )
  42. elif pool_type=='lp':
  43. lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
  44. channel_att_raw = self.mlp( lp_pool )
  45. elif pool_type=='lse':
  46. # LSE pool only
  47. lse_pool = logsumexp_2d(x)
  48. channel_att_raw = self.mlp( lse_pool )
  49. if channel_att_sum is None: # 首次计算
  50. channel_att_sum = channel_att_raw
  51. else:
  52. channel_att_sum = channel_att_sum + channel_att_raw # 累加不同方式的池化
  53. scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
  54. return x * scale
  55. def logsumexp_2d(tensor):
  56. """
  57. tensor: (B, C, H, W)
  58. """
  59. tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) # (B, C, H×W)
  60. s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) # (B, C, 1×1) 每个通道中的最大值
  61. outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
  62. return outputs
  63. class ChannelPool(nn.Module):
  64. def forward(self, x):
  65. # 找出 (H*W 平面每个点的在通道上最大值以及平均值)
  66. return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
  67. class SpatialGate(nn.Module):
  68. def __init__(self, kernel_size):
  69. super(SpatialGate, self).__init__()
  70. self.kernel_size = kernel_size
  71. self.compress = ChannelPool()
  72. self.spatial = BasicConv(2, 1, self.kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
  73. def forward(self, x):
  74. x_compress = self.compress(x)
  75. x_out = self.spatial(x_compress)
  76. scale = torch.sigmoid(x_out) # broadcasting
  77. return x * scale
  78. class CBAM(nn.Module):
  79. def __init__(self, gate_channels, reduction_ratio=16, kernel_size=7, pool_types=['avg', 'max'], no_spatial=False):
  80. super(CBAM, self).__init__()
  81. self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
  82. self.no_spatial=no_spatial
  83. if not no_spatial:
  84. self.SpatialGate = SpatialGate(kernel_size)
  85. def forward(self, x):
  86. x_out = self.ChannelGate(x)
  87. if not self.no_spatial:
  88. x_out = self.SpatialGate(x_out)
  89. return x_out