模型可以在训练期间和训练完成后进行保存,本文介绍 Tensorflow 2.0 中如何保存和恢复模型。

在训练期间保存模型

在训练期间保存模型可以直接从中断的地方开始训练,模型以 checkpoints 形式保存。使用 tf.keras.callbacks.ModelCheckpoint 在训练的过程中和结束时回调保存的模型。

Checkpoint 回调用法

创建一个只在训练期间保存权重的 tf.keras.callbacks.ModelCheckpoint 回调:

  1. checkpoint_path = "training_1/cp.ckpt" # 设置保存路径
  2. checkpoint_dir = os.path.dirname(checkpoint_path)
  3. # 创建一个保存模型权重的回调
  4. cp_callback = tf.keras.callbacks.ModelCheckpoint(
  5. filepath=checkpoint_path,
  6. save_weights_only=True,
  7. verbose=1
  8. )
  9. # 使用回调训练模型
  10. model.fit(
  11. train_images,
  12. train_labels,
  13. epochs=10,
  14. validation_data=(test_images,test_labels),
  15. callbacks=[cp_callback]
  16. )

Checkpoint 回调选项

回调提供了几个选项,可为 checkpoint 指定唯一名称并调整 checkpoint 的保存频率。例如,下面的代码训练了一个新模型,训练过程中每五个 epochs 保存一次模型的权重。

  1. # 在文件名中包含 epoch (使用 `str.format`)
  2. checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
  3. checkpoint_dir = os.path.dirname(checkpoint_path)
  4. # 创建一个回调
  5. cp_callback = tf.keras.callbacks.ModelCheckpoint(
  6. filepath=checkpoint_path,
  7. verbose=1,
  8. save_weights_only=True,
  9. period=5 # 每 5 个 epochs 保存一次模型的权重
  10. )
  11. # 创建一个新的模型实例
  12. model = create_model()
  13. # 使用 `checkpoint_path` 格式保存权重
  14. model.save_weights(checkpoint_path.format(epoch=0))
  15. # 使用回调训练模型
  16. model.fit(
  17. train_images,
  18. train_labels,
  19. epochs=50,
  20. callbacks=[cp_callback],
  21. validation_data=(test_images,test_labels),
  22. verbose=0
  23. )

获取最新的 checkpoint,并加载保存的权重:

  1. # 获取最新的 checkpoint
  2. latest = tf.train.latest_checkpoint(checkpoint_dir)
  3. # 创建一个新的模型实例
  4. model = create_model()
  5. # 加载以前保存的权重
  6. model.load_weights(latest)

保存整个网络模型

保存整个网络模型即保存了网络结构、优化器、权重等模型细节,可以在无需原始代码的情况下恢复整个网络。

将模型保存为 HDF5 文件

HDF5 格式保存整个网络模型,需要安装相关依赖。

  1. pip install -q pyyaml h5py # 需要以 HDF5 格式保存模型

下面为网络模型的保存与恢复:

  1. model.save("the_save_model.h5")
  2. model = tf.keras.models.load_model("the_save_model.h5")

保存网络全模型为 SavedModel 文件

注意:这种保存模型的方法是实验性的,在将来的版本中可能有所改变。

  1. tf.keras.experimental.export_saved_model(model, 'saved_model')
  2. model = keras.experimental.load_from_saved_model('saved_model')

仅保存网络结构

  1. config = model.get_config()
  2. model = tf.keras.Model.from_config(config)

仅保存网络参数

  1. weights = model.get_weights()
  2. model.set_weights(weights)
  3. # 可以把结构和参数保存结合起来
  4. config = model.get_config()
  5. weights = model.get_weights()
  6. model = tf.keras.Model.from_config(config)
  7. model.set_weights(weights)

仅保存网络权重

  1. model.save_weights('weight_tf_savedmodel')
  2. model.save_weights('weight_tf_savedmodel_h5', save_format='h5')