

  1. 速度快:
    1. 后续计算只利用了编码器的后三个卷积块的输出特征,没有使用前两个,因为前两个尺寸比较大,计算消耗比较大,这个是大头,与初始注意力图相比,所提出的整体注意机制几乎不会增加计算成本,并进一步突出了整个显着对象,如图4所示。
    2. 为了加速,在每个分支都是用1x1卷积将通道降维到32,并使用了Skip连接。
  2. 精度高:
    1. 结构的有效性。双分支结构,一路生成初始显著性图,一路生成更高质量的显著性图,前者用来细化后者的特征信息,抑制干扰信息。
    2. 使用整体注意力模块(holistic attention module),扩大初始显着图的覆盖范围。
    3. decoder中使用改进的RFB模块,多尺度感受野,有效编码上下文
    4. 两个分支中都使用了多尺度的特征,来自不同的层级。


Cascaded Partial Decoder for Fast and Accurate Salient Object Detection - 图2


  1. 低级特征实际上对于深度集成模型的性能贡献较少(下图a)
  2. 低级特征集成到高级特征上,会增大计算消耗,因为它们空间分辨率比较大(下图b)






使用这里作为优optimization layer,使用后两个卷积块来构造双分支(注意力分支与检测分支)。


  • 设计partial decoder集成来自三个卷积块的输出特征(f3,f4,f5),进而得到初始显著性图Si。
  • 经过整体注意力模块的处理之后,得到增强的注意力图Sh,用其细化来自第三个卷积块的特征f3。因为这里可以通过聚合三个顶层的特征来获得一个相对精确的显著性图,这里的注意力图就可以有效的消除特征f3中的异常信息,并且极大的提升其表达能力
    • 如果将干扰归为显着性部分,则该策略导致异常分割结果。因此,需要提高初始显着性图的有效性。使用了整体注意力模块
  • 通过特征图f3与注意力图Sh的元素乘法可以获得细化的特征图f3d,进过卷积4和5可以得到对应的输出f4d和f5d。
  • 细化后的特征最终通过另一个partial decoder得到输出的显著性图Sd。






  • 一方面,注意力分支为检测分支提供精确的注意图,这导致检测分支更精确的显着对象。
  • 另一方面,检测分支可以被认为是注意力分支的辅助损失,这也有助于注意力分支集中于显着对象。



尽管与传统的编码器 - 解码器架构相比,这里增加了一个解码器提高了骨干网络的计算成本,但由于丢弃了解码器中的低级特征,总计算复杂度仍然显着降低。此外,所提出的框架的级联优化机制提升了性能,实验表明这两个分支都优于原始模型。

整体注意力模块(Holistic Attention Module)


  • 当从注意力分支获得准确的显着性图时,该策略将有效地抑制特征的干扰。
  • 相反,如果将干扰归类为显着性区域,则该策略导致异常分割结果。



这里的Convg表示一个有着高斯核k和零偏置的卷积操作,其中的fmin_max()表示一个归一化函数,来让blurred map的范围变为[0, 1]。而MAX()操作表示取最大值函数,这样可以使得趋向于增加平滑后的Si中显著性区域的权重系数。





  • 在解码器中使用了改进的RFB(receptive field block)模块
    • 本身的RFB是将Inception模块中的3x3卷积替换为扩张卷积
    • 在RFB上也使用了Skip连接
    • 这样可以实现多尺度的感受野,进一步捕获全局对比度信息,更加有效的编码上下文信息
  • 为了加速,在每个分支都用1x1卷积降低通道为32。




  1. class CPD_ResNet(nn.Module):
  2. # resnet based encoder decoder
  3. def __init__(self, channel=32):
  4. super(CPD_ResNet, self).__init__()
  5. self.resnet = B2_ResNet()
  6. self.rfb2_1 = RFB(512, channel)
  7. self.rfb3_1 = RFB(1024, channel)
  8. self.rfb4_1 = RFB(2048, channel)
  9. self.agg1 = aggregation(channel)
  10. self.rfb2_2 = RFB(512, channel)
  11. self.rfb3_2 = RFB(1024, channel)
  12. self.rfb4_2 = RFB(2048, channel)
  13. self.agg2 = aggregation(channel)
  14. self.upsample = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
  15. ...
  16. def forward(self, x):
  17. ...
  18. x2_1 = x2
  19. x3_1 = self.resnet.layer3_1(x2_1) # 1024 x 16 x 16
  20. x4_1 = self.resnet.layer4_1(x3_1) # 2048 x 8 x 8
  21. x2_1 = self.rfb2_1(x2_1)
  22. x3_1 = self.rfb3_1(x3_1)
  23. x4_1 = self.rfb4_1(x4_1)
  24. attention_map = self.agg1(x4_1, x3_1, x2_1)
  25. x2_2 = self.HA(attention_map.sigmoid(), x2)
  26. x3_2 = self.resnet.layer3_2(x2_2) # 1024 x 16 x 16
  27. x4_2 = self.resnet.layer4_2(x3_2) # 2048 x 8 x 8
  28. x2_2 = self.rfb2_2(x2_2)
  29. x3_2 = self.rfb3_2(x3_2)
  30. x4_2 = self.rfb4_2(x4_2)
  31. detection_map = self.agg2(x4_2, x3_2, x2_2)
  32. return self.upsample(attention_map), self.upsample(detection_map)


  1. class aggregation(nn.Module):
  2. # dense aggregation, it can be replaced by other aggregation model, such as DSS, amulet, and so on.
  3. # used after MSF
  4. def __init__(self, channel):
  5. super(aggregation, self).__init__()
  6. self.relu = nn.ReLU(True)
  7. self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
  8. self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
  9. self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
  10. self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
  11. self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
  12. self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
  13. self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
  14. self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
  15. self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
  16. self.conv5 = nn.Conv2d(3*channel, 1, 1)
  17. def forward(self, x1, x2, x3):
  18. x1_1 = x1
  19. x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
  20. x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \
  21. * self.conv_upsample3(self.upsample(x2)) * x3
  22. x2_2 =, self.conv_upsample4(self.upsample(x1_1))), 1)
  23. x2_2 = self.conv_concat2(x2_2)
  24. x3_2 =, self.conv_upsample5(self.upsample(x2_2))), 1)
  25. x3_2 = self.conv_concat3(x3_2)
  26. x = self.conv4(x3_2)
  27. # N,96,H//8,W//8
  28. x = self.conv5(x)
  29. return x


