17.model_save
将之前下载好的vgg16_pretrained保存
保存方式1,保存模型结构+模型参数
# 保存方式1 → 保存:模型结构+模型参数torch.save(vgg16_pretrained,'vgg16_pretrained.pth')
可以看到vgg16_pretrained.pth被保存在当前路径下
然后在当前路径下建一个model_load.py,
import torch# 保存方式1 → 加载模型model = torch.load("vgg16_pretrained.pth")print(model)
可以查看加载的模型以及pretrained的参数也成功加载
保存方式2,保存模型参数
# 保存方式2 → 保存:模型参数# 以字典的方式保存模型参数 (官方推荐) 内存更小torch.save(vgg16_pretrained.state_dict(),"vgg16_pretrained_moethod2.pth")
保存在当前目录下
在model_load.py中
model = torch.load("vgg16_pretrained_method2.pth")print(model)
可以查看仅以字典的方式保存的模型参数
下面通过下载模型结构,再将以字典方式保存的模型参数放入结构中
import torchimport torchvisionvgg16 = torchvision.models.vgg16(pretrained=False) #下载 模型结构vgg16.load_state_dict(torch.load("vgg16_pretrained_method2.pth")) # 将模型参数放入模型结构print(vgg16)
实际项目中,方式1小tips
在model_save.py中写入
import torchimport torchvisionfrom torch import nnfrom torch.nn import Conv2dclass DEMO(nn.Module):def __init__(self):super(DEMO,self).__init__()self.conv1 = Conv2d(in_channels=3, out_channels=3, kernel_size= 3, stride=1, padding=0)def forward(self, x):x = self.conv1(x)return xdemo = DEMO()torch.save(demo,"demo1.pth")
model_load.py中写入
# tips# 自己写的网络,在加载时要将网络结构(这个class)复制过来# 或在开头加上:from 17.model_save import *# 实际中文件名不能带有数字17. 这里只做示例 *代表代入这个py文件中所有的class, 也可以只导入一个class# 否则会报错class DEMO(nn.Module):def __init__(self):super(DEMO,self).__init__()self.conv1 = Conv2d(in_channels=3, out_channels=3, kernel_size= 3, stride=1, padding=0)def forward(self, x):x = self.conv1(x)return xmodel = torch.load("demo1.pth")print(model)
————————————————————————————————————————————
# 17.model_save.pyimport torchimport torchvisionfrom torch import nn#vgg16_pretrained = torchvision.models.vgg16(pretrained=True)# 保存方式1 → 保存:模型结构+模型参数#torch.save(vgg16_pretrained,'vgg16_pretrained.pth')# 保存方式2 → 保存:模型参数# 以字典的方式保存模型参数 (官方推荐) 内存更小#torch.save(vgg16_pretrained.state_dict(),"vgg16_pretrained_method2.pth")# tips 在实际项目中,自己写的网络from torch.nn import Conv2dclass DEMO(nn.Module):def __init__(self):super(DEMO,self).__init__()self.conv1 = Conv2d(in_channels=3, out_channels=3, kernel_size= 3, stride=1, padding=0)def forward(self, x):x = self.conv1(x)return xdemo = DEMO()torch.save(demo,"demo1.pth")
# 17.model_load.pyimport torchimport torchvision# from model_save import *# 保存方式1 → 加载 模型结构 + 模型参数#model = torch.load("vgg16_pretrained.pth")#print(model)# 保存方式2 → 加载 模型参数\from torch import nnfrom torch.nn import Conv2d'''以字典的形式保存参数#model = torch.load("vgg16_pretrained_method2.pth")#print(model)''''''下载无参数模型结构,放入保存的模型参数#vgg16 = torchvision.models.vgg16(pretrained=False) #下载 模型结构#vgg16.load_state_dict(torch.load("vgg16_pretrained_method2.pth")) # 将模型参数放入模型结构#print(vgg16)'''# tips# 自己写的网络,在加载时要将网络结构(这个class)复制过来# 或在开头加上:from 17.model_save import *# 实际中文件名不能带有数字17. 这里只做示例 *代表代入这个py文件中所有的class, 也可以只导入一个# 否则会报错class DEMO(nn.Module):def __init__(self):super(DEMO,self).__init__()self.conv1 = Conv2d(in_channels=3, out_channels=3, kernel_size= 3, stride=1, padding=0)def forward(self, x):x = self.conv1(x)return xmodel = torch.load("demo1.pth")print(model)
