tensorflow 1.4 checkopoint 转 saveModel
phvm开源项目的模型转换
config = Config.config
dataset = Dataset.EPWDataset()
model = PHVM.PHVM(len(dataset.vocab.id2featCate), len(dataset.vocab.id2featVal), len(dataset.vocab.id2word),
len(dataset.vocab.id2category),
key_wordvec=None, val_wordvec=None, tgt_wordvec=dataset.vocab.id2vec,
type_vocab_size=len(dataset.vocab.id2type))
best_checkpoint_dir = config.checkpoint_dir + "/" + args.model_name + config.best_model_dir
model_utils.restore_model(model, best_checkpoint_dir, tmp_checkpoint_dir)
# 转换的细节代码
export_dir = 'export_dir'
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(model.sess, [tf.saved_model.tag_constants.SERVING])
builder.save()
# model_uteils.py
def restore_model(model, best_checkpoint_dir):
saver = model.best_saver # best_saver 这个变量在phvm模型中定义了
latest_ckpt = tf.train.latest_checkpoint(best_checkpoint_dir)
saver.restore(model.sess, latest_ckpt)