image.png

主要结构

image.png

In this paper, we propose a novel saliency detection method, which contains:

  1. a context-aware pyramid feature extraction module and a channel-wise attention module to capture context-aware multi-scale multi-receptive-field high-level features
  2. a spatial attention module for low-level feature maps to refine salient object details and an effective edge preservation loss to guide network to learn more detailed information in boundary localization.

有意思的地方:

  1. 低级特征上使用了空间注意力模块。
  2. 高级特征上使用了通道注意力模块和类ASPP结构CPFE。

两种注意力没必要一定要放在一起,反而这里在保留了更多的结构信息的较低层的特征上使用了空间注意力,而在较高层的富含语义信息的特征上使用了CPFE和通道注意力,使用多尺度的感受野和通道加权,更好的使用特征,区分前景背景信息。

下面是R3Net的结构图,关于这里的拼接组合有些类似。

image.png

整体模型的代码:

  1. def VGG16(img_input, dropout=False, with_CPFE=False, with_CA=False, with_SA=False, droup_rate=0.3):
  2. # Block 1
  3. x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(img_input)
  4. x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
  5. C1 = x
  6. x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
  7. if dropout:
  8. x = Dropout(droup_rate)(x)
  9. # Block 2
  10. x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)
  11. x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)
  12. C2 = x
  13. x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
  14. if dropout:
  15. x = Dropout(droup_rate)(x)
  16. # Block 3
  17. x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)
  18. x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)
  19. x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)
  20. C3 = x
  21. x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
  22. if dropout:
  23. x = Dropout(droup_rate)(x)
  24. # Block 4
  25. x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)
  26. x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)
  27. x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)
  28. C4 = x
  29. x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
  30. if dropout:
  31. x = Dropout(droup_rate)(x)
  32. # Block 5
  33. x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x)
  34. x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x)
  35. x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x)
  36. if dropout:
  37. x = Dropout(droup_rate)(x)
  38. C5 = x
  39. C1 = Conv2D(64, (3, 3), padding='same', name='C1_conv')(C1)
  40. C1 = BN(C1, 'C1_BN')
  41. C2 = Conv2D(64, (3, 3), padding='same', name='C2_conv')(C2)
  42. C2 = BN(C2, 'C2_BN')
  43. if with_CPFE:
  44. C3_cfe = CFE(C3, 32, 'C3_cfe')
  45. C4_cfe = CFE(C4, 32, 'C4_cfe')
  46. C5_cfe = CFE(C5, 32, 'C5_cfe')
  47. C5_cfe = BilinearUpsampling(upsampling=(4, 4), name='C5_cfe_up4')(C5_cfe)
  48. C4_cfe = BilinearUpsampling(upsampling=(2, 2), name='C4_cfe_up2')(C4_cfe)
  49. C345 = Concatenate(name='C345_aspp_concat', axis=-1)([C3_cfe, C4_cfe, C5_cfe])
  50. if with_CA:
  51. C345 = ChannelWiseAttention(C345, name='C345_ChannelWiseAttention_withcpfe')
  52. C345 = Conv2D(64, (1, 1), padding='same', name='C345_conv')(C345)
  53. C345 = BN(C345,'C345')
  54. C345 = BilinearUpsampling(upsampling=(4, 4), name='C345_up4')(C345)
  55. if with_SA:
  56. SA = SpatialAttention(C345, 'spatial_attention')
  57. C2 = BilinearUpsampling(upsampling=(2, 2), name='C2_up2')(C2)
  58. C12 = Concatenate(name='C12_concat', axis=-1)([C1, C2])
  59. C12 = Conv2D(64, (3, 3), padding='same', name='C12_conv')(C12)
  60. C12 = BN(C12, 'C12')
  61. C12 = Multiply(name='C12_atten_mutiply')([SA, C12])
  62. fea = Concatenate(name='fuse_concat',axis=-1)([C12, C345])
  63. sa = Conv2D(1, (3, 3), padding='same', name='sa')(fea)
  64. model = Model(inputs=img_input, outputs=sa, name="BaseModel")
  65. return model

CPFE:context-aware pyramid feature extraction

image.png

