模型可以在训练期间和训练完成后进行保存,本文介绍 Tensorflow 2.0 中如何保存和恢复模型。
在训练期间保存模型
在训练期间保存模型可以直接从中断的地方开始训练,模型以 checkpoints
形式保存。使用 tf.keras.callbacks.ModelCheckpoint
在训练的过程中和结束时回调保存的模型。
Checkpoint 回调用法
创建一个只在训练期间保存权重的 tf.keras.callbacks.ModelCheckpoint
回调:
checkpoint_path = "training_1/cp.ckpt" # 设置保存路径
checkpoint_dir = os.path.dirname(checkpoint_path)
# 创建一个保存模型权重的回调
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
verbose=1
)
# 使用回调训练模型
model.fit(
train_images,
train_labels,
epochs=10,
validation_data=(test_images,test_labels),
callbacks=[cp_callback]
)
Checkpoint 回调选项
回调提供了几个选项,可为 checkpoint 指定唯一名称并调整 checkpoint 的保存频率。例如,下面的代码训练了一个新模型,训练过程中每五个 epochs 保存一次模型的权重。
# 在文件名中包含 epoch (使用 `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# 创建一个回调
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
period=5 # 每 5 个 epochs 保存一次模型的权重
)
# 创建一个新的模型实例
model = create_model()
# 使用 `checkpoint_path` 格式保存权重
model.save_weights(checkpoint_path.format(epoch=0))
# 使用回调训练模型
model.fit(
train_images,
train_labels,
epochs=50,
callbacks=[cp_callback],
validation_data=(test_images,test_labels),
verbose=0
)
获取最新的 checkpoint,并加载保存的权重:
# 获取最新的 checkpoint
latest = tf.train.latest_checkpoint(checkpoint_dir)
# 创建一个新的模型实例
model = create_model()
# 加载以前保存的权重
model.load_weights(latest)
保存整个网络模型
保存整个网络模型即保存了网络结构、优化器、权重等模型细节,可以在无需原始代码的情况下恢复整个网络。
将模型保存为 HDF5
文件
以 HDF5
格式保存整个网络模型,需要安装相关依赖。
pip install -q pyyaml h5py # 需要以 HDF5 格式保存模型
下面为网络模型的保存与恢复:
model.save("the_save_model.h5")
model = tf.keras.models.load_model("the_save_model.h5")
保存网络全模型为 SavedModel 文件
注意:这种保存模型的方法是实验性的,在将来的版本中可能有所改变。
tf.keras.experimental.export_saved_model(model, 'saved_model')
model = keras.experimental.load_from_saved_model('saved_model')
仅保存网络结构
config = model.get_config()
model = tf.keras.Model.from_config(config)
仅保存网络参数
weights = model.get_weights()
model.set_weights(weights)
# 可以把结构和参数保存结合起来
config = model.get_config()
weights = model.get_weights()
model = tf.keras.Model.from_config(config)
model.set_weights(weights)
仅保存网络权重
model.save_weights('weight_tf_savedmodel')
model.save_weights('weight_tf_savedmodel_h5', save_format='h5')