Cascaded Partial Decoder for Fast and Accurate Salient Object Detection - 图8 这里最后输出的特征图还需要双线性插值上采样8倍。最终使用1x1卷积进行预测输出。


  • Pytorch framework and a GTX 1080Ti GPU
  • Train:DUTS-TR
  • The parameters of the bifurcated backbone network are initialized by VGG16.
  • We initialize the other convolutional layers using the default setting of the Pytorch.
  • All training and test images are resized to 352x352.
  • Any post-processing procedure (e.g. CRF) is not applied in this paper.
  • The proposed model is trained by Adam optimizer. The batch size is set as 10 and the initial learning rate is set as 1e-4 and decreased by 10% when training loss reaches a flat.
  • It takes nearly six hours for training the proposed model.



  • 一方面,注意力分支的被监督的注意力图使得检测分支进一步集中于显着对象。
  • 另一方面,当训练所提出的模型时,检测分支的梯度也向后传播到注意力分支。



NLDF adopts a typical U-Net architecture, BMPM proposes a bidirectional decoder with gate function and Amulet integrates multi-level feature maps in multiple resolutions.




holistic attention的有效性


optimization layer的选择

We do not test the proposed model with Conv12 optimization layer because this setting will increase the computation cost via adding one more full decoder; thus requirements of reducing computation cost will not be achieved.




