


In this paper, we propose a novel saliency detection method, which contains:

  1. a context-aware pyramid feature extraction module and a channel-wise attention module to capture context-aware multi-scale multi-receptive-field high-level features
  2. a spatial attention module for low-level feature maps to refine salient object details and an effective edge preservation loss to guide network to learn more detailed information in boundary localization.


  1. 低级特征上使用了空间注意力模块。
  2. 高级特征上使用了通道注意力模块和类ASPP结构CPFE。





  1. def VGG16(img_input, dropout=False, with_CPFE=False, with_CA=False, with_SA=False, droup_rate=0.3):
  2. # Block 1
  3. x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(img_input)
  4. x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
  5. C1 = x
  6. x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
  7. if dropout:
  8. x = Dropout(droup_rate)(x)
  9. # Block 2
  10. x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)
  11. x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)
  12. C2 = x
  13. x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
  14. if dropout:
  15. x = Dropout(droup_rate)(x)
  16. # Block 3
  17. x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)
  18. x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)
  19. x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)
  20. C3 = x
  21. x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
  22. if dropout:
  23. x = Dropout(droup_rate)(x)
  24. # Block 4
  25. x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)
  26. x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)
  27. x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)
  28. C4 = x
  29. x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
  30. if dropout:
  31. x = Dropout(droup_rate)(x)
  32. # Block 5
  33. x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x)
  34. x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x)
  35. x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x)
  36. if dropout:
  37. x = Dropout(droup_rate)(x)
  38. C5 = x
  39. C1 = Conv2D(64, (3, 3), padding='same', name='C1_conv')(C1)
  40. C1 = BN(C1, 'C1_BN')
  41. C2 = Conv2D(64, (3, 3), padding='same', name='C2_conv')(C2)
  42. C2 = BN(C2, 'C2_BN')
  43. if with_CPFE:
  44. C3_cfe = CFE(C3, 32, 'C3_cfe')
  45. C4_cfe = CFE(C4, 32, 'C4_cfe')
  46. C5_cfe = CFE(C5, 32, 'C5_cfe')
  47. C5_cfe = BilinearUpsampling(upsampling=(4, 4), name='C5_cfe_up4')(C5_cfe)
  48. C4_cfe = BilinearUpsampling(upsampling=(2, 2), name='C4_cfe_up2')(C4_cfe)
  49. C345 = Concatenate(name='C345_aspp_concat', axis=-1)([C3_cfe, C4_cfe, C5_cfe])
  50. if with_CA:
  51. C345 = ChannelWiseAttention(C345, name='C345_ChannelWiseAttention_withcpfe')
  52. C345 = Conv2D(64, (1, 1), padding='same', name='C345_conv')(C345)
  53. C345 = BN(C345,'C345')
  54. C345 = BilinearUpsampling(upsampling=(4, 4), name='C345_up4')(C345)
  55. if with_SA:
  56. SA = SpatialAttention(C345, 'spatial_attention')
  57. C2 = BilinearUpsampling(upsampling=(2, 2), name='C2_up2')(C2)
  58. C12 = Concatenate(name='C12_concat', axis=-1)([C1, C2])
  59. C12 = Conv2D(64, (3, 3), padding='same', name='C12_conv')(C12)
  60. C12 = BN(C12, 'C12')
  61. C12 = Multiply(name='C12_atten_mutiply')([SA, C12])
  62. fea = Concatenate(name='fuse_concat',axis=-1)([C12, C345])
  63. sa = Conv2D(1, (3, 3), padding='same', name='sa')(fea)
  64. model = Model(inputs=img_input, outputs=sa, name="BaseModel")
  65. return model

CPFE:context-aware pyramid feature extraction



  1. class BatchNorm(BatchNormalization):
  2. def call(self, inputs, training=None):
  3. return super(self.__class__, self).call(inputs, training=True)
  4. def BN(input_tensor,block_id):
  5. bn = BatchNorm(name=block_id+'_BN')(input_tensor)
  6. a = Activation('relu',name=block_id+'_relu')(bn)
  7. return a
  8. def AtrousBlock(input_tensor, filters, rate, block_id, stride=1):
  9. x = Conv2D(filters, (3, 3), strides=(stride, stride),
  10. dilation_rate=(rate, rate),
  11. padding='same', use_bias=False,
  12. name=block_id + '_dilation')(input_tensor)
  13. return x
  14. def CFE(input_tensor, filters, block_id):
  15. rate = [3, 5, 7]
  16. cfe0 = Conv2D(filters, (1, 1), padding='same', use_bias=False,
  17. name=block_id + '_cfe0')(input_tensor)
  18. cfe1 = AtrousBlock(input_tensor, filters, rate[0], block_id + '_cfe1')
  19. cfe2 = AtrousBlock(input_tensor, filters, rate[1], block_id + '_cfe2')
  20. cfe3 = AtrousBlock(input_tensor, filters, rate[2], block_id + '_cfe3')
  21. cfe_concat = Concatenate(
  22. name=block_id + 'concatcfe', axis=-1)([cfe0, cfe1, cfe2, cfe3])
  23. cfe_concat = BN(cfe_concat, block_id)
  24. return cfe_concat

