模型保存

  1. import torch
  2. import torchvision
  3. from torch import nn
  4. vgg16 = torchvision.models.vgg16(pretrained=False)
  5. '''保存方式1:模型结构+模型参数'''
  6. torch.save(vgg16, "vgg16_method1.pth") # 保存模型的参数和结构
  7. '''保存方式2:模型参数(官方推荐)'''
  8. torch.save(vgg16.state_dict(), "vgg16_method2.pth") # 只保存模型的参数
  9. # 陷阱
  10. class Model(nn.Module):
  11. def __init__(self):
  12. super(Model, self).__init__()
  13. self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
  14. def forward(self, x):
  15. x = self.conv1(x)
  16. return x
  17. model = Model()
  18. torch.save(model, "model_method1.pth")

模型加载

  1. import torch
  2. import torchvision
  3. from model_save import Model
  4. '''方式1保存->加载'''
  5. model = torch.load("vgg16_method1.pth")
  6. print(model)
  7. '''方式2保存->加载'''
  8. model = torchvision.models.vgg16(pretrained=False)
  9. model.load_state_dict(torch.load("vgg16_method2.pth"))
  10. print(model)
  11. # 陷阱:记得导入模型
  12. model = torch.load("model_method1.pth")
  13. print(model)