模型可以在训练期间和训练完成后进行保存,本文介绍 Tensorflow 2.0 中如何保存和恢复模型。
在训练期间保存模型
在训练期间保存模型可以直接从中断的地方开始训练,模型以 checkpoints 形式保存。使用 tf.keras.callbacks.ModelCheckpoint 在训练的过程中和结束时回调保存的模型。
Checkpoint 回调用法
创建一个只在训练期间保存权重的 tf.keras.callbacks.ModelCheckpoint 回调:
checkpoint_path = "training_1/cp.ckpt" # 设置保存路径checkpoint_dir = os.path.dirname(checkpoint_path)# 创建一个保存模型权重的回调cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,save_weights_only=True,verbose=1)# 使用回调训练模型model.fit(train_images,train_labels,epochs=10,validation_data=(test_images,test_labels),callbacks=[cp_callback])
Checkpoint 回调选项
回调提供了几个选项,可为 checkpoint 指定唯一名称并调整 checkpoint 的保存频率。例如,下面的代码训练了一个新模型,训练过程中每五个 epochs 保存一次模型的权重。
# 在文件名中包含 epoch (使用 `str.format`)checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"checkpoint_dir = os.path.dirname(checkpoint_path)# 创建一个回调cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,verbose=1,save_weights_only=True,period=5 # 每 5 个 epochs 保存一次模型的权重)# 创建一个新的模型实例model = create_model()# 使用 `checkpoint_path` 格式保存权重model.save_weights(checkpoint_path.format(epoch=0))# 使用回调训练模型model.fit(train_images,train_labels,epochs=50,callbacks=[cp_callback],validation_data=(test_images,test_labels),verbose=0)
获取最新的 checkpoint,并加载保存的权重:
# 获取最新的 checkpointlatest = tf.train.latest_checkpoint(checkpoint_dir)# 创建一个新的模型实例model = create_model()# 加载以前保存的权重model.load_weights(latest)
保存整个网络模型
保存整个网络模型即保存了网络结构、优化器、权重等模型细节,可以在无需原始代码的情况下恢复整个网络。
将模型保存为 HDF5 文件
以 HDF5 格式保存整个网络模型,需要安装相关依赖。
pip install -q pyyaml h5py # 需要以 HDF5 格式保存模型
下面为网络模型的保存与恢复:
model.save("the_save_model.h5")model = tf.keras.models.load_model("the_save_model.h5")
保存网络全模型为 SavedModel 文件
注意:这种保存模型的方法是实验性的,在将来的版本中可能有所改变。
tf.keras.experimental.export_saved_model(model, 'saved_model')model = keras.experimental.load_from_saved_model('saved_model')
仅保存网络结构
config = model.get_config()model = tf.keras.Model.from_config(config)
仅保存网络参数
weights = model.get_weights()model.set_weights(weights)# 可以把结构和参数保存结合起来config = model.get_config()weights = model.get_weights()model = tf.keras.Model.from_config(config)model.set_weights(weights)
仅保存网络权重
model.save_weights('weight_tf_savedmodel')model.save_weights('weight_tf_savedmodel_h5', save_format='h5')
