


2.1 准备数据集

  1. wget http://download.tensorflow.org/example_images/flower_photos.tgz
  2. tar zxf flow_photos.tgz


2.2 划分数据集


  1. import glob
  2. import os
  3. import numpy as np
  4. import tensorflow as tf
  5. from tensorflow.python.platform import gfile
  6. INPUT_DATA = 'flower_photos/flower_photos'
  7. OUTPUT_FILE = 'output'
  10. # 读取数据并分割为训练集,验证集,测试集
  11. def create_image_lists(sess, testing_percentage, validation_percentage):
  12. sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
  13. is_root_dir = True
  14. # 初始化各个数据集
  15. training_images = []
  16. training_labels = []
  17. validation_images = []
  18. validation_labels = []
  19. testing_images = []
  20. testing_labels = []
  21. current_label = 0
  22. for sub_dir in sub_dirs:
  23. if is_root_dir:
  24. is_root_dir = False
  25. continue
  26. # 获取一个子目录的所有图片文件
  27. extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
  28. file_list = []
  29. dir_name = os.path.basename(sub_dir)
  30. for extension in extensions:
  31. file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
  32. file_list.extend(glob.glob(file_glob))
  33. if not file_list:
  34. continue
  35. for file_name in file_list:
  36. image_raw_data = gfile.FastGFile(file_name, 'rb').read()
  37. image = tf.image.decode_jpeg(image_raw_data)
  38. if image.dtype != tf.float32:
  39. image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  40. image = tf.image.resize_images(image, [299, 299])
  41. image_value = sess.run(image)
  42. chance = np.random.randint(100)
  43. if chance < validation_percentage:
  44. validation_images.append(image_value)
  45. validation_labels.append(current_label)
  46. elif chance < (testing_percentage + validation_percentage):
  47. testing_images.append(image_value)
  48. testing_labels.append(current_label)
  49. else:
  50. training_images.append(image_value)
  51. training_labels.append(current_label)
  52. current_label += 1
  53. state = np.randome.get_state()
  54. print(state)
  55. np.random.shuffle(training_images)
  56. np.random.set_state(state)
  57. np.random.shuffle(training_labels)
  58. return np.asarray(
  59. [training_images, training_labels, validation_images, validation_labels, testing_images, testing_labels])
  60. def main():
  61. with tf.Session() as sess:
  62. processed_data = create_image_lists(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
  63. np.save(OUTPUT_FILE, processed_data)
  64. if __name__ == '__main__':
  65. main()

2.3 获取Inception-v3模型

  1. wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
  2. # 解压之后可以得到训练
  3. tar zxf inception_v3_2016_08_28.tar.gz

2.4 完成迁移学习

  1. import glob
  2. import os
  3. import numpy as np
  4. import tensorflow as tf
  5. from tensorflow.python.platform import gfile
  6. INPUT_DATA = 'flower_photos/flower_photos'
  7. OUTPUT_FILE = 'output/flower_processed_data'
  10. # 读取数据并分割为训练集,验证集,测试集
  11. def create_image_lists(sess, testing_percentage, validation_percentage):
  12. sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
  13. is_root_dir = True
  14. # 初始化各个数据集
  15. training_images = []
  16. training_labels = []
  17. validation_images = []
  18. validation_labels = []
  19. testing_images = []
  20. testing_labels = []
  21. current_label = 0
  22. for sub_dir in sub_dirs:
  23. if is_root_dir:
  24. is_root_dir = False
  25. continue
  26. # 获取一个子目录的所有图片文件
  27. extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
  28. file_list = []
  29. dir_name = os.path.basename(sub_dir)
  30. for extension in extensions:
  31. file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
  32. file_list.extend(glob.glob(file_glob))
  33. if not file_list:
  34. continue
  35. for file_name in file_list:
  36. image_raw_data = gfile.FastGFile(file_name, 'rb').read()
  37. image = tf.image.decode_jpeg(image_raw_data)
  38. if image.dtype != tf.float32:
  39. image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  40. image = tf.image.resize_images(image, [299, 299])
  41. image_value = sess.run(image)
  42. chance = np.random.randint(100)
  43. if chance < validation_percentage:
  44. validation_images.append(image_value)
  45. validation_labels.append(current_label)
  46. elif chance < (testing_percentage + validation_percentage):
  47. testing_images.append(image_value)
  48. testing_labels.append(current_label)
  49. else:
  50. training_images.append(image_value)
  51. training_labels.append(current_label)
  52. print("training data size: %d, validation data size: %d, testing data size: %d" % (len(training_images), len(validation_images), len(testing_images)))
  53. current_label += 1
  54. state = np.random.get_state()
  55. print(state)
  56. np.random.shuffle(training_images)
  57. np.random.set_state(state)
  58. np.random.shuffle(training_labels)
  59. return np.asarray(
  60. [training_images, training_labels, validation_images, validation_labels, testing_images, testing_labels])
  61. def main():
  62. with tf.Session() as sess:
  63. processed_data = create_image_lists(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
  64. np.save(OUTPUT_FILE, processed_data)
  65. if __name__ == '__main__':
  66. main()