17.model_save

将之前下载好的vgg16_pretrained保存

保存方式1,保存模型结构+模型参数

  1. # 保存方式1 → 保存:模型结构+模型参数
  2. torch.save(vgg16_pretrained,'vgg16_pretrained.pth')

可以看到vgg16_pretrained.pth被保存在当前路径下

然后在当前路径下建一个model_load.py,

  1. import torch
  2. # 保存方式1 → 加载模型
  3. model = torch.load("vgg16_pretrained.pth")
  4. print(model)

可以查看加载的模型以及pretrained的参数也成功加载

保存方式2,保存模型参数

  1. # 保存方式2 → 保存:模型参数
  2. # 以字典的方式保存模型参数 (官方推荐) 内存更小
  3. torch.save(vgg16_pretrained.state_dict(),"vgg16_pretrained_moethod2.pth")

保存在当前目录下

在model_load.py中

  1. model = torch.load("vgg16_pretrained_method2.pth")
  2. print(model)

可以查看仅以字典的方式保存的模型参数

下面通过下载模型结构,再将以字典方式保存的模型参数放入结构中

  1. import torch
  2. import torchvision
  3. vgg16 = torchvision.models.vgg16(pretrained=False) #下载 模型结构
  4. vgg16.load_state_dict(torch.load("vgg16_pretrained_method2.pth")) # 将模型参数放入模型结构
  5. print(vgg16)

实际项目中,方式1小tips

在model_save.py中写入

  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torch.nn import Conv2d
  5. class DEMO(nn.Module):
  6. def __init__(self):
  7. super(DEMO,self).__init__()
  8. self.conv1 = Conv2d(in_channels=3, out_channels=3, kernel_size= 3, stride=1, padding=0)
  9. def forward(self, x):
  10. x = self.conv1(x)
  11. return x
  12. demo = DEMO()
  13. torch.save(demo,"demo1.pth")

model_load.py中写入

  1. # tips
  2. # 自己写的网络,在加载时要将网络结构(这个class)复制过来
  3. # 或在开头加上:from 17.model_save import *
  4. # 实际中文件名不能带有数字17. 这里只做示例 *代表代入这个py文件中所有的class, 也可以只导入一个class
  5. # 否则会报错
  6. class DEMO(nn.Module):
  7. def __init__(self):
  8. super(DEMO,self).__init__()
  9. self.conv1 = Conv2d(in_channels=3, out_channels=3, kernel_size= 3, stride=1, padding=0)
  10. def forward(self, x):
  11. x = self.conv1(x)
  12. return x
  13. model = torch.load("demo1.pth")
  14. print(model)

————————————————————————————————————————————

  1. # 17.model_save.py
  2. import torch
  3. import torchvision
  4. from torch import nn
  5. #vgg16_pretrained = torchvision.models.vgg16(pretrained=True)
  6. # 保存方式1 → 保存:模型结构+模型参数
  7. #torch.save(vgg16_pretrained,'vgg16_pretrained.pth')
  8. # 保存方式2 → 保存:模型参数
  9. # 以字典的方式保存模型参数 (官方推荐) 内存更小
  10. #torch.save(vgg16_pretrained.state_dict(),"vgg16_pretrained_method2.pth")
  11. # tips 在实际项目中,自己写的网络
  12. from torch.nn import Conv2d
  13. class DEMO(nn.Module):
  14. def __init__(self):
  15. super(DEMO,self).__init__()
  16. self.conv1 = Conv2d(in_channels=3, out_channels=3, kernel_size= 3, stride=1, padding=0)
  17. def forward(self, x):
  18. x = self.conv1(x)
  19. return x
  20. demo = DEMO()
  21. torch.save(demo,"demo1.pth")
  1. # 17.model_load.py
  2. import torch
  3. import torchvision
  4. # from model_save import *
  5. # 保存方式1 → 加载 模型结构 + 模型参数
  6. #model = torch.load("vgg16_pretrained.pth")
  7. #print(model)
  8. # 保存方式2 → 加载 模型参数\
  9. from torch import nn
  10. from torch.nn import Conv2d
  11. '''
  12. 以字典的形式保存参数
  13. #model = torch.load("vgg16_pretrained_method2.pth")
  14. #print(model)
  15. '''
  16. '''
  17. 下载无参数模型结构,放入保存的模型参数
  18. #vgg16 = torchvision.models.vgg16(pretrained=False) #下载 模型结构
  19. #vgg16.load_state_dict(torch.load("vgg16_pretrained_method2.pth")) # 将模型参数放入模型结构
  20. #print(vgg16)
  21. '''
  22. # tips
  23. # 自己写的网络,在加载时要将网络结构(这个class)复制过来
  24. # 或在开头加上:from 17.model_save import *
  25. # 实际中文件名不能带有数字17. 这里只做示例 *代表代入这个py文件中所有的class, 也可以只导入一个
  26. # 否则会报错
  27. class DEMO(nn.Module):
  28. def __init__(self):
  29. super(DEMO,self).__init__()
  30. self.conv1 = Conv2d(in_channels=3, out_channels=3, kernel_size= 3, stride=1, padding=0)
  31. def forward(self, x):
  32. x = self.conv1(x)
  33. return x
  34. model = torch.load("demo1.pth")
  35. print(model)