方式一:保存模型结构 + 模型参数
torch.save(model, path)
保存模型结构和参数import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
torch.save(vgg16, './checkpoints/vgg16.pth') # 保存模型结构和参数
torch.load(path)
加载模型结构和参数model = torch.load("./checkpoints/vgg16.pth") # 加载模型结构和参数
print(model) # 打印模型结构
陷阱:对于自定义的网络结构,直接 load 会报错没有自定义的这个网络结构的类
- 解决方法:将网络结构代码保存在单独的文件 model.py,在加载模型的 py 文件导入
from model import *
方式二:只保存模型参数(官方推荐)
torch.save(model.state_dict(), path)
保存模型参数,不保存模型结构torch.save(vgg16.state_dict(), "./checkpoints/vgg16_2.pth")
torch.load(path)
加载模型参数。需事先重新创建模型结构,再加载参数vgg16 = torchvision.models.vgg16(pretrained=False) # 重新创建模型结构
vgg16.load_state_dict(torch.load("./checkpoints/vgg16_2.pth")) # 加载模型参数
print(vgg16)