深度残差收缩网络其实是一种通用的特征学习方法,是深度残差网络ResNet、注意力机制和软阈值化的集成,可以用于图像分类。本文采用TensorFlow 1.0和TFLearn 0.3.2,编写了图像分类的程序,采用的图像数据为CIFAR-10。CIFAR-10是一个非常常用的图像数据集,包含10个类别的图像。可以在这个网址找到具体介绍:https://www.cs.toronto.edu/~kriz/cifar.html
    1819738-20190927164458588-147256357.png

    参照ResNet代码(https://github.com/tflearn/tflearn/blob/master/examples/images/residual_network_cifar10.py),所编写的深度残差收缩网络的代码如下:

    1. #!/usr/bin/env python3
    2. # -*- coding: utf-8 -*-
    3. """
    4. Created on Mon Dec 23 21:23:09 2019
    5. M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis,
    6. IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
    7. @author: super_9527
    8. """
    9. from __future__ import division, print_function, absolute_import
    10. import tflearn
    11. import numpy as np
    12. import tensorflow as tf
    13. from tflearn.layers.conv import conv_2d
    14. # Data loading
    15. from tflearn.datasets import cifar10
    16. (X, Y), (testX, testY) = cifar10.load_data()
    17. # Add noise
    18. X = X + np.random.random((50000, 32, 32, 3))*0.1
    19. testX = testX + np.random.random((10000, 32, 32, 3))*0.1
    20. # Transform labels to one-hot format
    21. Y = tflearn.data_utils.to_categorical(Y,10)
    22. testY = tflearn.data_utils.to_categorical(testY,10)
    23. def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
    24. downsample_strides=2, activation='relu', batch_norm=True,
    25. bias=True, weights_init='variance_scaling',
    26. bias_init='zeros', regularizer='L2', weight_decay=0.0001,
    27. trainable=True, restore=True, reuse=False, scope=None,
    28. name="ResidualBlock"):
    29. # residual shrinkage blocks with channel-wise thresholds
    30. residual = incoming
    31. in_channels = incoming.get_shape().as_list()[-1]
    32. # Variable Scope fix for older TF
    33. try:
    34. vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
    35. reuse=reuse)
    36. except Exception:
    37. vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
    38. with vscope as scope:
    39. name = scope.name #TODO
    40. for i in range(nb_blocks):
    41. identity = residual
    42. if not downsample:
    43. downsample_strides = 1
    44. if batch_norm:
    45. residual = tflearn.batch_normalization(residual)
    46. residual = tflearn.activation(residual, activation)
    47. residual = conv_2d(residual, out_channels, 3,
    48. downsample_strides, 'same', 'linear',
    49. bias, weights_init, bias_init,
    50. regularizer, weight_decay, trainable,
    51. restore)
    52. if batch_norm:
    53. residual = tflearn.batch_normalization(residual)
    54. residual = tflearn.activation(residual, activation)
    55. residual = conv_2d(residual, out_channels, 3, 1, 'same',
    56. 'linear', bias, weights_init,
    57. bias_init, regularizer, weight_decay,
    58. trainable, restore)
    59. # get thresholds and apply thresholding
    60. abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
    61. scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
    62. scales = tflearn.batch_normalization(scales)
    63. scales = tflearn.activation(scales, 'relu')
    64. scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
    65. scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
    66. thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
    67. # soft thresholding
    68. residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
    69. # Downsampling
    70. if downsample_strides > 1:
    71. identity = tflearn.avg_pool_2d(identity, 1,
    72. downsample_strides)
    73. # Projection to new dimension
    74. if in_channels != out_channels:
    75. if (out_channels - in_channels) % 2 == 0:
    76. ch = (out_channels - in_channels)//2
    77. identity = tf.pad(identity,
    78. [[0, 0], [0, 0], [0, 0], [ch, ch]])
    79. else:
    80. ch = (out_channels - in_channels)//2
    81. identity = tf.pad(identity,
    82. [[0, 0], [0, 0], [0, 0], [ch, ch+1]])
    83. in_channels = out_channels
    84. residual = residual + identity
    85. return residual
    86. # Real-time data preprocessing
    87. img_prep = tflearn.ImagePreprocessing()
    88. img_prep.add_featurewise_zero_center(per_channel=True)
    89. # Real-time data augmentation
    90. img_aug = tflearn.ImageAugmentation()
    91. img_aug.add_random_flip_leftright()
    92. img_aug.add_random_crop([32, 32], padding=4)
    93. # Building Deep Residual Shrinkage Network
    94. net = tflearn.input_data(shape=[None, 32, 32, 3],
    95. data_preprocessing=img_prep,
    96. data_augmentation=img_aug)
    97. net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
    98. net = residual_shrinkage_block(net, 1, 16)
    99. net = residual_shrinkage_block(net, 1, 32, downsample=True)
    100. net = residual_shrinkage_block(net, 1, 32, downsample=True)
    101. net = tflearn.batch_normalization(net)
    102. net = tflearn.activation(net, 'relu')
    103. net = tflearn.global_avg_pool(net)
    104. # Regression
    105. net = tflearn.fully_connected(net, 10, activation='softmax')
    106. mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)
    107. net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
    108. # Training
    109. model = tflearn.DNN(net, checkpoint_path='model_cifar10',
    110. max_checkpoints=10, tensorboard_verbose=0,
    111. clip_gradients=0.)
    112. model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,
    113. show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')
    114. training_acc = model.evaluate(X, Y)[0]
    115. validation_acc = model.evaluate(testX, testY)[0]

    上面的代码构建了一个小型的深度残差收缩网络,只含有3个基本残差收缩模块,其他的超参数也未进行优化。如果为了追求更高的准确率的话,可以适当增加深度,增加训练迭代次数,以及适当调整超参数。

    前五篇的内容:

    深度残差收缩网络:(一)背景知识 https://www.cnblogs.com/yc-9527/p/11598844.html

    深度残差收缩网络:(二)整体思路 https://www.cnblogs.com/yc-9527/p/11601322.html

    深度残差收缩网络:(三)网络结构 https://www.cnblogs.com/yc-9527/p/11603320.html

    深度残差收缩网络:(四)注意力机制下的阈值设置 https://www.cnblogs.com/yc-9527/p/11604082.html

    深度残差收缩网络:(五)实验验证 https://www.cnblogs.com/yc-9527/p/11610073.html

    原文的链接:

    M. Zhao, S. Zhong, X. Fu, B. Tang, and M. Pecht, “Deep Residual Shrinkage Networks for Fault Diagnosis,” IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898

    https://ieeexplore.ieee.org/document/8850096