tensorflow 1.4 checkopoint 转 saveModel

phvm开源项目的模型转换

  1. config = Config.config
  2. dataset = Dataset.EPWDataset()
  3. model = PHVM.PHVM(len(dataset.vocab.id2featCate), len(dataset.vocab.id2featVal), len(dataset.vocab.id2word),
  4. len(dataset.vocab.id2category),
  5. key_wordvec=None, val_wordvec=None, tgt_wordvec=dataset.vocab.id2vec,
  6. type_vocab_size=len(dataset.vocab.id2type))
  7. best_checkpoint_dir = config.checkpoint_dir + "/" + args.model_name + config.best_model_dir
  8. model_utils.restore_model(model, best_checkpoint_dir, tmp_checkpoint_dir)
  9. # 转换的细节代码
  10. export_dir = 'export_dir'
  11. builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
  12. builder.add_meta_graph_and_variables(model.sess, [tf.saved_model.tag_constants.SERVING])
  13. builder.save()
  14. # model_uteils.py
  15. def restore_model(model, best_checkpoint_dir):
  16. saver = model.best_saver # best_saver 这个变量在phvm模型中定义了
  17. latest_ckpt = tf.train.latest_checkpoint(best_checkpoint_dir)
  18. saver.restore(model.sess, latest_ckpt)