模型保存
import torchimport torchvisionfrom torch import nnvgg16 = torchvision.models.vgg16(pretrained=False)'''保存方式1:模型结构+模型参数'''torch.save(vgg16, "vgg16_method1.pth") # 保存模型的参数和结构'''保存方式2:模型参数(官方推荐)'''torch.save(vgg16.state_dict(), "vgg16_method2.pth") # 只保存模型的参数# 陷阱class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3) def forward(self, x): x = self.conv1(x) return xmodel = Model()torch.save(model, "model_method1.pth")
模型加载
import torchimport torchvisionfrom model_save import Model'''方式1保存->加载'''model = torch.load("vgg16_method1.pth")print(model)'''方式2保存->加载'''model = torchvision.models.vgg16(pretrained=False)model.load_state_dict(torch.load("vgg16_method2.pth"))print(model)# 陷阱:记得导入模型model = torch.load("model_method1.pth")print(model)