一 模型保存和恢复

1.1 保存为pb文件和从pb文件恢复

  1. ckpt = tf.train.latest_checkpoint(ckpt_dir)
  2. cnn = TextCNN(
  3. w2v_model,
  4. sequence_length=30,
  5. num_classes=2,
  6. vocab_size=352218,
  7. embedding_size=300,
  8. filter_sizes=[2,3,4],
  9. num_filters=128,
  10. l2_reg_lambda=0.0)
  11. issue = tf.InteractiveSession()
  12. issue.run(tf.global_variables_initializer())
  13. saver = tf.train.Saver()
  14. saver.restore(issue, ckpt)
  15. constant_graph = graph_util.convert_variables_to_constants(issue, issue.graph_def, ['output/predictions', 'output/scores'])
  16. constant_graph = graph_util.remove_training_nodes(constant_graph)
  17. with tf.gfile.GFile('./pb/model.pb', mode='wb') as f:
  18. f.write(constant_graph.SerializeToString())
  1. graph_def = tf.GraphDef()
  2. with tf.gfile.FastGFile(pb_file, 'rb') as model_f:
  3. graph_def.ParseFromString(model_f.read())
  4. _ = tf.import_graph_def(graph_def, name='')
  5. input_x = tf.get_default_graph().get_tensor_by_name('input_x:0')
  6. dropout_prob = tf.get_default_graph().get_tensor_by_name('dropout_keep_prob:0')
  7. predictions = tf.get_default_graph().get_tensor_by_name('output/predictions:0')
  8. scores = tf.get_default_graph().get_tensor_by_name('output/scores:0')
  9. isess = tf.InteractiveSession()

1.2 ckpt的保存和恢复