P26网络模型的保存与读取

/home/hcq/python/pytorch/module/01model_save

保存save

  1. import torch
  2. import torchvision
  3. from torch import nn
  4. vgg16 = torchvision.models.vgg16(pretrained=False)
  5. # 保存方式1: 模型结构+模型参数
  6. torch.save(vgg16, "../../pth/vgg16_method1.pth")
  7. # 保存方式2:(官方推荐)只保存参数,格式字典dict形式
  8. torch.save(vgg16.state_dict(), "../../pth/vgg16_method2.pth")
  9. # 陷阱
  10. class Mymodule(nn.Module):
  11. def __init__(self):
  12. super(Mymodule, self).__init__()
  13. self.conv1 = nn.Conv2d(3,64,kernel_size=3)
  14. def forward(self,x):
  15. output = self.conv1(x)
  16. return output
  17. mymodule = Mymodule()
  18. torch.save(mymodule, "../../pth/mymodule_method2.pth")
  19. print("导出完成")

加载load

  1. import torch
  2. # 方式1
  3. import torchvision.models
  4. from torch import nn
  5. model = torch.load("../../pth/vgg16_method1.pth")
  6. print(model)
  7. # VGG(
  8. # (features): Sequential(
  9. # (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  10. # (1): ReLU(inplace=True)
  11. # 方式2: 输出是个字典
  12. model = torch.load("../../pth/vgg16_method2.pth")
  13. print(model)
  14. # OrderedDict([('features.0.weight', tensor([[[[ 0.0166, -0.0475, 0.0280],
  15. # [-0.0281, -0.0843, 0.0382],
  16. # [-0.0265, -0.1130, -0.0962]],
  17. ### 加载参数
  18. vgg16 = torchvision.models.vgg16(pretrained=False)
  19. vgg16.load_state_dict(torch.load("../../pth/vgg16_method2.pth"))
  20. print(vgg16)
  21. # 陷阱(解决方法:把class Mymodule 添加进来)
  22. ## 解决方法1
  23. class Mymodule(nn.Module):
  24. def __init__(self):
  25. super(Mymodule, self).__init__()
  26. self.conv1 = nn.Conv2d(3,64,kernel_size=3)
  27. def forward(self,x):
  28. output = self.conv1(x)
  29. return output
  30. ## 解决方法2
  31. # from 01model_save import *
  32. model = torch.load("../../pth/mymodule_method2.pth")
  33. print(model) # AttributeError: Can't get attribute 'Mymodule' on <module '__main__'