迁移学习是在已经训练好的模型上,通过微调最后几层神经网络,从而达到快速训练自定义模型的一种技术
本例程展示如何基于tensorflow通过迁移学习快速训练出自定义的分类模型
数据准备
在主机目录/home/pi/Lepi_Data/ros/transfer_learning下新建文件夹”分类测试”,然后在”分类测试”目录下新建不同的分类目录,比如 “分类1”,”分类2”,向每个分类添加不少于20张对应分类的图片,然后可以运行下面这段示例程序
#!coding:utf-8import tensorflow as tffrom tensorflow import kerasimport osfrom tensorflow.keras.models import load_modelimport tensorflow.python.keras.backend as Kfrom tensorflow.keras.preprocessing.image import img_to_arrayfrom sklearn.model_selection import train_test_splitfrom tensorflow.keras.models import Modelfrom tensorflow.keras.layers import GlobalAveragePooling2D, Densefrom tensorflow.keras.optimizers import SGDimport cv2import numpy as npdef get_labels(data_dir):"""加载标签"""labels = []files = os.listdir(data_dir)files.sort()for file in files:if not os.path.isdir(os.path.join(data_dir, file)):continuecatname_full, _ = os.path.splitext(file)# catname = catname_full.split('_')[-1]# if len(os.listdir(os.path.join(data_dir,file)))>=1:labels.append(catname_full)print(labels)labels.sort()return labelsdef load_img_from_dir(data_dir, target_size=(112, 112), max_num=100): # From Directory"""从目录加载图片并缩放至指定尺寸"""x_load = []y_load = []labels = get_labels(data_dir)dirs = labelsfor cat in dirs: # load directoryfiles_dir = os.path.join(data_dir, cat)files = os.listdir(files_dir)for file in files[:max_num]:file_path = os.path.join(files_dir, file)try:cv_img = cv2.imread(file_path)# cv_img=cv2.imdecode(np.fromfile(file_path,dtype=np.uint8),-1)x = img_to_array(cv2.resize(cv_img, target_size))# x = K.expand_dims(x, axis=0)# x = preprocess_input(x)x_load.append(x)y_load.append(labels.index(cat)) # directory name as labelexcept Exception as e:print(e)continuereturn np.array(x_load), np.array(y_load)def prepress_labels(labels):"""# one-hot编码 把类别id转换为表示当前类别的向量,比如0 1 2 =》 [[1 0 0] [0 1 0] [0 0 1]]"""from tensorflow.keras import utilslabels = utils.to_categorical(labels)return labelsclass AccuracyLogger(keras.callbacks.Callback):"""AccuracyLogger 类, 用来记录训练进度Attributes:callback: function 回调函数epoch: int 回合batch: int 批次"""def __init__(self, callback=None):self.callback = callbackself.epoch = 0self.batch = 0def set_model(self, model):"""set_model 函数, 在训练之前由父模型调用,告诉回调函数是哪个模型在调用它Keyword arguments::model: model 训练模型"""self.model = model #def on_epoch_begin(self, epoch, logs={}):"""on_epoch_begin 函数, epoch开始的时候自动调用Keyword arguments::epoch: int 回合数"""self.epoch = epochdef on_batch_end(self, batch, logs={}):"""on_batch_end 函数, 在batch训练结束后自动调用Keyword arguments::batch: int 批次数"""self.batch = batchprint('epoch:%d batch:%d' % (self.epoch, self.batch))# print(logs)if self.callback is not None:self.callback(self.epoch, self.batch, logs)class ImageClassifier:"""ImageClassifier 类, 用来训练分类器Attributes:data_root: str 数据根目录FC_NUMS: int 回合TRAIN_LAYERS: int 批次model: model 模型busy: bool 是否正在执行操作ns: str 单次训练命名空间"""def __init__(self, model_path=None, data_root=os.path.expanduser('~')+'/Lepi_Data/ros/transfer_learning'):self.data_root = data_rootself.FC_NUMS = 64self.TRAIN_LAYERS = 2self.IMAGE_SIZE = 112self.model = Noneself.busy = Falseself.ns = Nonedef load_model(self):"""on_batch_end 函数, 在batch训练结束后自动调用Keyword arguments::batch: int 批次数"""self.busy = Truedata_dir = os.path.join(self.data_root, self.ns)path = os.path.join(data_dir, 'model.h5')try:if self.model is None:self.model = load_model(path) # 加载训练模型else:K.clear_session()self.model = load_model(path)except Exception as e:print(e)finally:self.busy = Falsedef preprocess_input(self, cv_img):cv_img = cv2.resize(cv_img, (self.IMAGE_SIZE, self.IMAGE_SIZE)).astype('float32')# cv_img = cv2.normalize(cv_img, None, 0, 1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32FC3)# cv_img = cv_img.reshape(-1, self.IMAGE_SIZE, self.IMAGE_SIZE, 3)input_data = np.expand_dims(cv_img, axis=0)return input_datadef get_base_model(self, name='vgg16'):print('using model : ', name)from tensorflow.keras.applications.vgg16 import VGG16# 采用VGG16为基本模型,include_top为False,表示FC层是可自定义的,抛弃模型中的FC层;该模型会在~/.keras/models下载基本模型base_model = VGG16(input_shape=(self.IMAGE_SIZE, self.IMAGE_SIZE, 3), include_top=False, weights='imagenet')# self.preprocess_input = preprocess_inputreturn base_modeldef download_model(self, name='vgg16'):"""下载模型"""self.get_base_model(name)def train(self, data_dir, epochs=3, callback=None, model_name='vgg16'):"""train 函数, 训练模型Keyword arguments::epochs: int 训练次数"""if self.busy:return 1self.busy = Truelabel_names = get_labels(data_dir)self.NUM_CLASSES = len(label_names)base_model = self.get_base_model(model_name)x_data, y_label = load_img_from_dir(data_dir, target_size=(self.IMAGE_SIZE, self.IMAGE_SIZE), max_num=30)for i in range(x_data.shape[0]):x_data[i] = self.preprocess_input(x_data[i])print(x_data.shape)print(x_data[0].shape)x_data = x_data.reshape(x_data.shape[0], self.IMAGE_SIZE, self.IMAGE_SIZE, 3)y_label_one_hot = prepress_labels(y_label)# 验证应该使用从未见过的图片train_x, test_x, train_y, test_y = train_test_split(x_data, y_label_one_hot, random_state=0,test_size=0.3)# 自定义FC层以基本模型的输入为卷积层的最后一层x = base_model.outputx = GlobalAveragePooling2D()(x)x = Dense(self.FC_NUMS, activation='relu')(x)prediction = Dense(self.NUM_CLASSES, activation='softmax')(x)# 构造完新的FC层,加入custom层model = Model(inputs=base_model.input, outputs=prediction)# 获取模型的层数print("layer nums:", len(model.layers))# 除了FC层,靠近FC层的一部分卷积层可参与参数训练,# 一般来说,模型结构已经标明一个卷积块包含的层数,# 在这里我们选择TRAIN_LAYERS为3,表示最后一个卷积块和FC层要参与参数训练for layer in model.layers:layer.trainable = Falsefor layer in model.layers[-self.TRAIN_LAYERS:]:layer.trainable = Truefor layer in model.layers:print("layer.trainable:", layer.trainable)# 预编译模型model.compile(optimizer=SGD(lr=0.0001, momentum=0.9),loss='categorical_crossentropy', metrics=['accuracy'])model.summary()model.fit(train_x, train_y,validation_data=(test_x, test_y),# model.fit(x_data, y_label_one_hot,# validation_split=0.4,callbacks=[AccuracyLogger(callback)],epochs=epochs, batch_size=4,# steps_per_epoch=1,validation_steps =1 ,verbose=1, shuffle=True)self.model = modelmodel.save(os.path.join(data_dir, 'model.h5'))self.label_names = label_namesself.dump_label_name(label_names)# self.convert_tflite()# self.session = K.get_session()# self.graph = tf.get_default_graph()self.busy = Falsedef dump_label_name(self, dirs):"""保存训练标签"""with open(os.path.join(self.data_root, self.ns, "labelmap.txt"), "w") as f:# pickle.dump(dirs, f, protocol=2)to_write = []for item in dirs:to_write.append(item+'\n')f.writelines(to_write)def load_label_name(self):"""加载训练标签"""import numpy as npwith open(os.path.join(self.data_root, self.ns, "labelmap.txt"), "r") as f:# label_names = np.array(pickle.load(f))label_names = [line.strip() for line in f.readlines()]self.label_names = label_namesreturn label_namesdef evaluate(self, data_dir):"""模型评估函数"""x_data, y_label = load_img_from_dir(data_dir, target_size=(self.IMAGE_SIZE, self.IMAGE_SIZE), max_num=60)for i in range(x_data.shape[0]):x_data[i] = self.preprocess_input(x_data[i])x_data = x_data.reshape(x_data.shape[0], self.IMAGE_SIZE, self.IMAGE_SIZE, 3)y_label_one_hot = prepress_labels(y_label)# 验证应该使用从未见过的图片train_x, test_x, train_y, test_y = train_test_split(x_data, y_label_one_hot, random_state=0,test_size=0.5)# 开始评估模型效果 # verbose=0为不输出日志信息score = model.evaluate(test_x, test_y, verbose=1, steps=1)print('Test loss:', score[0])print('Test accuracy:', score[1]) # 准确度def predict(self, cv_img=None, path=None):"""predict 函数, 执行预测Keyword arguments::cv_img: image 输入opencv图像path: str 图片路径"""if path is not None:cv_img = cv2.imread(path)# label_names = get_labels(data_dir)# rs_img_f32 = cv2.resize(cv_img, (self.IMAGE_SIZE, self.IMAGE_SIZE)).astype('float32')if cv_img is None:print('no cv image provided')return Noneinput_data = self.preprocess_input(cv_img)if self.model is not None:result = self.model.predict(input_data, steps=1)# print(label_names)print("result:", result)return self.label_names, result[0]else:print('your model is not ready')return Noneif __name__ == '__main__':"""训练测试"""IC = ImageClassifier()IC.ns = '分类测试'epoch_total = 3def pub_training_logs(epoch, batch, logs):# logs {'loss': 0.33773628, 'accuracy': 0.71428573, 'batch': 6, 'size': 4}print(logs)msg = '第%d/%d轮, 批次: %d, 损失: %.2f, 准确率: %.2f' % (epoch+1, epoch_total, batch, logs['loss'], logs['accuracy'])print(msg)IC.train(os.path.join(IC.data_root, '分类测试'), epoch_total,callback=pub_training_logs, model_name='vgg16')
