1 Mnist数据集
2.1 简介
提供6W张2828像素点的0~9手写数字图片和标签,用于训练
提供1W张2828像素点的0~9手写数字图片和标签,用于测试
每张图片的784个像素点(28*28=784)组成长度为784的一位数组,作为输入特征。
图片的标签以一位数组形式给出,每个元素表示对应分类出现的概率。
2.2 常用函数
(1)从 集合中取出全部变量,生成一个列表
tf.get_collection("")
(2)列表对应元素相加
tf.add_n([])
(3)把x转为dtype类型
tf.cast(x,dtype)
(4)返回最大值索引号,如tf.argmax([1,0,0])返回0
tf.argmax(x,axis)
(5)返回home/name
os.path.join("home","name")
(6)其内定义的节点在计算图中
with tf.Graph().as_default() as g:
(7)按指定拆分符对字符串切片,返回分割后的列表。如’./model/mnist_model-1001 ‘.split(‘-‘)[-1]
字符串.split()
(8)保存模型
saver = tf.train.Saver()with tf.Session() as sess:...#将当前会话加载到指定路径saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
(9)加载模型
with tf.Session() as sess:ckpt = tf.train.get_checkpoint_state(存储路径)# 若模型存在,则加载出模型到当前对话,在测试数据集上进行准确率验证,并打印出当前轮数下的准确率if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)
(10)实例化可还原平均值的saver
# 实例化具有滑动平均的saver对象,从而在会话被加载时模型中的所有参数被赋值为各自的滑动平均值,增强模型的稳定性ema = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)ema_restore = ema.variables_to_restore()saver = tf.train.Saver(ema_restore)
(11)准确率计算方法
# 计算模型在测试集上的准确率correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
(12)断点续训
表示如果程序中断,下次训练从中断的位置继续训练模型,而不是从头开始训练
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)# 若模型存在,则加载出模型到当前对话,在测试数据集上进行准确率验证,并打印出当前轮数下的准确率if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)
2 代码实现
2.1 代码结构
(1)forward.py
# 定义前向传播过程def forward(x,regularizer):w =b =y =return y# 给权重赋初值def get_weight(shape, regularizer):# 给偏置赋初值def get_bias(shape):
(2)backward.py
def backward(mnist):x =y_=y =global_step =loss =<正则化、指数衰减学习率、滑动平均>train_step =实例化saverwith tf.Session() as sess:初始化for i in range(STEPS):sess.run(train_step,feed_dict = {x:,y_:})if i %轮数 ==0:saver.save()
其中损失函数loss包含正则化regularization
backward.py中加入
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels =tf.argmax(y_,1))cem =tf.reduce_mean(ce)#loss = cem + tf.add_n(tf.get_collection('losses'))# 正则化的loss
forward.py中加入
if regularizer != None:tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
学习率learning_rate
backward.py中加入
## 指数衰减学习率learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,#数据集总样本数/Batch_sizeLEARNING_RATE_DECAY,staircase =True)
滑动平均ema
backward.py中加入
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)#滑动平均ema_op = ema.apply(tf.trainable_variables())with tf.control_dependencies([train_step,ema_op]):train_op = tf.no_op(name='train')
(3)test.py
def test(mnist):with tf.Graph().as_default() as g:定义x y_ y实例化可还原滑动平均值的saver计算正确率while True:with tf.Session() as sess:加载ckpy模型ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)如果已经有ckpt模型则恢复if ckpt and ckpt.model_checkpoint_path:恢复会话saver.restore(sess, ckpt.model_checkpoint_path)恢复轮数 global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]计算准确率accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})打印提示 print("After %s training step(s), test accuracy = %g" % (global_step, accuracy_score))如果没有模型else给出提示 print('No checkpoint file found')returndef main():mnist = input_Data.read_data_sets("./data/",one_hot=True)test(mnist)if __main__=='__main__':main()
2.2 完整代码
(1)mnist_forward.py
# 1前向传播过程import tensorflow as tf# 网络输入节点为784个(代表每张输入图片的像素个数)INPUT_NODE = 784# 输出节点为10个(表示输出为数字0-9的十分类)OUTPUT_NODE = 10# 隐藏层节点500个LAYER1_NODE = 500def get_weight(shape, regularizer):# 参数满足截断正态分布,并使用正则化,w = tf.Variable(tf.truncated_normal(shape, stddev=0.1))# w = tf.Variable(tf.random_normal(shape,stddev=0.1))# 将每个参数的正则化损失加到总损失中if regularizer != None: tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))return wdef get_bias(shape):# 初始化的一维数组,初始化值为全 0b = tf.Variable(tf.zeros(shape))return bdef forward(x, regularizer):# 由输入层到隐藏层的参数w1形状为[784,500]w1 = get_weight([INPUT_NODE, LAYER1_NODE], regularizer)# 由输入层到隐藏的偏置b1形状为长度500的一维数组,b1 = get_bias([LAYER1_NODE])# 前向传播结构第一层为输入 x与参数 w1矩阵相乘加上偏置 b1 ,再经过relu函数 ,得到隐藏层输出 y1。y1 = tf.nn.relu(tf.matmul(x, w1) + b1)# 由隐藏层到输出层的参数w2形状为[500,10]w2 = get_weight([LAYER1_NODE, OUTPUT_NODE], regularizer)# 由隐藏层到输出的偏置b2形状为长度10的一维数组b2 = get_bias([OUTPUT_NODE])# 前向传播结构第二层为隐藏输出 y1与参 数 w2 矩阵相乘加上偏置 矩阵相乘加上偏置 b2,得到输出 y。# 由于输出 。由于输出 y要经过softmax oftmax 函数,使其符合概率分布,故输出y不经过 relu函数y = tf.matmul(y1, w2) + b2return y
(2)mnist_backward.py
#2反向传播过程#引入tensorflow、input_data、前向传播mnist_forward和os模块import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport mnist_forwardimport os#每轮喂入神经网络的图片数BATCH_SIZE = 200#初始学习率LEARNING_RATE_BASE = 0.1#学习率衰减率LEARNING_RATE_DECAY = 0.99#正则化系数REGULARIZER = 0.0001#训练轮数STEPS = 50000#滑动平均衰减率MOVING_AVERAGE_DECAY = 0.99#模型保存路径MODEL_SAVE_PATH="./model/"#模型保存名称MODEL_NAME="mnist_model"def backward(mnist):#用placeholder给训练数据x和标签y_占位x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])#调用mnist_forward文件中的前向传播过程forword()函数,并设置正则化,计算训练数据集上的预测结果yy = mnist_forward.forward(x, REGULARIZER)#当前计算轮数计数器赋值,设定为不可训练类型global_step = tf.Variable(0, trainable=False)#调用包含所有参数正则化损失的损失函数lossce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))cem = tf.reduce_mean(ce)loss = cem + tf.add_n(tf.get_collection('losses'))#设定指数衰减学习率learning_ratelearning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples / BATCH_SIZE,LEARNING_RATE_DECAY,staircase=True)#使用梯度衰减算法对模型优化,降低损失函数#train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)train_step = tf.train.MomentumOptimizer(learning_rate,0.9).minimize(loss, global_step=global_step)#train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)#定义参数的滑动平均ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)ema_op = ema.apply(tf.trainable_variables())#实例化可还原滑动平均的saver#在模型训练时引入滑动平均可以使模型在测试数据上表现的更加健壮with tf.control_dependencies([train_step,ema_op]):train_op = tf.no_op(name='train')saver = tf.train.Saver()with tf.Session() as sess:#所有参数初始化init_op = tf.global_variables_initializer()sess.run(init_op)#每次喂入batch_size组(即200组)训练数据和对应标签,循环迭代steps轮for i in range(STEPS):xs, ys = mnist.train.next_batch(BATCH_SIZE)_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})if i % 1000 == 0:print("After %d training step(s), loss on training batch is %g." % (step, loss_value))#将当前会话加载到指定路径saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)def main():#读入mnistmnist = input_data.read_data_sets("./data/", one_hot=True)#反向传播backward(mnist)if __name__ == '__main__':main()
(3)mnist_test.py
# 验证网络的准确性和泛化性import timeimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport mnist_forwardimport mnist_backward# 程序5秒的循环间隔时间TEST_INTERVAL_SECS = 5def test(mnist):# 利用tf.Graph()复现之前定义的计算图with tf.Graph().as_default() as g:# 利用placeholder给训练数据x和标签y_占位x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])# 调用mnist_forward文件中的前向传播过程forword()函数y = mnist_forward.forward(x, None)# 实例化具有滑动平均的saver对象,从而在会话被加载时模型中的所有参数被赋值为各自的滑动平均值,增强模型的稳定性ema = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)ema_restore = ema.variables_to_restore()saver = tf.train.Saver(ema_restore)# 计算模型在测试集上的准确率correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))while True:with tf.Session() as sess:# 加载指定路径下的ckptckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)# 若模型存在,则加载出模型到当前对话,在测试数据集上进行准确率验证,并打印出当前轮数下的准确率if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})print("After %s training step(s), test accuracy = %g" % (global_step, accuracy_score))# 若模型不存在,则打印出模型不存在的提示,从而test()函数完成else:print('No checkpoint file found')returntime.sleep(TEST_INTERVAL_SECS)def main():# 加载指定路径下的测试数据集mnist = input_data.read_data_sets("./data/", one_hot=True)test(mnist)if __name__ == '__main__':main()