实际上就是一种ASPP结构。

  1. class BatchNorm(BatchNormalization):
  2. def call(self, inputs, training=None):
  3. return super(self.__class__, self).call(inputs, training=True)
  4. def BN(input_tensor,block_id):
  5. bn = BatchNorm(name=block_id+'_BN')(input_tensor)
  6. a = Activation('relu',name=block_id+'_relu')(bn)
  7. return a
  8. def AtrousBlock(input_tensor, filters, rate, block_id, stride=1):
  9. x = Conv2D(filters, (3, 3), strides=(stride, stride),
  10. dilation_rate=(rate, rate),
  11. padding='same', use_bias=False,
  12. name=block_id + '_dilation')(input_tensor)
  13. return x
  14. def CFE(input_tensor, filters, block_id):
  15. rate = [3, 5, 7]
  16. cfe0 = Conv2D(filters, (1, 1), padding='same', use_bias=False,
  17. name=block_id + '_cfe0')(input_tensor)
  18. cfe1 = AtrousBlock(input_tensor, filters, rate[0], block_id + '_cfe1')
  19. cfe2 = AtrousBlock(input_tensor, filters, rate[1], block_id + '_cfe2')
  20. cfe3 = AtrousBlock(input_tensor, filters, rate[2], block_id + '_cfe3')
  21. cfe_concat = Concatenate(
  22. name=block_id + 'concatcfe', axis=-1)([cfe0, cfe1, cfe2, cfe3])
  23. cfe_concat = BN(cfe_concat, block_id)
  24. return cfe_concat

Channel-wise attention & Spacial attention

image.png

这里的通道注意力有些类似于SENet的模块,而这里的空间结构使用了非对称卷积的方式,逐步压缩通道,改变卷积方向、扩大感受野(文章中的k=9)的同时实现了较低的运算量。

image.png

两个分支加和后使用sigmoid计算得到一个权重,对原始特征加权。

  1. import tensorflow as tf
  2. from keras.engine import Layer
  3. from keras.layers import *
  4. from bilinear_upsampling import BilinearUpsampling
  5. class BatchNorm(BatchNormalization):
  6. def call(self, inputs, training=None):
  7. return super(self.__class__, self).call(inputs, training=True)
  8. def BN(input_tensor,block_id):
  9. bn = BatchNorm(name=block_id+'_BN')(input_tensor)
  10. a = Activation('relu',name=block_id+'_relu')(bn)
  11. return a
  12. def l1_reg(weight_matrix):
  13. return K.mean(weight_matrix)
  14. class Repeat(Layer):
  15. def __init__(self,repeat_list, **kwargs):
  16. super(Repeat, self).__init__(**kwargs)
  17. self.repeat_list = repeat_list
  18. def call(self, inputs):
  19. outputs = tf.tile(inputs, self.repeat_list)
  20. return outputs
  21. def get_config(self):
  22. config = {
  23. 'repeat_list': self.repeat_list
  24. }
  25. base_config = super(Repeat, self).get_config()
  26. return dict(list(base_config.items()) + list(config.items()))
  27. def compute_output_shape(self, input_shape):
  28. output_shape = [None]
  29. for i in xrange(1,len(input_shape)):
  30. output_shape.append(input_shape[i]*self.repeat_list[i])
  31. return tuple(output_shape)
  32. def SpatialAttention(inputs,name):
  33. k = 9
  34. H, W, C = map(int,inputs.get_shape()[1:])
  35. attention1 = Conv2D(C / 2, (1, k), padding='same', name=name+'_1_conv1')(inputs)
  36. attention1 = BN(attention1,'attention1_1')
  37. attention1 = Conv2D(1, (k, 1), padding='same', name=name + '_1_conv2')(attention1)
  38. attention1 = BN(attention1, 'attention1_2')
  39. attention2 = Conv2D(C / 2, (k, 1), padding='same', name=name + '_2_conv1')(inputs)
  40. attention2 = BN(attention2, 'attention2_1')
  41. attention2 = Conv2D(1, (1, k), padding='same', name=name + '_2_conv2')(attention2)
  42. attention2 = BN(attention2, 'attention2_2')
  43. attention = Add(name=name+'_add')([attention1,attention2])
  44. attention = Activation('sigmoid')(attention)
  45. attention = Repeat(repeat_list=[1, 1, 1, C])(attention)
  46. return attention
  47. def ChannelWiseAttention(inputs,name):
  48. H, W, C = map(int, inputs.get_shape()[1:])
  49. attention = GlobalAveragePooling2D(name=name+'_GlobalAveragePooling2D')(inputs)
  50. attention = Dense(C / 4, activation='relu')(attention)
  51. attention = Dense(C, activation='sigmoid',activity_regularizer=l1_reg)(attention)
  52. attention = Reshape((1, 1, C),name=name+'_reshape')(attention)
  53. attention = Repeat(repeat_list=[1, H, W, 1],name=name+'_repeat')(attention)
  54. attention = Multiply(name=name + '_multiply')([attention, inputs])
  55. return attention

损失函数(亮点)

最终的损失函数:
image.png

其中的LS为正常的交叉熵函数,关键在于这里的LB,是一个边界损失。定义如下:

image.png

