迁移学习是在已经训练好的模型上,通过微调最后几层神经网络,从而达到快速训练自定义模型的一种技术
本例程展示如何基于tensorflow通过迁移学习快速训练出自定义的分类模型

数据准备

在主机目录/home/pi/Lepi_Data/ros/transfer_learning下新建文件夹”分类测试”,然后在”分类测试”目录下新建不同的分类目录,比如 “分类1”,”分类2”,向每个分类添加不少于20张对应分类的图片,然后可以运行下面这段示例程序

  1. #!coding:utf-8
  2. import tensorflow as tf
  3. from tensorflow import keras
  4. import os
  5. from tensorflow.keras.models import load_model
  6. import tensorflow.python.keras.backend as K
  7. from tensorflow.keras.preprocessing.image import img_to_array
  8. from sklearn.model_selection import train_test_split
  9. from tensorflow.keras.models import Model
  10. from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
  11. from tensorflow.keras.optimizers import SGD
  12. import cv2
  13. import numpy as np
  14. def get_labels(data_dir):
  15. """
  16. 加载标签
  17. """
  18. labels = []
  19. files = os.listdir(data_dir)
  20. files.sort()
  21. for file in files:
  22. if not os.path.isdir(os.path.join(data_dir, file)):
  23. continue
  24. catname_full, _ = os.path.splitext(file)
  25. # catname = catname_full.split('_')[-1]
  26. # if len(os.listdir(os.path.join(data_dir,file)))>=1:
  27. labels.append(catname_full)
  28. print(labels)
  29. labels.sort()
  30. return labels
  31. def load_img_from_dir(data_dir, target_size=(112, 112), max_num=100): # From Directory
  32. """
  33. 从目录加载图片并缩放至指定尺寸
  34. """
  35. x_load = []
  36. y_load = []
  37. labels = get_labels(data_dir)
  38. dirs = labels
  39. for cat in dirs: # load directory
  40. files_dir = os.path.join(data_dir, cat)
  41. files = os.listdir(files_dir)
  42. for file in files[:max_num]:
  43. file_path = os.path.join(files_dir, file)
  44. try:
  45. cv_img = cv2.imread(file_path)
  46. # cv_img=cv2.imdecode(np.fromfile(file_path,dtype=np.uint8),-1)
  47. x = img_to_array(cv2.resize(cv_img, target_size))
  48. # x = K.expand_dims(x, axis=0)
  49. # x = preprocess_input(x)
  50. x_load.append(x)
  51. y_load.append(labels.index(cat)) # directory name as label
  52. except Exception as e:
  53. print(e)
  54. continue
  55. return np.array(x_load), np.array(y_load)
  56. def prepress_labels(labels):
  57. """
  58. # one-hot编码 把类别id转换为表示当前类别的向量,比如0 1 2 =》 [[1 0 0] [0 1 0] [0 0 1]]
  59. """
  60. from tensorflow.keras import utils
  61. labels = utils.to_categorical(labels)
  62. return labels
  63. class AccuracyLogger(keras.callbacks.Callback):
  64. """
  65. AccuracyLogger 类, 用来记录训练进度
  66. Attributes:
  67. callback: function 回调函数
  68. epoch: int 回合
  69. batch: int 批次
  70. """
  71. def __init__(self, callback=None):
  72. self.callback = callback
  73. self.epoch = 0
  74. self.batch = 0
  75. def set_model(self, model):
  76. """
  77. set_model 函数, 在训练之前由父模型调用,告诉回调函数是哪个模型在调用它
  78. Keyword arguments::
  79. model: model 训练模型
  80. """
  81. self.model = model #
  82. def on_epoch_begin(self, epoch, logs={}):
  83. """
  84. on_epoch_begin 函数, epoch开始的时候自动调用
  85. Keyword arguments::
  86. epoch: int 回合数
  87. """
  88. self.epoch = epoch
  89. def on_batch_end(self, batch, logs={}):
  90. """
  91. on_batch_end 函数, 在batch训练结束后自动调用
  92. Keyword arguments::
  93. batch: int 批次数
  94. """
  95. self.batch = batch
  96. print('epoch:%d batch:%d' % (self.epoch, self.batch))
  97. # print(logs)
  98. if self.callback is not None:
  99. self.callback(self.epoch, self.batch, logs)
  100. class ImageClassifier:
  101. """
  102. ImageClassifier 类, 用来训练分类器
  103. Attributes:
  104. data_root: str 数据根目录
  105. FC_NUMS: int 回合
  106. TRAIN_LAYERS: int 批次
  107. model: model 模型
  108. busy: bool 是否正在执行操作
  109. ns: str 单次训练命名空间
  110. """
  111. def __init__(self, model_path=None, data_root=os.path.expanduser('~')+'/Lepi_Data/ros/transfer_learning'):
  112. self.data_root = data_root
  113. self.FC_NUMS = 64
  114. self.TRAIN_LAYERS = 2
  115. self.IMAGE_SIZE = 112
  116. self.model = None
  117. self.busy = False
  118. self.ns = None
  119. def load_model(self):
  120. """
  121. on_batch_end 函数, 在batch训练结束后自动调用
  122. Keyword arguments::
  123. batch: int 批次数
  124. """
  125. self.busy = True
  126. data_dir = os.path.join(self.data_root, self.ns)
  127. path = os.path.join(data_dir, 'model.h5')
  128. try:
  129. if self.model is None:
  130. self.model = load_model(path) # 加载训练模型
  131. else:
  132. K.clear_session()
  133. self.model = load_model(path)
  134. except Exception as e:
  135. print(e)
  136. finally:
  137. self.busy = False
  138. def preprocess_input(self, cv_img):
  139. cv_img = cv2.resize(
  140. cv_img, (self.IMAGE_SIZE, self.IMAGE_SIZE)).astype('float32')
  141. # cv_img = cv2.normalize(cv_img, None, 0, 1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32FC3)
  142. # cv_img = cv_img.reshape(-1, self.IMAGE_SIZE, self.IMAGE_SIZE, 3)
  143. input_data = np.expand_dims(cv_img, axis=0)
  144. return input_data
  145. def get_base_model(self, name='vgg16'):
  146. print('using model : ', name)
  147. from tensorflow.keras.applications.vgg16 import VGG16
  148. # 采用VGG16为基本模型,include_top为False,表示FC层是可自定义的,抛弃模型中的FC层;该模型会在~/.keras/models下载基本模型
  149. base_model = VGG16(input_shape=(
  150. self.IMAGE_SIZE, self.IMAGE_SIZE, 3), include_top=False, weights='imagenet')
  151. # self.preprocess_input = preprocess_input
  152. return base_model
  153. def download_model(self, name='vgg16'):
  154. """
  155. 下载模型
  156. """
  157. self.get_base_model(name)
  158. def train(self, data_dir, epochs=3, callback=None, model_name='vgg16'):
  159. """
  160. train 函数, 训练模型
  161. Keyword arguments::
  162. epochs: int 训练次数
  163. """
  164. if self.busy:
  165. return 1
  166. self.busy = True
  167. label_names = get_labels(data_dir)
  168. self.NUM_CLASSES = len(label_names)
  169. base_model = self.get_base_model(model_name)
  170. x_data, y_label = load_img_from_dir(data_dir, target_size=(
  171. self.IMAGE_SIZE, self.IMAGE_SIZE), max_num=30)
  172. for i in range(x_data.shape[0]):
  173. x_data[i] = self.preprocess_input(x_data[i])
  174. print(x_data.shape)
  175. print(x_data[0].shape)
  176. x_data = x_data.reshape(
  177. x_data.shape[0], self.IMAGE_SIZE, self.IMAGE_SIZE, 3)
  178. y_label_one_hot = prepress_labels(y_label)
  179. # 验证应该使用从未见过的图片
  180. train_x, test_x, train_y, test_y = train_test_split(x_data, y_label_one_hot, random_state=0,
  181. test_size=0.3)
  182. # 自定义FC层以基本模型的输入为卷积层的最后一层
  183. x = base_model.output
  184. x = GlobalAveragePooling2D()(x)
  185. x = Dense(self.FC_NUMS, activation='relu')(x)
  186. prediction = Dense(self.NUM_CLASSES, activation='softmax')(x)
  187. # 构造完新的FC层,加入custom层
  188. model = Model(inputs=base_model.input, outputs=prediction)
  189. # 获取模型的层数
  190. print("layer nums:", len(model.layers))
  191. # 除了FC层,靠近FC层的一部分卷积层可参与参数训练,
  192. # 一般来说,模型结构已经标明一个卷积块包含的层数,
  193. # 在这里我们选择TRAIN_LAYERS为3,表示最后一个卷积块和FC层要参与参数训练
  194. for layer in model.layers:
  195. layer.trainable = False
  196. for layer in model.layers[-self.TRAIN_LAYERS:]:
  197. layer.trainable = True
  198. for layer in model.layers:
  199. print("layer.trainable:", layer.trainable)
  200. # 预编译模型
  201. model.compile(optimizer=SGD(lr=0.0001, momentum=0.9),
  202. loss='categorical_crossentropy', metrics=['accuracy'])
  203. model.summary()
  204. model.fit(train_x, train_y,
  205. validation_data=(test_x, test_y),
  206. # model.fit(x_data, y_label_one_hot,
  207. # validation_split=0.4,
  208. callbacks=[AccuracyLogger(callback)],
  209. epochs=epochs, batch_size=4,
  210. # steps_per_epoch=1,validation_steps =1 ,
  211. verbose=1, shuffle=True)
  212. self.model = model
  213. model.save(os.path.join(data_dir, 'model.h5'))
  214. self.label_names = label_names
  215. self.dump_label_name(label_names)
  216. # self.convert_tflite()
  217. # self.session = K.get_session()
  218. # self.graph = tf.get_default_graph()
  219. self.busy = False
  220. def dump_label_name(self, dirs):
  221. """
  222. 保存训练标签
  223. """
  224. with open(os.path.join(self.data_root, self.ns, "labelmap.txt"), "w") as f:
  225. # pickle.dump(dirs, f, protocol=2)
  226. to_write = []
  227. for item in dirs:
  228. to_write.append(item+'\n')
  229. f.writelines(to_write)
  230. def load_label_name(self):
  231. """
  232. 加载训练标签
  233. """
  234. import numpy as np
  235. with open(os.path.join(self.data_root, self.ns, "labelmap.txt"), "r") as f:
  236. # label_names = np.array(pickle.load(f))
  237. label_names = [line.strip() for line in f.readlines()]
  238. self.label_names = label_names
  239. return label_names
  240. def evaluate(self, data_dir):
  241. """
  242. 模型评估函数
  243. """
  244. x_data, y_label = load_img_from_dir(data_dir, target_size=(
  245. self.IMAGE_SIZE, self.IMAGE_SIZE), max_num=60)
  246. for i in range(x_data.shape[0]):
  247. x_data[i] = self.preprocess_input(x_data[i])
  248. x_data = x_data.reshape(
  249. x_data.shape[0], self.IMAGE_SIZE, self.IMAGE_SIZE, 3)
  250. y_label_one_hot = prepress_labels(y_label)
  251. # 验证应该使用从未见过的图片
  252. train_x, test_x, train_y, test_y = train_test_split(x_data, y_label_one_hot, random_state=0,
  253. test_size=0.5)
  254. # 开始评估模型效果 # verbose=0为不输出日志信息
  255. score = model.evaluate(test_x, test_y, verbose=1, steps=1)
  256. print('Test loss:', score[0])
  257. print('Test accuracy:', score[1]) # 准确度
  258. def predict(self, cv_img=None, path=None):
  259. """
  260. predict 函数, 执行预测
  261. Keyword arguments::
  262. cv_img: image 输入opencv图像
  263. path: str 图片路径
  264. """
  265. if path is not None:
  266. cv_img = cv2.imread(path)
  267. # label_names = get_labels(data_dir)
  268. # rs_img_f32 = cv2.resize(cv_img, (self.IMAGE_SIZE, self.IMAGE_SIZE)).astype('float32')
  269. if cv_img is None:
  270. print('no cv image provided')
  271. return None
  272. input_data = self.preprocess_input(cv_img)
  273. if self.model is not None:
  274. result = self.model.predict(input_data, steps=1)
  275. # print(label_names)
  276. print("result:", result)
  277. return self.label_names, result[0]
  278. else:
  279. print('your model is not ready')
  280. return None
  281. if __name__ == '__main__':
  282. """
  283. 训练测试
  284. """
  285. IC = ImageClassifier()
  286. IC.ns = '分类测试'
  287. epoch_total = 3
  288. def pub_training_logs(epoch, batch, logs):
  289. # logs {'loss': 0.33773628, 'accuracy': 0.71428573, 'batch': 6, 'size': 4}
  290. print(logs)
  291. msg = '第%d/%d轮, 批次: %d, 损失: %.2f, 准确率: %.2f' % (
  292. epoch+1, epoch_total, batch, logs['loss'], logs['accuracy'])
  293. print(msg)
  294. IC.train(os.path.join(IC.data_root, '分类测试'), epoch_total,
  295. callback=pub_training_logs, model_name='vgg16')