Channel-wise attention & Spacial attention





  1. import tensorflow as tf
  2. from keras.engine import Layer
  3. from keras.layers import *
  4. from bilinear_upsampling import BilinearUpsampling
  5. class BatchNorm(BatchNormalization):
  6. def call(self, inputs, training=None):
  7. return super(self.__class__, self).call(inputs, training=True)
  8. def BN(input_tensor,block_id):
  9. bn = BatchNorm(name=block_id+'_BN')(input_tensor)
  10. a = Activation('relu',name=block_id+'_relu')(bn)
  11. return a
  12. def l1_reg(weight_matrix):
  13. return K.mean(weight_matrix)
  14. class Repeat(Layer):
  15. def __init__(self,repeat_list, **kwargs):
  16. super(Repeat, self).__init__(**kwargs)
  17. self.repeat_list = repeat_list
  18. def call(self, inputs):
  19. outputs = tf.tile(inputs, self.repeat_list)
  20. return outputs
  21. def get_config(self):
  22. config = {
  23. 'repeat_list': self.repeat_list
  24. }
  25. base_config = super(Repeat, self).get_config()
  26. return dict(list(base_config.items()) + list(config.items()))
  27. def compute_output_shape(self, input_shape):
  28. output_shape = [None]
  29. for i in xrange(1,len(input_shape)):
  30. output_shape.append(input_shape[i]*self.repeat_list[i])
  31. return tuple(output_shape)
  32. def SpatialAttention(inputs,name):
  33. k = 9
  34. H, W, C = map(int,inputs.get_shape()[1:])
  35. attention1 = Conv2D(C / 2, (1, k), padding='same', name=name+'_1_conv1')(inputs)
  36. attention1 = BN(attention1,'attention1_1')
  37. attention1 = Conv2D(1, (k, 1), padding='same', name=name + '_1_conv2')(attention1)
  38. attention1 = BN(attention1, 'attention1_2')
  39. attention2 = Conv2D(C / 2, (k, 1), padding='same', name=name + '_2_conv1')(inputs)
  40. attention2 = BN(attention2, 'attention2_1')
  41. attention2 = Conv2D(1, (1, k), padding='same', name=name + '_2_conv2')(attention2)
  42. attention2 = BN(attention2, 'attention2_2')
  43. attention = Add(name=name+'_add')([attention1,attention2])
  44. attention = Activation('sigmoid')(attention)
  45. attention = Repeat(repeat_list=[1, 1, 1, C])(attention)
  46. return attention
  47. def ChannelWiseAttention(inputs,name):
  48. H, W, C = map(int, inputs.get_shape()[1:])
  49. attention = GlobalAveragePooling2D(name=name+'_GlobalAveragePooling2D')(inputs)
  50. attention = Dense(C / 4, activation='relu')(attention)
  51. attention = Dense(C, activation='sigmoid',activity_regularizer=l1_reg)(attention)
  52. attention = Reshape((1, 1, C),name=name+'_reshape')(attention)
  53. attention = Repeat(repeat_list=[1, H, W, 1],name=name+'_repeat')(attention)
  54. attention = Multiply(name=name + '_multiply')([attention, inputs])
  55. return attention





