序列化语义

保存模型的推荐方法

序列化和恢复模型有两种主要方法。

第一个(推荐)只保存和加载模型参数:

  1. torch.save(the_model.state_dict(), PATH)

然后:

  1. the_model = TheModelClass(*args, **kwargs)
  2. the_mdel.load_state_dict(torch.load(PATH))

第二个方法是保存并加载整个模型:

  1. torch.save(the_model, PATH)

然后:

  1. the_model = torch.load(PATH)

然而,在这种情况下,序列化的数据会与特定的类结构和准确的目录结构相绑定,所以在其他项目中使用或经大量重构之后,这些结构可能会以各种方式被破坏。