深度残差收缩网络其实是一种通用的特征学习方法,是深度残差网络ResNet、注意力机制和软阈值化的集成,可以用于图像分类。本文采用TensorFlow 1.0和TFLearn 0.3.2,编写了图像分类的程序,采用的图像数据为CIFAR-10。CIFAR-10是一个非常常用的图像数据集,包含10个类别的图像。可以在这个网址找到具体介绍:https://www.cs.toronto.edu/~kriz/cifar.html
参照ResNet代码(https://github.com/tflearn/tflearn/blob/master/examples/images/residual_network_cifar10.py),所编写的深度残差收缩网络的代码如下:
#!/usr/bin/env python3# -*- coding: utf-8 -*-"""Created on Mon Dec 23 21:23:09 2019M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis,IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898@author: super_9527"""from __future__ import division, print_function, absolute_importimport tflearnimport numpy as npimport tensorflow as tffrom tflearn.layers.conv import conv_2d# Data loadingfrom tflearn.datasets import cifar10(X, Y), (testX, testY) = cifar10.load_data()# Add noiseX = X + np.random.random((50000, 32, 32, 3))*0.1testX = testX + np.random.random((10000, 32, 32, 3))*0.1# Transform labels to one-hot formatY = tflearn.data_utils.to_categorical(Y,10)testY = tflearn.data_utils.to_categorical(testY,10)def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,downsample_strides=2, activation='relu', batch_norm=True,bias=True, weights_init='variance_scaling',bias_init='zeros', regularizer='L2', weight_decay=0.0001,trainable=True, restore=True, reuse=False, scope=None,name="ResidualBlock"):# residual shrinkage blocks with channel-wise thresholdsresidual = incomingin_channels = incoming.get_shape().as_list()[-1]# Variable Scope fix for older TFtry:vscope = tf.variable_scope(scope, default_name=name, values=[incoming],reuse=reuse)except Exception:vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)with vscope as scope:name = scope.name #TODOfor i in range(nb_blocks):identity = residualif not downsample:downsample_strides = 1if batch_norm:residual = tflearn.batch_normalization(residual)residual = tflearn.activation(residual, activation)residual = conv_2d(residual, out_channels, 3,downsample_strides, 'same', 'linear',bias, weights_init, bias_init,regularizer, weight_decay, trainable,restore)if batch_norm:residual = tflearn.batch_normalization(residual)residual = tflearn.activation(residual, activation)residual = conv_2d(residual, out_channels, 3, 1, 'same','linear', bias, weights_init,bias_init, regularizer, weight_decay,trainable, restore)# get thresholds and apply thresholdingabs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')scales = tflearn.batch_normalization(scales)scales = tflearn.activation(scales, 'relu')scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))# soft thresholdingresidual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))# Downsamplingif downsample_strides > 1:identity = tflearn.avg_pool_2d(identity, 1,downsample_strides)# Projection to new dimensionif in_channels != out_channels:if (out_channels - in_channels) % 2 == 0:ch = (out_channels - in_channels)//2identity = tf.pad(identity,[[0, 0], [0, 0], [0, 0], [ch, ch]])else:ch = (out_channels - in_channels)//2identity = tf.pad(identity,[[0, 0], [0, 0], [0, 0], [ch, ch+1]])in_channels = out_channelsresidual = residual + identityreturn residual# Real-time data preprocessingimg_prep = tflearn.ImagePreprocessing()img_prep.add_featurewise_zero_center(per_channel=True)# Real-time data augmentationimg_aug = tflearn.ImageAugmentation()img_aug.add_random_flip_leftright()img_aug.add_random_crop([32, 32], padding=4)# Building Deep Residual Shrinkage Networknet = tflearn.input_data(shape=[None, 32, 32, 3],data_preprocessing=img_prep,data_augmentation=img_aug)net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)net = residual_shrinkage_block(net, 1, 16)net = residual_shrinkage_block(net, 1, 32, downsample=True)net = residual_shrinkage_block(net, 1, 32, downsample=True)net = tflearn.batch_normalization(net)net = tflearn.activation(net, 'relu')net = tflearn.global_avg_pool(net)# Regressionnet = tflearn.fully_connected(net, 10, activation='softmax')mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')# Trainingmodel = tflearn.DNN(net, checkpoint_path='model_cifar10',max_checkpoints=10, tensorboard_verbose=0,clip_gradients=0.)model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')training_acc = model.evaluate(X, Y)[0]validation_acc = model.evaluate(testX, testY)[0]
上面的代码构建了一个小型的深度残差收缩网络,只含有3个基本残差收缩模块,其他的超参数也未进行优化。如果为了追求更高的准确率的话,可以适当增加深度,增加训练迭代次数,以及适当调整超参数。
前五篇的内容:
深度残差收缩网络:(一)背景知识 https://www.cnblogs.com/yc-9527/p/11598844.html
深度残差收缩网络:(二)整体思路 https://www.cnblogs.com/yc-9527/p/11601322.html
深度残差收缩网络:(三)网络结构 https://www.cnblogs.com/yc-9527/p/11603320.html
深度残差收缩网络:(四)注意力机制下的阈值设置 https://www.cnblogs.com/yc-9527/p/11604082.html
深度残差收缩网络:(五)实验验证 https://www.cnblogs.com/yc-9527/p/11610073.html
原文的链接:
M. Zhao, S. Zhong, X. Fu, B. Tang, and M. Pecht, “Deep Residual Shrinkage Networks for Fault Diagnosis,” IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
