1.迁移学习介绍
迁移学习,就是将一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。本文的例子是利用ImageNet数据集上训练好的Inception-v3模型来解决一个新的图像分类问题。
一般来说,在数据量充足的情况下,迁移学习的效果不如完全重新训练。但是迁移学习所需要的训练时间和训练样本数远远小于训练完整的模型。
2.基于Tensorflow的迁移学习
2.1 准备数据集
wget http://download.tensorflow.org/example_images/flower_photos.tgz
tar zxf flow_photos.tgz
文件包含5个子目录,每个目录代表一种花的名称。平局每一种花有734长图片,每一张都是RGB色彩模式的,大小也不相同。
2.2 划分数据集
将数据集划分为训练集,验证集,测试集
import glob
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
INPUT_DATA = 'flower_photos/flower_photos'
OUTPUT_FILE = 'output'
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10
# 读取数据并分割为训练集,验证集,测试集
def create_image_lists(sess, testing_percentage, validation_percentage):
sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
is_root_dir = True
# 初始化各个数据集
training_images = []
training_labels = []
validation_images = []
validation_labels = []
testing_images = []
testing_labels = []
current_label = 0
for sub_dir in sub_dirs:
if is_root_dir:
is_root_dir = False
continue
# 获取一个子目录的所有图片文件
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
file_list = []
dir_name = os.path.basename(sub_dir)
for extension in extensions:
file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
file_list.extend(glob.glob(file_glob))
if not file_list:
continue
for file_name in file_list:
image_raw_data = gfile.FastGFile(file_name, 'rb').read()
image = tf.image.decode_jpeg(image_raw_data)
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.resize_images(image, [299, 299])
image_value = sess.run(image)
chance = np.random.randint(100)
if chance < validation_percentage:
validation_images.append(image_value)
validation_labels.append(current_label)
elif chance < (testing_percentage + validation_percentage):
testing_images.append(image_value)
testing_labels.append(current_label)
else:
training_images.append(image_value)
training_labels.append(current_label)
current_label += 1
state = np.randome.get_state()
print(state)
np.random.shuffle(training_images)
np.random.set_state(state)
np.random.shuffle(training_labels)
return np.asarray(
[training_images, training_labels, validation_images, validation_labels, testing_images, testing_labels])
def main():
with tf.Session() as sess:
processed_data = create_image_lists(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
np.save(OUTPUT_FILE, processed_data)
if __name__ == '__main__':
main()
2.3 获取Inception-v3模型
wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
# 解压之后可以得到训练
tar zxf inception_v3_2016_08_28.tar.gz
2.4 完成迁移学习
import glob
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
INPUT_DATA = 'flower_photos/flower_photos'
OUTPUT_FILE = 'output/flower_processed_data'
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10
# 读取数据并分割为训练集,验证集,测试集
def create_image_lists(sess, testing_percentage, validation_percentage):
sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
is_root_dir = True
# 初始化各个数据集
training_images = []
training_labels = []
validation_images = []
validation_labels = []
testing_images = []
testing_labels = []
current_label = 0
for sub_dir in sub_dirs:
if is_root_dir:
is_root_dir = False
continue
# 获取一个子目录的所有图片文件
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
file_list = []
dir_name = os.path.basename(sub_dir)
for extension in extensions:
file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
file_list.extend(glob.glob(file_glob))
if not file_list:
continue
for file_name in file_list:
image_raw_data = gfile.FastGFile(file_name, 'rb').read()
image = tf.image.decode_jpeg(image_raw_data)
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.resize_images(image, [299, 299])
image_value = sess.run(image)
chance = np.random.randint(100)
if chance < validation_percentage:
validation_images.append(image_value)
validation_labels.append(current_label)
elif chance < (testing_percentage + validation_percentage):
testing_images.append(image_value)
testing_labels.append(current_label)
else:
training_images.append(image_value)
training_labels.append(current_label)
print("training data size: %d, validation data size: %d, testing data size: %d" % (len(training_images), len(validation_images), len(testing_images)))
current_label += 1
state = np.random.get_state()
print(state)
np.random.shuffle(training_images)
np.random.set_state(state)
np.random.shuffle(training_labels)
return np.asarray(
[training_images, training_labels, validation_images, validation_labels, testing_images, testing_labels])
def main():
with tf.Session() as sess:
processed_data = create_image_lists(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
np.save(OUTPUT_FILE, processed_data)
if __name__ == '__main__':
main()