checkpoint 转为.pb

saver.saver() 保存得到 checkpoint 文件

在TensorFlow中模型的保存和调用,相信大家都不会陌生,使用关键语句saver = tf.train.Saver()和saver.save()就可以完成。
但是,不知道大家是否了解,tensorflow通过checkpoint这一种格式文件,是将模型的结构和权重数据分开保存的,这就造成了一些使用场景下的不方便。
所以,我们需要一种方式将模型结构和权重数据合并在一个文件中,tensorflow提供了freeze_graph函数和pb文件格式,来解决这一问题。
这些模型文件是做什么的
tensorflow模型checkpoint转为.pb - 图1
在save之后,模型会保存在ckpt文件中,checkpoint文件保存了一个目录下所有的模型文件列表,events文件是给可视化工具tensorboard用的。
和保存的模型直接相关的是以下这三个文件:

  • .data文件保存了当前参数值
  • .index文件保存了当前参数名
  • .meta文件保存了当前图结构

当你使用saver.restore()载入模型时,你用的就是这一组的三个checkpoint文件。
但是,当我们需要将模型和权重整合成一个文件时,我们就需要以下的操作了。

freeze_graph生成PB文件

tensorflow提供了freeze_graph这个函数来生成pb文件。以下的代码块可以完成将checkpoint文件转换成pb文件的操作:

  1. 载入你的模型结构,
  2. 提供checkpoint文件地址
  3. 使用tf.train.writegraph保存图,这个图会提供给freeze_graph使用
  4. 使用freeze_graph生成pb文件 ```python import tensorflow as tf from tensorflow.python.tools import freeze_graph

network是你自己定义的模型

import network

模型的checkpoint文件地址

ckpt_path = “./ckpt_model/model-20190403-164504.ckpt-205000”

def main(): tf.reset_default_graph()

  1. x = tf.placeholder(tf.float32, shape=[None, 224, 224, 3], name='input')
  2. # flow是模型的输出
  3. flow = network(x)
  4. # 设置输出类型以及输出的接口名字,为了之后的调用pb的时候使用
  5. flow = tf.cast(flow, tf.int8, 'out')
  6. with tf.Session() as sess:
  7. # 保存图,在./pb_model文件夹中生成model.pb文件
  8. # model.pb文件将作为input_graph给到接下来的freeze_graph函数
  9. tf.train.write_graph(sess.graph_def, './pb_model', 'model.pb')
  10. #把图和参数结构一起
  11. freeze_graph.freeze_graph(
  12. input_graph='./pb_model/model.pb',
  13. input_saver='',
  14. input_binary=False,
  15. input_checkpoint=ckpt_path,
  16. output_node_names='out',
  17. restore_op_name='save/restore_all',
  18. filename_tensor_name='save/Const:0',
  19. output_graph='./pb_model/frozen_model.pb',
  20. clear_devices=False,
  21. initializer_nodes=''
  22. )
  23. print("done")

if name == ‘main‘: main() ``` 在以上的程序运行之后,./pb_model/文件夹中就会出现frozen_model.pb文件,这是我们可以使用的模型结构和权重整合过的pb文件。
freeze_graph总共有11个参数,以下逐一介绍下,供大家参考:

  1. input_graph:模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分。我们的例子中,使用了二进制的pb文件,对应input_binary就是False
  2. input_saver:Saver解析器,主要用于版本不兼容时使用。通常为空,为空时用当前版本的Saver
  3. input_binary:配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认值是False
  4. input_checkpoint:checkpoint文件地址
  5. output_node_names:输出节点的名字,有多个时用逗号分开,我们的输出节点是’out’,这是我们使用flow = tf.cast(flow, tf.int8, ‘out’)将模型的输出节点命名为out。如果没有这一步的操作,我们可以找到模型的输出节点名是什么,并且在这一参数中对应。
  6. restore_op_name:从模型恢复节点的名字,一般使用默认:save/restore_all
  7. filename_tensor_name:一般使用默认:save/Const:0
  8. output_graph:用来保存整合后的模型输出文件,即pb文件的保存地址
  9. clear_devices:指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认),默认True
  10. initializer_nodes:默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。
  11. variable_names_blacklist:默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。