这里使用laplace算子提取边缘信息(梯度),配合 abstanh 操作得到最终的边缘,下面是相关的代码,主要差异在于 abs 改为了 relu :

  1. import tensorflow as tf
  2. from keras import backend as K
  3. from keras.backend.common import epsilon
  4. def _to_tensor(x, dtype):
  5. return tf.convert_to_tensor(x, dtype=dtype)
  6. def logit(inputs):
  7. _epsilon = _to_tensor(epsilon(), inputs.dtype.base_dtype)
  8. inputs = tf.clip_by_value(inputs, _epsilon, 1 - _epsilon)
  9. inputs = tf.log(inputs / (1 - inputs))
  10. return inputs
  11. # 计算laplace的函数
  12. def tfLaplace(x):
  13. laplace = tf.constant([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], tf.float32)
  14. laplace = tf.reshape(laplace, [3, 3, 1, 1])
  15. edge = tf.nn.conv2d(x, laplace, strides=[1, 1, 1, 1], padding='SAME')
  16. edge = tf.nn.relu(tf.tanh(edge))
  17. return edge
  18. def EdgeLoss(y_true, y_pred):
  19. y_true_edge = tfLaplace(y_true)
  20. edge_pos = 2.
  21. edge_loss = K.mean(tf.nn.weighted_cross_entropy_with_logits(y_true_edge,y_pred,edge_pos), axis=-1)
  22. return edge_loss
  23. def EdgeHoldLoss(y_true, y_pred):
  24. y_pred2 = tf.sigmoid(y_pred)
  25. y_true_edge = tfLaplace(y_true)
  26. y_pred_edge = tfLaplace(y_pred2)
  27. y_pred_edge = logit(y_pred_edge)
  28. edge_loss = K.mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true_edge,logits=y_pred_edge), axis=-1)
  29. saliency_pos = 1.12
  30. saliency_loss = K.mean(tf.nn.weighted_cross_entropy_with_logits(y_true,y_pred,saliency_pos), axis=-1)
  31. return 0.7*saliency_loss+0.3*edge_loss

实验细节

准备数据

  • don’t use the validation set and train the model until training loss converges.
  • some data augmentation techniques:
    • random rotating
    • random cropping
    • random brightness, saturation and contrast changing
    • random horizontal flipping.

下面是实验中设定的训练时的数据增强。真值与输入样本的处理略有不同,归类如下:

Image Mask
random_crop random_crop
random_rotate random_rotate
random_light
Zero-center by mean pixel y = y/y.max()
  1. import numpy as np
  2. import cv2
  3. import random
  4. def padding(x,y):
  5. h,w,c = x.shape
  6. size = max(h,w)
  7. paddingh = (size-h)//2
  8. paddingw = (size-w)//2
  9. temp_x = np.zeros((size,size,c))
  10. temp_y = np.zeros((size,size))
  11. temp_x[paddingh:h+paddingh,paddingw:w+paddingw,:] = x
  12. temp_y[paddingh:h+paddingh,paddingw:w+paddingw] = y
  13. return temp_x,temp_y
  14. def random_crop(x,y):
  15. h,w = y.shape
  16. randh = np.random.randint(h/8)
  17. randw = np.random.randint(w/8)
  18. randf = np.random.randint(10)
  19. offseth = 0 if randh == 0 else np.random.randint(randh)
  20. offsetw = 0 if randw == 0 else np.random.randint(randw)
  21. p0, p1, p2, p3 = offseth,h+offseth-randh, offsetw, w+offsetw-randw
  22. if randf >= 5:
  23. x = x[::, ::-1, ::]
  24. y = y[::, ::-1]
  25. return x[p0:p1,p2:p3],y[p0:p1,p2:p3]
  26. def random_rotate(x,y):
  27. angle = np.random.randint(-25,25)
  28. h, w = y.shape
  29. center = (w / 2, h / 2)
  30. M = cv2.getRotationMatrix2D(center, angle, 1.0)
  31. return cv2.warpAffine(x, M, (w, h)),cv2.warpAffine(y, M, (w, h))
  32. def random_light(x):
  33. contrast = np.random.rand(1)+0.5
  34. light = np.random.randint(-20,20)
  35. x = contrast*x + light
  36. return np.clip(x,0,255)
  37. def getTrainGenerator(file_path, target_size, batch_size, israndom=False):
  38. f = open(file_path, 'r')
  39. trainlist = f.readlines()
  40. f.close()
  41. while True:
  42. random.shuffle(trainlist)
  43. batch_x = []
  44. batch_y = []
  45. for name in trainlist:
  46. p = name.strip('\r\n').split(' ')
  47. img_path = p[0]
  48. mask_path = p[1]
  49. x = cv2.imread(img_path)
  50. y = cv2.imread(mask_path)
  51. x = np.array(x, dtype=np.float32)
  52. y = np.array(y, dtype=np.float32)
  53. ############# 处理的核心 ######################
  54. if len(y.shape) == 3:
  55. y = y[:,:,0]
  56. y = y/y.max()
  57. if israndom:
  58. x,y = random_crop(x,y)
  59. x,y = random_rotate(x,y)
  60. x = random_light(x)
  61. x = x[..., ::-1]
  62. # Zero-center by mean pixel
  63. x[..., 0] -= 103.939
  64. x[..., 1] -= 116.779
  65. x[..., 2] -= 123.68
  66. x, y = padding(x, y)
  67. ############# 处理的核心 ######################
  68. x = cv2.resize(x, target_size, interpolation=cv2.INTER_LINEAR)
  69. y = cv2.resize(y, target_size, interpolation=cv2.INTER_NEAREST)
  70. y = y.reshape((target_size[0],target_size[1],1))
  71. batch_x.append(x)
  72. batch_y.append(y)
  73. if len(batch_x) == batch_size:
  74. yield (np.array(batch_x, dtype=np.float32), np.array(batch_y, dtype=np.float32))
  75. batch_x = []
  76. batch_y = []

