有时需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。

4.5.1 存取Tensor

我们可以直接使用save函数和load函数分别存储和读取Tensor
实质上是对pickle模块的一层封装

  1. import torch
  2. from torch import nn
  3. # 存储
  4. x = torch.ones(3)
  5. torch.save(x, "x.pt")
  6. # 读取
  7. y = torch.load("x.pt")
  8. print(y)
  9. # Tensor 列表的存取
  10. x = torch.zeros(3)
  11. y = torch.ones(4)
  12. torch.save([x, y], "xy_list.pt")
  13. xy_list = torch.load("xy_list.pt")
  14. print(xy_list)
  15. # 存储并读取一个从字符串映射到 Tensor 的字典
  16. torch.save({'x': x, 'y': y}, "xy_dict.pt")
  17. xy_dict = torch.load("xy_dict.pt")
  18. print(xy_dict)

运行结果

  1. tensor([1., 1., 1.])
  2. [tensor([0., 0., 0.]), tensor([1., 1., 1., 1.])]
  3. {'x': tensor([0., 0., 0.]), 'y': tensor([1., 1., 1., 1.])}

存取Tensor.py

4.5.2 存取模型

4.5.2.1 state_dict

在PyTorch中,Module的可学习参数(即权重和偏差),模块模型包含在参数中(通过model.parameters()访问)。state_dict是一个从参数名称隐射到参数Tesnor的字典对象。

  1. lass MLP(nn.Module):
  2. def __init__(self):
  3. super(MLP, self).__init__()
  4. self.hidden = nn.Linear(3, 2)
  5. self.act = nn.ReLU()
  6. self.output = nn.Linear(2, 1)
  7. def forward(self, x):
  8. a = self.act(self.hidden(x))
  9. return self.output(a)
  10. # state_dict是一个从参数名称隐射到参数 Tesnor 的字典对象。
  11. net = MLP()
  12. print(net.state_dict())
  13. for i in net.state_dict().items():
  14. print(i)

运行结果

  1. OrderedDict([('hidden.weight', tensor([[-0.0567, 0.1704, 0.3017],
  2. [ 0.5385, 0.5153, -0.3220]])), ('hidden.bias', tensor([-0.4603, -0.4822])), ('output.weight', tensor([[-0.5111, -0.7002]])), ('output.bias', tensor([-0.3096]))])
  3. ('hidden.weight', tensor([[-0.0567, 0.1704, 0.3017],
  4. [ 0.5385, 0.5153, -0.3220]]))
  5. ('hidden.bias', tensor([-0.4603, -0.4822]))
  6. ('output.weight', tensor([[-0.5111, -0.7002]]))
  7. ('output.bias', tensor([-0.3096]))

只有具有可学习参数的层(卷积层、线性层等)才有state_dict中的条目。优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。

  1. optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  2. print(optimizer.state_dict())

运行结果

  1. {'state': {}, 'param_groups': [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140603236014432, 140603236015072, 140602887169280, 140602887160128]}]}

4.5.2.2 两种存取模型的方式

存取模型有两种方式:仅存取模型参数和存取整个模型。前者是推荐用法。

  1. # 仅存取模型参数 state_dict, 推荐
  2. torch.save(net.state_dict(), "state_dict.pt")
  3. model = MLP()
  4. model.load_state_dict(torch.load("state_dict.pt"))
  5. print(model)
  6. # 存取整个模型, 即结构 + 参数
  7. torch.save(net, "whole_model.pt")
  8. model = torch.load("whole_model.pt")
  9. print(model)

运行结果

  1. MLP(
  2. (hidden): Linear(in_features=3, out_features=2, bias=True)
  3. (act): ReLU()
  4. (output): Linear(in_features=2, out_features=1, bias=True)
  5. )
  6. MLP(
  7. (hidden): Linear(in_features=3, out_features=2, bias=True)
  8. (act): ReLU()
  9. (output): Linear(in_features=2, out_features=1, bias=True)
  10. )
  11. /home/luzhan/anaconda3/lib/python3.7/site-packages/torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type MLP. It won't be checked for correctness upon loading.
  12. "type " + obj.__name__ + ". It won't be checked "
  13. /home/luzhan/anaconda3/lib/python3.7/site-packages/torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Linear. It won't be checked for correctness upon loading.
  14. "type " + obj.__name__ + ". It won't be checked "
  15. /home/luzhan/anaconda3/lib/python3.7/site-packages/torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type ReLU. It won't be checked for correctness upon loading.
  16. "type " + obj.__name__ + ". It won't be checked "

4.5.2.3 简单实例

  1. # 简单实践
  2. x = torch.randn(2, 3)
  3. y = net(x)
  4. PATH = "./test.pt"
  5. torch.save(net.state_dict(), PATH)
  6. net2 = MLP()
  7. net2.load_state_dict(torch.load(PATH))
  8. y2 = net2(x)
  9. print(y2 == y)

因为模型参数相等,因此使用同一种模型实例化后的计算结果也是相等的。

  1. tensor([[True],
  2. [True]])

4.5.3 其他场景

例如GPU与CPU之间的模型保存与读取、使用多块GPU的模型的存储等等,使用的时候可以参考官方文档。