一、Python模块 & data

  1. %matplotlib inline
  2. %config InlineBackend.figure_format = 'retina'
  3. import matplotlib.pyplot as plt
  4. import torch
  5. from torch import nn
  6. from torch import optim
  7. import torch.nn.functional as F
  8. from torchvision import datasets, transforms
  9. #自定义模块
  10. import helper
  11. import fc_model
  12. #Define a transform to normalize the data
  13. transform = transforms.Compose([transforms.ToTensor(),
  14. transforms.Normalize((0.5, ), (0.5, ))])
  15. #Download and load the training data
  16. trainset = datasets.FashioniNIST('~/.pytorch/F_MNIST_data', download=True,
  17. train=True, transform=transform)
  18. trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
  19. #Download and load the test data
  20. testset = datasets.FashioniNIST ('~/.pytorch/F_MNIST_data', download=True,
  21. train=False, transform=transform)
  22. testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

二、建立模型 & 训练

  1. #建立模型 自定义模块fc_model
  2. model = fc_model.Network(784, 10, [512, 256, 128])
  3. criterion = nn.NLLLoss()
  4. optimizer = optim.Adam(model.parameters(), lr=0.001)
  5. #训练模型
  6. fc_model.tranin(model, trainloader, testloader, criterion, optimizer, epochs=2)
  • print(model)PyTorch 专题 | 输出-模型存储与加载 - 图1
  • print(model.state_dict().keys())PyTorch 专题 | 输出-模型存储与加载 - 图2

三、存储/加载模型

1. 存储模型(参数)

字典checkpoint:保存记录维度的信息

  1. 网络结构
  • input
  • output
  • hidden layers
  • .state_dict() 参数(weights, bias)
  1. checkpoint = {'input_size': 784,
  2. 'output_size': 10,
  3. 'hidden_layers': [each.out_features for each in model.hidden_layers],
  4. 'state_dict': model.state_dict()}
  5. torch.save(checkpoint, 'checkpoint.pth')}

PyTorch 专题 | 输出-模型存储与加载 - 图3
注意:属性in_features, out_features

2. 加载模型(参数)

加载模型的参数必须与存储好的模型一致,否则加载错误

  1. def load_checkpoint(filepath):
  2. checkpoint = torch.load(filepath)
  3. model = fc_model.Network(checkpoint['input_size'],
  4. checkpoint['output_size'],
  5. checkpoint['hidden_layers']) #.out_features提取维度信息
  6. model.load_state_dict(checkpoint['state_dict'])
  7. return model
  8. #加载模型
  9. model = load_checkpoint('checkpoint.pth')
  10. print (model)

PyTorch 专题 | 输出-模型存储与加载 - 图4