方式一:保存模型结构 + 模型参数

  1. torch.save(model, path) 保存模型结构和参数

    1. import torchvision
    2. vgg16 = torchvision.models.vgg16(pretrained=False)
    3. torch.save(vgg16, './checkpoints/vgg16.pth') # 保存模型结构和参数
  2. torch.load(path) 加载模型结构和参数

    1. model = torch.load("./checkpoints/vgg16.pth") # 加载模型结构和参数
    2. print(model) # 打印模型结构

陷阱:对于自定义的网络结构,直接 load 会报错没有自定义的这个网络结构的类

  • 解决方法:将网络结构代码保存在单独的文件 model.py,在加载模型的 py 文件导入 from model import *

方式二:只保存模型参数(官方推荐)

  1. torch.save(model.state_dict(), path) 保存模型参数,不保存模型结构

    1. torch.save(vgg16.state_dict(), "./checkpoints/vgg16_2.pth")
  2. torch.load(path) 加载模型参数。需事先重新创建模型结构,再加载参数

    1. vgg16 = torchvision.models.vgg16(pretrained=False) # 重新创建模型结构
    2. vgg16.load_state_dict(torch.load("./checkpoints/vgg16_2.pth")) # 加载模型参数
    3. print(vgg16)