一 模型保存和恢复
1.1 保存为pb文件和从pb文件恢复
ckpt = tf.train.latest_checkpoint(ckpt_dir) cnn = TextCNN( w2v_model, sequence_length=30, num_classes=2, vocab_size=352218, embedding_size=300, filter_sizes=[2,3,4], num_filters=128, l2_reg_lambda=0.0) issue = tf.InteractiveSession() issue.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(issue, ckpt) constant_graph = graph_util.convert_variables_to_constants(issue, issue.graph_def, ['output/predictions', 'output/scores']) constant_graph = graph_util.remove_training_nodes(constant_graph) with tf.gfile.GFile('./pb/model.pb', mode='wb') as f: f.write(constant_graph.SerializeToString())
graph_def = tf.GraphDef() with tf.gfile.FastGFile(pb_file, 'rb') as model_f: graph_def.ParseFromString(model_f.read()) _ = tf.import_graph_def(graph_def, name='') input_x = tf.get_default_graph().get_tensor_by_name('input_x:0') dropout_prob = tf.get_default_graph().get_tensor_by_name('dropout_keep_prob:0') predictions = tf.get_default_graph().get_tensor_by_name('output/predictions:0') scores = tf.get_default_graph().get_tensor_by_name('output/scores:0') isess = tf.InteractiveSession()
1.2 ckpt的保存和恢复