The performance of the proposed model relies on the accuracy of the attention branch. When the attention branch detects clutters as target regions, our model will obtain wrong results.





  • 本文提出了一种新颖的级联部分解码器框架,用于快速准确的显着物体检测。
  • 在构造解码器时,所提出的框架丢弃较浅层的特征以提高计算效率,并利用生成的显着图来细化特征以提高准确性。
  • 提出了一个整体注意模块来进一步细分整个显着对象
  • 提出一个有效的解码器来抽象判别特征并快速整合多层次特征



  • 有效的分割不见得需要利用所有的层级
  • 对于需要细化的特征使用注意力机制
  • 对注意力分支也进行真值的监督
  • 扩大感受野的同时尽可能减少计算量(使用扩张Inception模块的扩张卷积改进,并配合skip链接)
  • 测试在其他相似任务上的有效性


  1. #
  2. import torch
  3. import torch.nn as nn
  4. import torchvision.models as models
  5. from HolisticAttention import HA
  6. from ResNet import B2_ResNet
  7. class BasicConv2d(nn.Module):
  8. def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
  9. super(BasicConv2d, self).__init__()
  10. self.conv = nn.Conv2d(in_planes, out_planes,
  11. kernel_size=kernel_size, stride=stride,
  12. padding=padding, dilation=dilation, bias=False)
  13. = nn.BatchNorm2d(out_planes)
  14. self.relu = nn.ReLU(inplace=True)
  15. def forward(self, x):
  16. x = self.conv(x)
  17. x =
  18. return x
  19. class RFB(nn.Module):
  20. # RFB-like multi-scale module
  21. def __init__(self, in_channel, out_channel):
  22. super(RFB, self).__init__()
  23. self.relu = nn.ReLU(True)
  24. self.branch0 = nn.Sequential(
  25. BasicConv2d(in_channel, out_channel, 1),
  26. )
  27. self.branch1 = nn.Sequential(
  28. BasicConv2d(in_channel, out_channel, 1),
  29. BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
  30. BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
  31. BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
  32. )
  33. self.branch2 = nn.Sequential(
  34. BasicConv2d(in_channel, out_channel, 1),
  35. BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
  36. BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
  37. BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
  38. )
  39. self.branch3 = nn.Sequential(
  40. BasicConv2d(in_channel, out_channel, 1),
  41. BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
  42. BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
  43. BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
  44. )
  45. self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
  46. self.conv_res = BasicConv2d(in_channel, out_channel, 1)
  47. def forward(self, x):
  48. x0 = self.branch0(x)
  49. x1 = self.branch1(x)
  50. x2 = self.branch2(x)
  51. x3 = self.branch3(x)
  52. x_cat = self.conv_cat(, x1, x2, x3), 1))
  53. x = self.relu(x_cat + self.conv_res(x))
  54. return x
  55. class aggregation(nn.Module):
  56. # dense aggregation, it can be replaced by other aggregation model, such as DSS, amulet, and so on.
  57. # used after MSF
  58. def __init__(self, channel):
  59. super(aggregation, self).__init__()
  60. self.relu = nn.ReLU(True)
  61. self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
  62. self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
  63. self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
  64. self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
  65. self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
  66. self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
  67. self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
  68. self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
  69. self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
  70. self.conv5 = nn.Conv2d(3*channel, 1, 1)
  71. def forward(self, x1, x2, x3):
  72. x1_1 = x1
  73. x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
  74. x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \
  75. * self.conv_upsample3(self.upsample(x2)) * x3
  76. x2_2 =, self.conv_upsample4(self.upsample(x1_1))), 1)
  77. x2_2 = self.conv_concat2(x2_2)
  78. x3_2 =, self.conv_upsample5(self.upsample(x2_2))), 1)
  79. x3_2 = self.conv_concat3(x3_2)
  80. x = self.conv4(x3_2)
  81. x = self.conv5(x)
  82. return x
  83. class CPD_ResNet(nn.Module):
  84. # resnet based encoder decoder
  85. def __init__(self, channel=32):
  86. super(CPD_ResNet, self).__init__()
  87. self.resnet = B2_ResNet()
  88. self.rfb2_1 = RFB(512, channel)
  89. self.rfb3_1 = RFB(1024, channel)
  90. self.rfb4_1 = RFB(2048, channel)
  91. self.agg1 = aggregation(channel)
  92. self.rfb2_2 = RFB(512, channel)
  93. self.rfb3_2 = RFB(1024, channel)
  94. self.rfb4_2 = RFB(2048, channel)
  95. self.agg2 = aggregation(channel)
  96. self.upsample = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
  97. self.HA = HA()
  98. if
  99. self.initialize_weights()
  100. def forward(self, x):
  101. x = self.resnet.conv1(x)
  102. x = self.resnet.bn1(x)
  103. x = self.resnet.relu(x)
  104. x = self.resnet.maxpool(x)
  105. x1 = self.resnet.layer1(x) # 256 x 64 x 64
  106. x2 = self.resnet.layer2(x1) # 512 x 32 x 32
  107. x2_1 = x2
  108. x3_1 = self.resnet.layer3_1(x2_1) # 1024 x 16 x 16
  109. x4_1 = self.resnet.layer4_1(x3_1) # 2048 x 8 x 8
  110. x2_1 = self.rfb2_1(x2_1)
  111. x3_1 = self.rfb3_1(x3_1)
  112. x4_1 = self.rfb4_1(x4_1)
  113. attention_map = self.agg1(x4_1, x3_1, x2_1)
  114. x2_2 = self.HA(attention_map.sigmoid(), x2)
  115. x3_2 = self.resnet.layer3_2(x2_2) # 1024 x 16 x 16
  116. x4_2 = self.resnet.layer4_2(x3_2) # 2048 x 8 x 8
  117. x2_2 = self.rfb2_2(x2_2)
  118. x3_2 = self.rfb3_2(x3_2)
  119. x4_2 = self.rfb4_2(x4_2)
  120. detection_map = self.agg2(x4_2, x3_2, x2_2)
  121. return self.upsample(attention_map), self.upsample(detection_map)
  122. def initialize_weights(self):
  123. res50 = models.resnet50(pretrained=True)
  124. pretrained_dict = res50.state_dict()
  125. all_params = {}
  126. for k, v in self.resnet.state_dict().items():
  127. if k in pretrained_dict.keys():
  128. v = pretrained_dict[k]
  129. all_params[k] = v
  130. elif '_1' in k:
  131. name = k.split('_1')[0] + k.split('_1')[1]
  132. v = pretrained_dict[name]
  133. all_params[k] = v
  134. elif '_2' in k:
  135. name = k.split('_2')[0] + k.split('_2')[1]
  136. v = pretrained_dict[name]
  137. all_params[k] = v
  138. assert len(all_params.keys()) == len(self.resnet.state_dict().keys())
  139. self.resnet.load_state_dict(all_params)