Pytorch保存模型

尽量参考Pytorch官方给出的save和load model的方法https://pytorch.org/tutorials/beginner/saving_loading_models.html

但是如果你希望使用的是这种保存方法
Save:

  1. torch.save(model, PATH)

Load:

  1. # Model class must be defined somewhere
  2. model = torch.load(PATH)
  3. model.eval()

你必须要注意到,这个所谓的保存整个模型的save方式,它是:

  1. 你在这个load的代码中也需要在load之前按照你预测用的模型同样的方式定义好同样的model类,它并不是保存一个完整的独立的模型类型能够给你load之后直接使用
  2. 对于pytorch来说,torch.save所保存的,真的仅仅是模型本身,是那些有weight bias等等的方法和它们的参数值,并不包括中间的矩阵操作过程,对于模型中间的对于数据矩阵你自定义的操作过程,pytorch是调用你代码中之前定义的model类型中所做的操作。也就是说你的pred代码中的model必须和你save model时的代码完全一致,不然会出现各种奇奇怪怪的问题。

总结

因此,综上,终于理解为什么官方更推荐save state_dict的方式,因为save entire model这个操作根本就不时真的entire model,它能够save的应该只是那些你print(model)时能打印出来的部分。