模型保存
import torch
import torchvision
from torch import nn
vgg16 = torchvision.models.vgg16(pretrained=False)
'''保存方式1:模型结构+模型参数'''
torch.save(vgg16, "vgg16_method1.pth") # 保存模型的参数和结构
'''保存方式2:模型参数(官方推荐)'''
torch.save(vgg16.state_dict(), "vgg16_method2.pth") # 只保存模型的参数
# 陷阱
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
model = Model()
torch.save(model, "model_method1.pth")
模型加载
import torch
import torchvision
from model_save import Model
'''方式1保存->加载'''
model = torch.load("vgg16_method1.pth")
print(model)
'''方式2保存->加载'''
model = torchvision.models.vgg16(pretrained=False)
model.load_state_dict(torch.load("vgg16_method2.pth"))
print(model)
# 陷阱:记得导入模型
model = torch.load("model_method1.pth")
print(model)