1.迁移学习介绍

迁移学习,就是将一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。本文的例子是利用ImageNet数据集上训练好的Inception-v3模型来解决一个新的图像分类问题。
一般来说,在数据量充足的情况下,迁移学习的效果不如完全重新训练。但是迁移学习所需要的训练时间和训练样本数远远小于训练完整的模型。

2.基于Tensorflow的迁移学习

2.1 准备数据集

  1. wget http://download.tensorflow.org/example_images/flower_photos.tgz
  2. tar zxf flow_photos.tgz

文件包含5个子目录,每个目录代表一种花的名称。平局每一种花有734长图片,每一张都是RGB色彩模式的,大小也不相同。

2.2 划分数据集

将数据集划分为训练集,验证集,测试集

  1. import glob
  2. import os
  3. import numpy as np
  4. import tensorflow as tf
  5. from tensorflow.python.platform import gfile
  6. INPUT_DATA = 'flower_photos/flower_photos'
  7. OUTPUT_FILE = 'output'
  8. VALIDATION_PERCENTAGE = 10
  9. TEST_PERCENTAGE = 10
  10. # 读取数据并分割为训练集,验证集,测试集
  11. def create_image_lists(sess, testing_percentage, validation_percentage):
  12. sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
  13. is_root_dir = True
  14. # 初始化各个数据集
  15. training_images = []
  16. training_labels = []
  17. validation_images = []
  18. validation_labels = []
  19. testing_images = []
  20. testing_labels = []
  21. current_label = 0
  22. for sub_dir in sub_dirs:
  23. if is_root_dir:
  24. is_root_dir = False
  25. continue
  26. # 获取一个子目录的所有图片文件
  27. extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
  28. file_list = []
  29. dir_name = os.path.basename(sub_dir)
  30. for extension in extensions:
  31. file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
  32. file_list.extend(glob.glob(file_glob))
  33. if not file_list:
  34. continue
  35. for file_name in file_list:
  36. image_raw_data = gfile.FastGFile(file_name, 'rb').read()
  37. image = tf.image.decode_jpeg(image_raw_data)
  38. if image.dtype != tf.float32:
  39. image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  40. image = tf.image.resize_images(image, [299, 299])
  41. image_value = sess.run(image)
  42. chance = np.random.randint(100)
  43. if chance < validation_percentage:
  44. validation_images.append(image_value)
  45. validation_labels.append(current_label)
  46. elif chance < (testing_percentage + validation_percentage):
  47. testing_images.append(image_value)
  48. testing_labels.append(current_label)
  49. else:
  50. training_images.append(image_value)
  51. training_labels.append(current_label)
  52. current_label += 1
  53. state = np.randome.get_state()
  54. print(state)
  55. np.random.shuffle(training_images)
  56. np.random.set_state(state)
  57. np.random.shuffle(training_labels)
  58. return np.asarray(
  59. [training_images, training_labels, validation_images, validation_labels, testing_images, testing_labels])
  60. def main():
  61. with tf.Session() as sess:
  62. processed_data = create_image_lists(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
  63. np.save(OUTPUT_FILE, processed_data)
  64. if __name__ == '__main__':
  65. main()

2.3 获取Inception-v3模型

  1. wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
  2. # 解压之后可以得到训练
  3. tar zxf inception_v3_2016_08_28.tar.gz

2.4 完成迁移学习

  1. import glob
  2. import os
  3. import numpy as np
  4. import tensorflow as tf
  5. from tensorflow.python.platform import gfile
  6. INPUT_DATA = 'flower_photos/flower_photos'
  7. OUTPUT_FILE = 'output/flower_processed_data'
  8. VALIDATION_PERCENTAGE = 10
  9. TEST_PERCENTAGE = 10
  10. # 读取数据并分割为训练集,验证集,测试集
  11. def create_image_lists(sess, testing_percentage, validation_percentage):
  12. sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
  13. is_root_dir = True
  14. # 初始化各个数据集
  15. training_images = []
  16. training_labels = []
  17. validation_images = []
  18. validation_labels = []
  19. testing_images = []
  20. testing_labels = []
  21. current_label = 0
  22. for sub_dir in sub_dirs:
  23. if is_root_dir:
  24. is_root_dir = False
  25. continue
  26. # 获取一个子目录的所有图片文件
  27. extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
  28. file_list = []
  29. dir_name = os.path.basename(sub_dir)
  30. for extension in extensions:
  31. file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
  32. file_list.extend(glob.glob(file_glob))
  33. if not file_list:
  34. continue
  35. for file_name in file_list:
  36. image_raw_data = gfile.FastGFile(file_name, 'rb').read()
  37. image = tf.image.decode_jpeg(image_raw_data)
  38. if image.dtype != tf.float32:
  39. image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  40. image = tf.image.resize_images(image, [299, 299])
  41. image_value = sess.run(image)
  42. chance = np.random.randint(100)
  43. if chance < validation_percentage:
  44. validation_images.append(image_value)
  45. validation_labels.append(current_label)
  46. elif chance < (testing_percentage + validation_percentage):
  47. testing_images.append(image_value)
  48. testing_labels.append(current_label)
  49. else:
  50. training_images.append(image_value)
  51. training_labels.append(current_label)
  52. print("training data size: %d, validation data size: %d, testing data size: %d" % (len(training_images), len(validation_images), len(testing_images)))
  53. current_label += 1
  54. state = np.random.get_state()
  55. print(state)
  56. np.random.shuffle(training_images)
  57. np.random.set_state(state)
  58. np.random.shuffle(training_labels)
  59. return np.asarray(
  60. [training_images, training_labels, validation_images, validation_labels, testing_images, testing_labels])
  61. def main():
  62. with tf.Session() as sess:
  63. processed_data = create_image_lists(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
  64. np.save(OUTPUT_FILE, processed_data)
  65. if __name__ == '__main__':
  66. main()