1 实现把任意图片放进训练好的网络进行测试

输入的图片是白底黑字的数字图片进行测试,测试前需要做两步
(1)转换图片矩阵大小为28*28符合网络的输入
(2)把图片的转换成白字黑底的黑白图片
mnist_app.py

  1. import tensorflow as tf
  2. import numpy as np
  3. from PIL import Image
  4. import mnist_backward
  5. import mnist_forward
  6. def restore_model(testPicArr):
  7. # 利用tf.Graph()复现之前定义的计算图
  8. with tf.Graph().as_default() as tg:
  9. x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
  10. # 调用mnist_forward文件中的前向传播过程forword()函数
  11. y = mnist_forward.forward(x, None)
  12. # 得到概率最大的预测值
  13. preValue = tf.argmax(y, 1)
  14. # 实例化具有滑动平均的saver对象
  15. variable_averages = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
  16. variables_to_restore = variable_averages.variables_to_restore()
  17. saver = tf.train.Saver(variables_to_restore)
  18. with tf.Session() as sess:
  19. # 通过ckpt获取最新保存的模型
  20. ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
  21. if ckpt and ckpt.model_checkpoint_path:
  22. saver.restore(sess, ckpt.model_checkpoint_path)
  23. preValue = sess.run(preValue, feed_dict={x: testPicArr})
  24. return preValue
  25. else:
  26. print("No checkpoint file found")
  27. return -1
  28. # 预处理,包括resize,转变灰度图,二值化
  29. def pre_pic(picName):
  30. img = Image.open(picName)
  31. reIm = img.resize((28, 28), Image.ANTIALIAS)
  32. #把图片转换为灰度值图片
  33. im_arr = np.array(reIm.convert('L'))
  34. # 对图片做二值化处理(这样以滤掉噪声,另外调试中可适当调节阈值)
  35. threshold = 50
  36. # 模型的要求是黑底白字,但输入的图是白底黑字,所以需要对每个像素点的值改为255减去原值以得到互补的反色。
  37. for i in range(28):
  38. for j in range(28):
  39. im_arr[i][j] = 255 - im_arr[i][j]
  40. if (im_arr[i][j] < threshold):
  41. im_arr[i][j] = 0
  42. else:
  43. im_arr[i][j] = 255
  44. # 把图片形状拉成1行784列,并把值变为浮点型(因为要求像素点是0-1 之间的浮点数)
  45. nm_arr = im_arr.reshape([1, 784])
  46. nm_arr = nm_arr.astype(np.float32)
  47. # 接着让现有的RGB图从0-255之间的数变为0-1之间的浮点数
  48. img_ready = np.multiply(nm_arr, 1.0 / 255.0)
  49. return img_ready
  50. def application():
  51. # 输入要识别的几张图片
  52. testNum = int(input("input the number of test pictures:"))
  53. for i in range(testNum):
  54. # 给出待识别图片的路径和名称
  55. testPic = input("the path of test picture:")
  56. # 图片预处理
  57. testPicArr = pre_pic(testPic)
  58. # 获取预测结果
  59. preValue = restore_model(testPicArr)
  60. print("The prediction number is:", preValue)
  61. def main():
  62. application()
  63. if __name__ == '__main__':
  64. main()

2 实现制作数据

