1.迁移学习介绍
迁移学习,就是将一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。本文的例子是利用ImageNet数据集上训练好的Inception-v3模型来解决一个新的图像分类问题。
一般来说,在数据量充足的情况下,迁移学习的效果不如完全重新训练。但是迁移学习所需要的训练时间和训练样本数远远小于训练完整的模型。
2.基于Tensorflow的迁移学习
2.1 准备数据集
wget http://download.tensorflow.org/example_images/flower_photos.tgztar zxf flow_photos.tgz
文件包含5个子目录,每个目录代表一种花的名称。平局每一种花有734长图片,每一张都是RGB色彩模式的,大小也不相同。
2.2 划分数据集
将数据集划分为训练集,验证集,测试集
import globimport osimport numpy as npimport tensorflow as tffrom tensorflow.python.platform import gfileINPUT_DATA = 'flower_photos/flower_photos'OUTPUT_FILE = 'output'VALIDATION_PERCENTAGE = 10TEST_PERCENTAGE = 10# 读取数据并分割为训练集,验证集,测试集def create_image_lists(sess, testing_percentage, validation_percentage):sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]is_root_dir = True# 初始化各个数据集training_images = []training_labels = []validation_images = []validation_labels = []testing_images = []testing_labels = []current_label = 0for sub_dir in sub_dirs:if is_root_dir:is_root_dir = Falsecontinue# 获取一个子目录的所有图片文件extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']file_list = []dir_name = os.path.basename(sub_dir)for extension in extensions:file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)file_list.extend(glob.glob(file_glob))if not file_list:continuefor file_name in file_list:image_raw_data = gfile.FastGFile(file_name, 'rb').read()image = tf.image.decode_jpeg(image_raw_data)if image.dtype != tf.float32:image = tf.image.convert_image_dtype(image, dtype=tf.float32)image = tf.image.resize_images(image, [299, 299])image_value = sess.run(image)chance = np.random.randint(100)if chance < validation_percentage:validation_images.append(image_value)validation_labels.append(current_label)elif chance < (testing_percentage + validation_percentage):testing_images.append(image_value)testing_labels.append(current_label)else:training_images.append(image_value)training_labels.append(current_label)current_label += 1state = np.randome.get_state()print(state)np.random.shuffle(training_images)np.random.set_state(state)np.random.shuffle(training_labels)return np.asarray([training_images, training_labels, validation_images, validation_labels, testing_images, testing_labels])def main():with tf.Session() as sess:processed_data = create_image_lists(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)np.save(OUTPUT_FILE, processed_data)if __name__ == '__main__':main()
2.3 获取Inception-v3模型
wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz# 解压之后可以得到训练tar zxf inception_v3_2016_08_28.tar.gz
2.4 完成迁移学习
import globimport osimport numpy as npimport tensorflow as tffrom tensorflow.python.platform import gfileINPUT_DATA = 'flower_photos/flower_photos'OUTPUT_FILE = 'output/flower_processed_data'VALIDATION_PERCENTAGE = 10TEST_PERCENTAGE = 10# 读取数据并分割为训练集,验证集,测试集def create_image_lists(sess, testing_percentage, validation_percentage):sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]is_root_dir = True# 初始化各个数据集training_images = []training_labels = []validation_images = []validation_labels = []testing_images = []testing_labels = []current_label = 0for sub_dir in sub_dirs:if is_root_dir:is_root_dir = Falsecontinue# 获取一个子目录的所有图片文件extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']file_list = []dir_name = os.path.basename(sub_dir)for extension in extensions:file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)file_list.extend(glob.glob(file_glob))if not file_list:continuefor file_name in file_list:image_raw_data = gfile.FastGFile(file_name, 'rb').read()image = tf.image.decode_jpeg(image_raw_data)if image.dtype != tf.float32:image = tf.image.convert_image_dtype(image, dtype=tf.float32)image = tf.image.resize_images(image, [299, 299])image_value = sess.run(image)chance = np.random.randint(100)if chance < validation_percentage:validation_images.append(image_value)validation_labels.append(current_label)elif chance < (testing_percentage + validation_percentage):testing_images.append(image_value)testing_labels.append(current_label)else:training_images.append(image_value)training_labels.append(current_label)print("training data size: %d, validation data size: %d, testing data size: %d" % (len(training_images), len(validation_images), len(testing_images)))current_label += 1state = np.random.get_state()print(state)np.random.shuffle(training_images)np.random.set_state(state)np.random.shuffle(training_labels)return np.asarray([training_images, training_labels, validation_images, validation_labels, testing_images, testing_labels])def main():with tf.Session() as sess:processed_data = create_image_lists(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)np.save(OUTPUT_FILE, processed_data)if __name__ == '__main__':main()