这里使用laplace算子提取边缘信息(梯度),配合 abstanh 操作得到最终的边缘,下面是相关的代码,主要差异在于 abs 改为了 relu :

  1. import tensorflow as tf
  2. from keras import backend as K
  3. from keras.backend.common import epsilon
  4. def _to_tensor(x, dtype):
  5. return tf.convert_to_tensor(x, dtype=dtype)
  6. def logit(inputs):
  7. _epsilon = _to_tensor(epsilon(), inputs.dtype.base_dtype)
  8. inputs = tf.clip_by_value(inputs, _epsilon, 1 - _epsilon)
  9. inputs = tf.log(inputs / (1 - inputs))
  10. return inputs
  11. # 计算laplace的函数
  12. def tfLaplace(x):
  13. laplace = tf.constant([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], tf.float32)
  14. laplace = tf.reshape(laplace, [3, 3, 1, 1])
  15. edge = tf.nn.conv2d(x, laplace, strides=[1, 1, 1, 1], padding='SAME')
  16. edge = tf.nn.relu(tf.tanh(edge))
  17. return edge
  18. def EdgeLoss(y_true, y_pred):
  19. y_true_edge = tfLaplace(y_true)
  20. edge_pos = 2.
  21. edge_loss = K.mean(tf.nn.weighted_cross_entropy_with_logits(y_true_edge,y_pred,edge_pos), axis=-1)
  22. return edge_loss
  23. def EdgeHoldLoss(y_true, y_pred):
  24. y_pred2 = tf.sigmoid(y_pred)
  25. y_true_edge = tfLaplace(y_true)
  26. y_pred_edge = tfLaplace(y_pred2)
  27. y_pred_edge = logit(y_pred_edge)
  28. edge_loss = K.mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true_edge,logits=y_pred_edge), axis=-1)
  29. saliency_pos = 1.12
  30. saliency_loss = K.mean(tf.nn.weighted_cross_entropy_with_logits(y_true,y_pred,saliency_pos), axis=-1)
  31. return 0.7*saliency_loss+0.3*edge_loss



  • don’t use the validation set and train the model until training loss converges.
  • some data augmentation techniques:
    • random rotating
    • random cropping
    • random brightness, saturation and contrast changing
    • random horizontal flipping.


Image Mask
random_crop random_crop
random_rotate random_rotate
Zero-center by mean pixel y = y/y.max()
  1. import numpy as np
  2. import cv2
  3. import random
  4. def padding(x,y):
  5. h,w,c = x.shape
  6. size = max(h,w)
  7. paddingh = (size-h)//2
  8. paddingw = (size-w)//2
  9. temp_x = np.zeros((size,size,c))
  10. temp_y = np.zeros((size,size))
  11. temp_x[paddingh:h+paddingh,paddingw:w+paddingw,:] = x
  12. temp_y[paddingh:h+paddingh,paddingw:w+paddingw] = y
  13. return temp_x,temp_y
  14. def random_crop(x,y):
  15. h,w = y.shape
  16. randh = np.random.randint(h/8)
  17. randw = np.random.randint(w/8)
  18. randf = np.random.randint(10)
  19. offseth = 0 if randh == 0 else np.random.randint(randh)
  20. offsetw = 0 if randw == 0 else np.random.randint(randw)
  21. p0, p1, p2, p3 = offseth,h+offseth-randh, offsetw, w+offsetw-randw
  22. if randf >= 5:
  23. x = x[::, ::-1, ::]
  24. y = y[::, ::-1]
  25. return x[p0:p1,p2:p3],y[p0:p1,p2:p3]
  26. def random_rotate(x,y):
  27. angle = np.random.randint(-25,25)
  28. h, w = y.shape
  29. center = (w / 2, h / 2)
  30. M = cv2.getRotationMatrix2D(center, angle, 1.0)
  31. return cv2.warpAffine(x, M, (w, h)),cv2.warpAffine(y, M, (w, h))
  32. def random_light(x):
  33. contrast = np.random.rand(1)+0.5
  34. light = np.random.randint(-20,20)
  35. x = contrast*x + light
  36. return np.clip(x,0,255)
  37. def getTrainGenerator(file_path, target_size, batch_size, israndom=False):
  38. f = open(file_path, 'r')
  39. trainlist = f.readlines()
  40. f.close()
  41. while True:
  42. random.shuffle(trainlist)
  43. batch_x = []
  44. batch_y = []
  45. for name in trainlist:
  46. p = name.strip('\r\n').split(' ')
  47. img_path = p[0]
  48. mask_path = p[1]
  49. x = cv2.imread(img_path)
  50. y = cv2.imread(mask_path)
  51. x = np.array(x, dtype=np.float32)
  52. y = np.array(y, dtype=np.float32)
  53. ############# 处理的核心 ######################
  54. if len(y.shape) == 3:
  55. y = y[:,:,0]
  56. y = y/y.max()
  57. if israndom:
  58. x,y = random_crop(x,y)
  59. x,y = random_rotate(x,y)
  60. x = random_light(x)
  61. x = x[..., ::-1]
  62. # Zero-center by mean pixel
  63. x[..., 0] -= 103.939
  64. x[..., 1] -= 116.779
  65. x[..., 2] -= 123.68
  66. x, y = padding(x, y)
  67. ############# 处理的核心 ######################
  68. x = cv2.resize(x, target_size, interpolation=cv2.INTER_LINEAR)
  69. y = cv2.resize(y, target_size, interpolation=cv2.INTER_NEAREST)
  70. y = y.reshape((target_size[0],target_size[1],1))
  71. batch_x.append(x)
  72. batch_y.append(y)
  73. if len(batch_x) == batch_size:
  74. yield (np.array(batch_x, dtype=np.float32), np.array(batch_y, dtype=np.float32))
  75. batch_x = []
  76. batch_y = []