在测试的时候,这样设定,使用了 zero-centerpaddingresize ,而预测生成的时候就要 cut 到原始的大小,再与真值计算损失(这里是猜测)。

  1. import numpy as np
  2. import cv2
  3. import os
  4. from keras.layers import Input
  5. from model import VGG16
  6. import matplotlib.pyplot as plt
  7. def padding(x):
  8. h,w,c = x.shape
  9. size = max(h,w)
  10. paddingh = (size-h)//2
  11. paddingw = (size-w)//2
  12. temp_x = np.zeros((size,size,c))
  13. temp_x[paddingh:h+paddingh,paddingw:w+paddingw,:] = x
  14. return temp_x
  15. def load_image(path):
  16. x = cv2.imread(path)
  17. sh = x.shape
  18. x = np.array(x, dtype=np.float32)
  19. # 这句似乎没什么用?
  20. x = x[..., ::-1]
  21. # Zero-center by mean pixel
  22. x[..., 0] -= 103.939
  23. x[..., 1] -= 116.779
  24. x[..., 2] -= 123.68
  25. x = padding(x)
  26. x = cv2.resize(x, target_size, interpolation=cv2.INTER_LINEAR)
  27. x = np.expand_dims(x,0)
  28. return x,sh
  29. def cut(pridict,shape):
  30. h,w,c = shape
  31. size = max(h, w)
  32. pridict = cv2.resize(pridict, (size,size))
  33. paddingh = (size - h) // 2
  34. paddingw = (size - w) // 2
  35. return pridict[paddingh:h + paddingh, paddingw:w + paddingw]
  36. def sigmoid(x):
  37. return 1/(1 + np.exp(-x))
  38. def getres(pridict,shape):
  39. pridict = sigmoid(pridict)*255
  40. pridict = np.array(pridict, dtype=np.uint8)
  41. pridict = np.squeeze(pridict)
  42. pridict = cut(pridict, shape)
  43. return pridict
  44. def laplace_edge(x):
  45. laplace = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])
  46. edge = cv2.filter2D(x/255.,-1,laplace)
  47. edge = np.maximum(np.tanh(edge),0)
  48. edge = edge * 255
  49. edge = np.array(edge, dtype=np.uint8)
  50. return edge
  51. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  52. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  53. model_name = 'model/PFA_00050.h5'
  54. target_size = (256,256)
  55. dropout = False
  56. with_CPFE = True
  57. with_CA = True
  58. with_SA = True
  59. if target_size[0 ] % 32 != 0 or target_size[1] % 32 != 0:
  60. raise ValueError('Image height and wight must be a multiple of 32')
  61. model_input = Input(shape=(target_size[0],target_size[1],3))
  62. model = VGG16(model_input,dropout=dropout, with_CPFE=with_CPFE, with_CA=with_CA, with_SA=with_SA)
  63. model.load_weights(model_name,by_name=True)
  64. for layer in model.layers:
  65. layer.trainable = False
  66. image_path = 'image/2.jpg'
  67. img, shape = load_image(image_path)
  68. img = np.array(img, dtype=np.float32)
  69. sa = model.predict(img)
  70. sa = getres(sa, shape)
  71. edge = laplace_edge(sa)
  72. ...

训练

  • When training, we set α = 1.0 at beginning to generate rough saliency map. In this period, our model is trained using SGD with an initial learning rate 1e-2, the image size is 256×256 , the batch size is 22.
  • Then we adjust different α to refine the boundaries of saliency map,and find α = 0.7 is the optimal setting in experiment Tab.2. In this period, the image size, batch size is same as the previous period, but the initial learning rate is 1e-3.

这里的 alpha 就是两个损失之间的权重比例。

image.png

从上图中可以看出来,添加边界保留损失的时候,确实一定程度上得到较为清晰的边界。

消融实验

image.png

参考链接