一、Python模块 & data
%matplotlib inline%config InlineBackend.figure_format = 'retina'import matplotlib.pyplot as pltimport torchfrom torch import nnfrom torch import optimimport torch.nn.functional as Ffrom torchvision import datasets, transforms#自定义模块import helperimport fc_model#Define a transform to normalize the datatransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, ), (0.5, ))])#Download and load the training datatrainset = datasets.FashioniNIST('~/.pytorch/F_MNIST_data', download=True,train=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)#Download and load the test datatestset = datasets.FashioniNIST ('~/.pytorch/F_MNIST_data', download=True,train=False, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
二、建立模型 & 训练
#建立模型 自定义模块fc_modelmodel = fc_model.Network(784, 10, [512, 256, 128])criterion = nn.NLLLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)#训练模型fc_model.tranin(model, trainloader, testloader, criterion, optimizer, epochs=2)
print(model)
print(model.state_dict().keys())
三、存储/加载模型
1. 存储模型(参数)
字典checkpoint:保存记录维度的信息
- 网络结构
- input
- output
- hidden layers
.state_dict()参数(weights, bias)
checkpoint = {'input_size': 784,'output_size': 10,'hidden_layers': [each.out_features for each in model.hidden_layers],'state_dict': model.state_dict()}torch.save(checkpoint, 'checkpoint.pth')}

注意:属性in_features, out_features
2. 加载模型(参数)
加载模型的参数必须与存储好的模型一致,否则加载错误
def load_checkpoint(filepath):checkpoint = torch.load(filepath)model = fc_model.Network(checkpoint['input_size'],checkpoint['output_size'],checkpoint['hidden_layers']) #.out_features提取维度信息model.load_state_dict(checkpoint['state_dict'])return model#加载模型model = load_checkpoint('checkpoint.pth')print (model)

