TensorFlow模型文件

  • Checkpoints
  • Frozen Graph
  • SavedModel
  • HDF5

模型导出包含网络结构参数,可能是分别导出或者整合为一个独立的文件:

  • 参数和网络结构分开保存:Checkpoints、SavedModel
  • 只保存权重:HDF5
  • 参数和权重保存在一个文件:Frozen Graph、HDF5

Checkpoint

1、组成
  1. """
  2. model.ckpt-13000表示前缀,代表第13000 steps时保存的结果,加载指定checkpoint时,仅说明前缀即可。
  3. checkpoint示意:
  4. model_checkpoint_path: "model.ckpt-16329"
  5. all_model_checkpoint_paths: "model.ckpt-13000"
  6. all_model_checkpoint_paths: "model.ckpt-14000"
  7. all_model_checkpoint_paths: "model.ckpt-15000"
  8. all_model_checkpoint_paths: "model.ckpt-16000"
  9. all_model_checkpoint_paths: "model.ckpt-16329"
  10. """
  11. # 1. 参数
  12. checkpoint # 表示该目录下保存的所有的checkpoint列表
  13. model.ckpt-13000.index # 表示参数名
  14. model.ckpt-13000.data-00000-of-00001 # 表示参数值
  15. # 2. 网络结构
  16. model.ckpt-13000.meta # 表示网络结构

2、保存
  1. ## TF-1.X
  2. # 1. 默认保存所有变量
  3. saver = tf.train.Saver()
  4. # 2. 指定需要保存的向量
  5. w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
  6. w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
  7. saver = tf.train.Saver([w1,w2])
  8. # 3. 保存最新的4个模型,每2h保存一次
  9. saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)
  10. sess = tf.Session()
  11. sess.run(tf.global_variables_initializer())
  12. # 1. 默认保存
  13. # saver.save(sess=sess, save_path='ckpt')
  14. saver.save(sess, 'ckpt')
  15. # 2. 迭代1000次后保存模型
  16. saver.save(sess, 'ckpt', global_step=1000)
  17. # 3. 迭代1000次后保存模型,网络结构不需要重复保存
  18. saver.save(sess, 'ckpt', global_step=1000, write_meta_graph=False)
  19. # TF-2.X
  20. # TF-2.x keras

3、导入
  1. # 1. 导入网络结构
  2. saver = tf.train.import_meta_graph('my_test_model-1000.meta')
  3. # 2. 加载变量
  4. with tf.Session() as sess:
  5. new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
  6. ## TODO:
  7. new_saver.restore(sess, tf.train.latest_checkpoint('./')) # 获取最新的ckpt