checkpoint 转为.pb
saver.saver() 保存得到 checkpoint 文件
在TensorFlow中模型的保存和调用,相信大家都不会陌生,使用关键语句saver = tf.train.Saver()和saver.save()就可以完成。
但是,不知道大家是否了解,tensorflow通过checkpoint这一种格式文件,是将模型的结构和权重数据分开保存的,这就造成了一些使用场景下的不方便。
所以,我们需要一种方式将模型结构和权重数据合并在一个文件中,tensorflow提供了freeze_graph函数和pb文件格式,来解决这一问题。
这些模型文件是做什么的
在save之后,模型会保存在ckpt文件中,checkpoint文件保存了一个目录下所有的模型文件列表,events文件是给可视化工具tensorboard用的。
和保存的模型直接相关的是以下这三个文件:
- .data文件保存了当前参数值
- .index文件保存了当前参数名
- .meta文件保存了当前图结构
当你使用saver.restore()载入模型时,你用的就是这一组的三个checkpoint文件。
但是,当我们需要将模型和权重整合成一个文件时,我们就需要以下的操作了。
freeze_graph生成PB文件
tensorflow提供了freeze_graph这个函数来生成pb文件。以下的代码块可以完成将checkpoint文件转换成pb文件的操作:
- 载入你的模型结构,
- 提供checkpoint文件地址
- 使用tf.train.writegraph保存图,这个图会提供给freeze_graph使用
- 使用freeze_graph生成pb文件 ```python import tensorflow as tf from tensorflow.python.tools import freeze_graph
network是你自己定义的模型
import network
模型的checkpoint文件地址
ckpt_path = “./ckpt_model/model-20190403-164504.ckpt-205000”
def main(): tf.reset_default_graph()
x = tf.placeholder(tf.float32, shape=[None, 224, 224, 3], name='input')
# flow是模型的输出
flow = network(x)
# 设置输出类型以及输出的接口名字,为了之后的调用pb的时候使用
flow = tf.cast(flow, tf.int8, 'out')
with tf.Session() as sess:
# 保存图,在./pb_model文件夹中生成model.pb文件
# model.pb文件将作为input_graph给到接下来的freeze_graph函数
tf.train.write_graph(sess.graph_def, './pb_model', 'model.pb')
#把图和参数结构一起
freeze_graph.freeze_graph(
input_graph='./pb_model/model.pb',
input_saver='',
input_binary=False,
input_checkpoint=ckpt_path,
output_node_names='out',
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
output_graph='./pb_model/frozen_model.pb',
clear_devices=False,
initializer_nodes=''
)
print("done")
if name == ‘main‘:
main()
```
在以上的程序运行之后,./pb_model/文件夹中就会出现frozen_model.pb文件,这是我们可以使用的模型结构和权重整合过的pb文件。
freeze_graph总共有11个参数,以下逐一介绍下,供大家参考:
- input_graph:模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分。我们的例子中,使用了二进制的pb文件,对应input_binary就是False
- input_saver:Saver解析器,主要用于版本不兼容时使用。通常为空,为空时用当前版本的Saver
- input_binary:配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认值是False
- input_checkpoint:checkpoint文件地址
- output_node_names:输出节点的名字,有多个时用逗号分开,我们的输出节点是’out’,这是我们使用flow = tf.cast(flow, tf.int8, ‘out’)将模型的输出节点命名为out。如果没有这一步的操作,我们可以找到模型的输出节点名是什么,并且在这一参数中对应。
- restore_op_name:从模型恢复节点的名字,一般使用默认:save/restore_all
- filename_tensor_name:一般使用默认:save/Const:0
- output_graph:用来保存整合后的模型输出文件,即pb文件的保存地址
- clear_devices:指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认),默认True
- initializer_nodes:默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。
- variable_names_blacklist:默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。