2.1 简介

  • 数据集可以生成二进制的tfrecords文件。先将图片和标签制作成该格式的文件,使用tfrecords进行数据读取,会提高内存利用率。
  • 用tf.train.Example的协议存储训练情况,训练数据的特征用键值对的形式表示。
  • 用SerializeToString()把数据序列化为字符串存储。

    2.2 生成tfrecords文件

    1. writer = tf.python_io.TFRecordWriter(tfRecordName)
    2. # 把每张图片和标签封装到example中
    3. example = tf.train.Example(features=tf.train.Features(feature={
    4. 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),# img_raw放入原始图片
    5. 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))# labels是图片的标签
    6. }))
    7. # 把example进行序列化
    8. writer.write(example.SerializeToString())
    9. # 关闭writer
    10. writer.close()

    2.3 解析tfrecords文件

    ```python

    该函数会生成一个先入先出的队列,文件阅读器会使用它来读取数据

    filename_queue = tf.train.string_input_producer([tfRecord_path], shuffle=True)

    新建一个reader

    reader = tf.TFRecordReader()

    把读出的每个样本保存在serialized_example中进行解序列化,标签和图片的键名应该和制作tfrecords的键名相同,其中标签给出几分类。

    _, serialized_example = reader.read(filename_queue)

    将tf.train.Example协议内存块(protocol buffer)解析为张量

    features = tf.parse_single_example(serialized_example,
    1. features={
    2. 'label': tf.FixedLenFeature([10], tf.int64),
    3. 'img_raw': tf.FixedLenFeature([], tf.string)
    4. })

    将img_raw字符串转换为8位无符号整型

    img = tf.decode_raw(features[‘img_raw’], tf.uint8)

    将形状变为一行784列

    img.set_shape([784]) img = tf.cast(img, tf.float32) * (1. / 255)

    变成0到1之间的浮点数

    label = tf.cast(features[‘label’], tf.float32)
  1. <a name="Qh9Mi"></a>
  2. ## 2.4 生成自定义数据的完整代码
  3. 读取的文件格式是。图片文件名+空格+标签<br />![截屏2020-12-21 下午5.50.16.png](https://cdn.nlark.com/yuque/0/2020/png/1780216/1608544221091-34b15a24-0d52-4f21-874b-2314747e0aa6.png#align=left&display=inline&height=221&margin=%5Bobject%20Object%5D&name=%E6%88%AA%E5%B1%8F2020-12-21%20%E4%B8%8B%E5%8D%885.50.16.png&originHeight=221&originWidth=237&size=51618&status=done&style=none&width=237)
  4. <a name="hXJ73"></a>
  5. ### mnist_generateds.py文件
  6. ```python
  7. #mnist_generateds.py
  8. # coding:utf-8
  9. import tensorflow as tf
  10. import numpy as np
  11. from PIL import Image
  12. import os
  13. image_train_path = './mnist_data_jpg/mnist_train_jpg_60000/'
  14. label_train_path = './mnist_data_jpg/mnist_train_jpg_60000.txt'
  15. tfRecord_train = './data/mnist_train.tfrecords'
  16. image_test_path = './mnist_data_jpg/mnist_test_jpg_10000/'
  17. label_test_path = './mnist_data_jpg/mnist_test_jpg_10000.txt'
  18. tfRecord_test = './data/mnist_test.tfrecords'
  19. data_path = './data'
  20. resize_height = 28
  21. resize_width = 28
  22. # 生成tfrecords文件
  23. def write_tfRecord(tfRecordName, image_path, label_path):
  24. # 新建一个writer
  25. writer = tf.python_io.TFRecordWriter(tfRecordName)
  26. num_pic = 0
  27. f = open(label_path, 'r')
  28. contents = f.readlines()
  29. f.close()
  30. # 循环遍历每张图和标签
  31. for content in contents:
  32. value = content.split()
  33. img_path = image_path + value[0]
  34. img = Image.open(img_path)
  35. img_raw = img.tobytes()#图片转换为二进制数据
  36. labels = [0] * 10
  37. labels[int(value[1])] = 1
  38. # 把每张图片和标签封装到example中
  39. example = tf.train.Example(features=tf.train.Features(feature={
  40. 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
  41. 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
  42. }))
  43. # 把example进行序列化
  44. writer.write(example.SerializeToString())
  45. num_pic += 1#每完成一张图片,计数器加1
  46. print("the number of picture:", num_pic)
  47. # 关闭writer
  48. writer.close()
  49. print("write tfrecord successful")
  50. def generate_tfRecord():
  51. isExists = os.path.exists(data_path)
  52. if not isExists:
  53. os.makedirs(data_path)
  54. print('The directory was created successfully')
  55. else:
  56. print('directory already exists')
  57. write_tfRecord(tfRecord_train, image_train_path, label_train_path)
  58. write_tfRecord(tfRecord_test, image_test_path, label_test_path)
  59. # 解析tfrecords文件
  60. def read_tfRecord(tfRecord_path):
  61. # 该函数会生成一个先入先出的队列,文件阅读器会使用它来读取数据
  62. filename_queue = tf.train.string_input_producer([tfRecord_path], shuffle=True)
  63. # 新建一个reader
  64. reader = tf.TFRecordReader()
  65. # 把读出的每个样本保存在serialized_example中进行解序列化,标签和图片的键名应该和制作tfrecords的键名相同,其中标签给出几分类。
  66. _, serialized_example = reader.read(filename_queue)
  67. # 将tf.train.Example协议内存块(protocol buffer)解析为张量
  68. features = tf.parse_single_example(serialized_example,
  69. features={
  70. 'label': tf.FixedLenFeature([10], tf.int64),# 10表示标签的分类数量
  71. 'img_raw': tf.FixedLenFeature([], tf.string)
  72. })
  73. # 将img_raw字符串转换为8位无符号整型
  74. img = tf.decode_raw(features['img_raw'], tf.uint8)
  75. # 将形状变为一行784列
  76. img.set_shape([784])
  77. img = tf.cast(img, tf.float32) * (1. / 255)
  78. # 变成0到1之间的浮点数
  79. label = tf.cast(features['label'], tf.float32)
  80. # 返回图片和标签
  81. return img, label
  82. def get_tfrecord(num, isTrain=True):
  83. if isTrain:
  84. tfRecord_path = tfRecord_train
  85. else:
  86. tfRecord_path = tfRecord_test
  87. img, label = read_tfRecord(tfRecord_path)
  88. # 随机读取一个batch的数据,打乱数据
  89. img_batch, label_batch = tf.train.shuffle_batch([img, label],
  90. batch_size=num,
  91. num_threads=2,# 线程
  92. capacity=1000,
  93. min_after_dequeue=700)
  94. # 返回的图片和标签为随机抽取的batch_size组
  95. return img_batch, label_batch
  96. def main():
  97. generate_tfRecord()
  98. if __name__ == '__main__':
  99. main()

