copy form https://zhuanlan.zhihu.com/p/333791572

模型的内容

GraphDef

GraphDef是Tensorflow中序列化的图结构。在tensorflow中,计算图被保存为Protobuf格式(pb)。pb可以只保存图的结构,也可以保存结构加权重。

SignatureDef

定义图结构输入输出的节点名称和属性,一般存储于.index文件中。
查看方法:
list(meta_graph.signature_def.items())

保存方式

tf.saved_model

将动态图保存成权重(./variables)、计算图(keras_metadata.pb)、权重和计算图(saved_model.pb)三种文件。

  1. # 保存
  2. model = tf.saved_model.save(
  3. obj, export_dir, signatures=None, options=None
  4. )
  5. # 读取
  6. model= tf.saved_model.load(
  7. export_dir, tags=None, options=None
  8. )
  9. # 推理
  10. infer = model.signatures["serving_default"]

freeze_graph

该函数将图和权重以常量的形式保存在一张静态图中(pb)。
其中的核心代码是:

  1. from tensorflow.python.tools.freeze_graph import freeze_graph_with_def_protos
  2. output_graph_def = convert_variables_to_constants(session, input_graph_def, output_names)
  3. output_graph = 'pb_model/model.pb' # 保存地址
  4. with tf.gfile.GFile(output_graph, 'wb') as f:
  5. f.write(output_graph_def.SerializeToString())

tf.train.Saver()

  1. # 保存断点
  2. saver = tf.train.Saver()
  3. saver.save()
  4. # 加载断点
  5. saver.restore()

tensorflow通过checkpoint这一种格式文件,是将模型的结构和权重数据分开保存的
image.png
在save之后,模型会保存在ckpt文件中,checkpoint文件保存了一个目录下所有的模型文件列表,events文件是给可视化工具tensorboard用的。
和保存的模型直接相关的是以下这三个文件:

  • .data文件保存了当前参数值
  • .index文件保存了当前参数名
  • .meta文件保存了当前图结构
  • .events文件是给可视化工具tensorboard使用。
  • .pbtxt文件是以字符串存储的计算图

当你使用saver.restore()载入模型时,你用的就是这一组的三个checkpoint文件

其他说明

tf.train.CheckpointManager()

CheckpointManager是一个管理断点的工具,是Saver更高级的API,类似于tensorflow.keras.callbacks中的Checkpoint类。CheckpointManager可以设置自动存点间隔步数、最大断点数、自动存点间隔时间等参数
其中,最新的断点文件名以字符串形式储存在checkpoint文件中。
可参考:
tf.train.CheckpointManager | TensorFlow Core v2.3.1tensorflow.google.cn

  1. # 设置断点
  2. checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
  3. manager = tf.train.CheckpointManager(
  4. checkpoint, directory="/tmp/model", max_to_keep=5)
  5. # 加载最新的断点
  6. status = checkpoint.restore(manager.latest_checkpoint)
  7. # 保存断点
  8. while True:
  9. # train
  10. manager.save()

查看静态图输入输出节点

可以使用Tensorflow自带工具saved_model_cli,输入的模型需要使用tf.saved_model.save或者tf.keras.models.Model实例的save属性保存的模型结构。其中需要的文件有.data(模型权重)、.index(模型的SignatureDef)和.pb(MetaGraph)。
saved_model_cli show --dir model/ --all
也可以加载静态图后,打印所有节点,逐个查看:

  1. tensor_name_list = [tensor.name for tensor in tf.compat.v1.get_default_graph().as_graph_def().node]
  2. for tensor_name in tensor_name_list:
  3. print(tensor_name,'\n')

可以将静态图保存为summary,使用TensorBoard可视化查看:
summaryWriter = tf.compat.v1.summary.FileWriter('log/', grap

自定义SignatureDef

上文说到SignatureDef是输入输出到静态图节点的映射,一般表示为字典的形式,下面是官方给的分类模型书写范例

  1. signature_def: {
  2. key : "my_classification_signature"
  3. value: {
  4. inputs: {
  5. key : "inputs"
  6. value: {
  7. name: "tf_example:0"
  8. dtype: DT_STRING
  9. tensor_shape: ...
  10. }
  11. }
  12. outputs: {
  13. key : "classes"
  14. value: {
  15. name: "index_to_string:0"
  16. dtype: DT_STRING
  17. tensor_shape: ...
  18. }
  19. }
  20. outputs: {
  21. key : "scores"
  22. value: {
  23. name: "TopKV2:0"
  24. dtype: DT_FLOAT
  25. tensor_shape: ...
  26. }
  27. }
  28. method_name: "tensorflow/serving/classify"
  29. }
  30. }

修改静态图的SignatureDef:

  1. #保存为pb模型
  2. def export_model(session, m):
  3. #只需要修改这一段,定义输入输出,其他保持默认即可
  4. model_signature = signature_def_utils.build_signature_def(
  5. inputs={"input": utils.build_tensor_info(m.a)},
  6. outputs={
  7. "output": utils.build_tensor_info(m.y)},
  8. method_name=signature_constants.PREDICT_METHOD_NAME)
  9. export_path = "pb_model/1"
  10. if os.path.exists(export_path):
  11. os.system("rm -rf "+ export_path)
  12. print("Export the model to {}".format(export_path))
  13. try:
  14. legacy_init_op = tf.group(
  15. tf.tables_initializer(), name='legacy_init_op')
  16. builder = saved_model_builder.SavedModelBuilder(export_path)
  17. builder.add_meta_graph_and_variables(
  18. session, [tag_constants.SERVING],
  19. clear_devices=True,
  20. signature_def_map={
  21. signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
  22. model_signature,
  23. },
  24. legacy_init_op=legacy_init_op)
  25. builder.save()
  26. except Exception as e:
  27. print("Fail to export saved model, exception: {}".format(e))

关于SignatureDef的编写可参考:
SignatureDefs in SavedModel for TensorFlow Serving | TFXtensorflow.google.cn

在tf中打包多个模型和函数(Synchronized)

可以用tf.function直接对一些函数和模型的操作进行封装,其中的计算会转换为tf中的图计算。需要注意的是,用这种方法进行封装,执行的时候是Synchronized的。
如果需要实现Async,请使用OpenVINO、TensorRT、OpenGL或CUDA等进行部署。

  1. @tf.function
  2. def full_model(image):
  3. x1 = func_1(image)
  4. x2 = func_2(image)
  5. return [x1,x2]
  6. full_model = full_model.get_concrete_function(tf.TensorSpec((832, 1344,3), tf.float32))
  7. frozen_func = convert_variables_to_constants_v2(full_model)
  8. frozen_func.graph.as_graph_def()
  9. layers = [op.name for op in frozen_func.graph.get_operations()]
  10. print("-" * 50)
  11. print("Frozen model layers: ")
  12. for layer in layers:
  13. print(layer)
  14. print("-" * 50)
  15. print("Frozen model inputs: ")
  16. print(frozen_func.inputs)
  17. print("Frozen model outputs: ")
  18. print(frozen_func.outputs)
  19. # Save frozen graph from frozen ConcreteFunction to hard drive
  20. tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
  21. logdir="./model",
  22. name="model.pb",
  23. as_text=False)