在测试的时候,这样设定,使用了 zero-centerpaddingresize ,而预测生成的时候就要 cut 到原始的大小,再与真值计算损失(这里是猜测)。

  1. import numpy as np
  2. import cv2
  3. import os
  4. from keras.layers import Input
  5. from model import VGG16
  6. import matplotlib.pyplot as plt
  7. def padding(x):
  8. h,w,c = x.shape
  9. size = max(h,w)
  10. paddingh = (size-h)//2
  11. paddingw = (size-w)//2
  12. temp_x = np.zeros((size,size,c))
  13. temp_x[paddingh:h+paddingh,paddingw:w+paddingw,:] = x
  14. return temp_x
  15. def load_image(path):
  16. x = cv2.imread(path)
  17. sh = x.shape
  18. x = np.array(x, dtype=np.float32)
  19. # 这句似乎没什么用?
  20. x = x[..., ::-1]
  21. # Zero-center by mean pixel
  22. x[..., 0] -= 103.939
  23. x[..., 1] -= 116.779
  24. x[..., 2] -= 123.68
  25. x = padding(x)
  26. x = cv2.resize(x, target_size, interpolation=cv2.INTER_LINEAR)
  27. x = np.expand_dims(x,0)
  28. return x,sh
  29. def cut(pridict,shape):
  30. h,w,c = shape
  31. size = max(h, w)
  32. pridict = cv2.resize(pridict, (size,size))
  33. paddingh = (size - h) // 2
  34. paddingw = (size - w) // 2
  35. return pridict[paddingh:h + paddingh, paddingw:w + paddingw]
  36. def sigmoid(x):
  37. return 1/(1 + np.exp(-x))
  38. def getres(pridict,shape):
  39. pridict = sigmoid(pridict)*255
  40. pridict = np.array(pridict, dtype=np.uint8)
  41. pridict = np.squeeze(pridict)
  42. pridict = cut(pridict, shape)
  43. return pridict
  44. def laplace_edge(x):
  45. laplace = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])
  46. edge = cv2.filter2D(x/255.,-1,laplace)
  47. edge = np.maximum(np.tanh(edge),0)
  48. edge = edge * 255
  49. edge = np.array(edge, dtype=np.uint8)
  50. return edge
  51. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  52. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  53. model_name = 'model/PFA_00050.h5'
  54. target_size = (256,256)
  55. dropout = False
  56. with_CPFE = True
  57. with_CA = True
  58. with_SA = True
  59. if target_size[0 ] % 32 != 0 or target_size[1] % 32 != 0:
  60. raise ValueError('Image height and wight must be a multiple of 32')
  61. model_input = Input(shape=(target_size[0],target_size[1],3))
  62. model = VGG16(model_input,dropout=dropout, with_CPFE=with_CPFE, with_CA=with_CA, with_SA=with_SA)
  63. model.load_weights(model_name,by_name=True)
  64. for layer in model.layers:
  65. layer.trainable = False
  66. image_path = 'image/2.jpg'
  67. img, shape = load_image(image_path)
  68. img = np.array(img, dtype=np.float32)
  69. sa = model.predict(img)
  70. sa = getres(sa, shape)
  71. edge = laplace_edge(sa)
  72. ...


  • When training, we set α = 1.0 at beginning to generate rough saliency map. In this period, our model is trained using SGD with an initial learning rate 1e-2, the image size is 256×256 , the batch size is 22.
  • Then we adjust different α to refine the boundaries of saliency map,and find α = 0.7 is the optimal setting in experiment Tab.2. In this period, the image size, batch size is same as the previous period, but the initial learning rate is 1e-3.

这里的 alpha 就是两个损失之间的权重比例。