在反向传播mnistbackward.py和测试程序mnist_test.py中修改图片标签的接口。使用线程协调器,方法如下

  1. coord = tf.train.Coordinator()
  2. threads = tf.train.start_queue_runners(sess = sess,coord = coord)
  3. # 图片和标签的批获取
  4. coord.request_stop()
  5. coord.join(threads)

mnist_backward.py文件

线程协调器的代码是用################################################括起来的

  1. #mnist_backward.py
  2. import tensorflow as tf
  3. from tensorflow.examples.tutorials.mnist import input_data
  4. import mnist_forward
  5. import os
  6. import mnist_generateds # 1
  7. BATCH_SIZE = 200
  8. LEARNING_RATE_BASE = 0.1
  9. LEARNING_RATE_DECAY = 0.99
  10. REGULARIZER = 0.0001
  11. STEPS = 50000
  12. MOVING_AVERAGE_DECAY = 0.99
  13. MODEL_SAVE_PATH = "./model/"
  14. MODEL_NAME = "mnist_model"
  15. # 手动给出训练的总样本数6万
  16. train_num_examples = 60000 # 给出数据集的数量
  17. def backward():
  18. x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
  19. y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
  20. y = mnist_forward.forward(x, REGULARIZER)
  21. global_step = tf.Variable(0, trainable=False)
  22. ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
  23. cem = tf.reduce_mean(ce)
  24. loss = cem + tf.add_n(tf.get_collection('losses'))
  25. learning_rate = tf.train.exponential_decay(
  26. LEARNING_RATE_BASE,
  27. global_step,
  28. train_num_examples / BATCH_SIZE,
  29. LEARNING_RATE_DECAY,
  30. staircase=True)
  31. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  32. ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
  33. ema_op = ema.apply(tf.trainable_variables())
  34. with tf.control_dependencies([train_step, ema_op]):
  35. train_op = tf.no_op(name='train')
  36. saver = tf.train.Saver()
  37. # 一次批获取 batch_size张图片和标签
  38. ################################################
  39. img_batch, label_batch = mnist_generateds.get_tfrecord(BATCH_SIZE, isTrain=True) # 3
  40. ################################################
  41. with tf.Session() as sess:
  42. init_op = tf.global_variables_initializer()
  43. sess.run(init_op)
  44. ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
  45. if ckpt and ckpt.model_checkpoint_path:
  46. saver.restore(sess, ckpt.model_checkpoint_path)
  47. ################################################
  48. # 利用多线程提高图片和标签的批获取效率
  49. coord = tf.train.Coordinator() # 4
  50. # 启动输入队列的线程
  51. threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 5
  52. ################################################
  53. for i in range(STEPS):
  54. ################################################
  55. # 执行图片和标签的批获取
  56. xs, ys = sess.run([img_batch, label_batch]) # 6
  57. ################################################
  58. _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
  59. if i % 1000 == 0:
  60. print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
  61. saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
  62. ################################################
  63. # 关闭线程协调器
  64. coord.request_stop() # 7
  65. coord.join(threads) # 8
  66. ################################################
  67. def main():
  68. backward() # 9
  69. if __name__ == '__main__':
  70. main()

mnist_test.py文件

线程协调器的代码是用################################################括起来的

  1. import time
  2. import tensorflow as tf
  3. from tensorflow.examples.tutorials.mnist import input_data
  4. import mnist_forward
  5. import mnist_backward
  6. import mnist_generateds
  7. TEST_INTERVAL_SECS = 5
  8. # 手动给出测试的总样本数1万
  9. TEST_NUM = 10000 # 1
  10. def test():
  11. with tf.Graph().as_default() as g:
  12. x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
  13. y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
  14. y = mnist_forward.forward(x, None)
  15. ema = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
  16. ema_restore = ema.variables_to_restore()
  17. saver = tf.train.Saver(ema_restore)
  18. correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  19. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  20. ################################################
  21. # 用函数get_tfrecord替换读取所有测试集1万张图片
  22. img_batch, label_batch = mnist_generateds.get_tfrecord(TEST_NUM, isTrain=False) # 2
  23. ################################################
  24. while True:
  25. with tf.Session() as sess:
  26. ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
  27. if ckpt and ckpt.model_checkpoint_path:
  28. saver.restore(sess, ckpt.model_checkpoint_path)
  29. global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  30. ################################################
  31. # 利用多线程提高图片和标签的批获取效率
  32. coord = tf.train.Coordinator() # 3
  33. # 启动输入队列的线程
  34. threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 4
  35. # 执行图片和标签的批获取
  36. xs, ys = sess.run([img_batch, label_batch]) # 5
  37. ################################################
  38. accuracy_score = sess.run(accuracy, feed_dict={x: xs, y_: ys})
  39. print("After %s training step(s), test accuracy = %g" % (global_step, accuracy_score))
  40. ################################################
  41. # 关闭线程协调器
  42. coord.request_stop() # 6
  43. coord.join(threads) # 7
  44. ################################################
  45. else:
  46. print('No checkpoint file found')
  47. return
  48. time.sleep(TEST_INTERVAL_SECS)
  49. def main():
  50. test() # 8
  51. if __name__ == '__main__':
  52. main()