在实际中,我们有时需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。

读写Tensor

我们可以把Tensor存储到文件当中:

  1. import torch
  2. from torch import nn
  3. x = torch.ones(3)
  4. torch.save(x, 'x.pt')

在同级目录下可以看到x.pt文件,我们可以把它再读回到内存中:

  1. x2 = torch.load('x.pt')
  2. print(x2)
  3. 结果:
  4. tensor([1., 1., 1.])

还可以存储一个Tensor列表并读回内存:

  1. y = torch.zeros(4)
  2. torch.save([x, y], 'xy.pt')
  3. xy_list = torch.load('xy.pt')
  4. print(xy_list)
  5. 结果:
  6. [tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]

存储并读取一个从字符串映射到Tensor的字典:

  1. torch.save({'x': x, 'y': y}, 'xy_dict.pt')
  2. xy = torch.load('xy_dict.pt')
  3. print(xy)
  4. 结果:
  5. {'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}

读写模型

state_dict

Module的可学习参数(即权重和偏差),模块模型包含在参数中(通过model.parameters()访问)。state_dict是一个从参数名称隐射到参数Tesnor的字典对象。

拿之前的多层感知机作为例子:

  1. class MLP(nn.Module):
  2. def __init__(self):
  3. super(MLP, self).__init__()
  4. self.hidden = nn.Linear(3, 2)
  5. self.act = nn.ReLU()
  6. self.output = nn.Linear(2, 1)
  7. def forward(self, x):
  8. a = self.act(self.hidden(x))
  9. return self.output(a)
  10. net = MLP()
  11. print(net.state_dict())
  12. optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  13. print(optimizer.state_dict())
  14. 结果:
  15. OrderedDict([('hidden.weight', tensor([[ 0.4458, -0.2268, 0.2473],
  16. [-0.3065, 0.0547, 0.0892]])), ('hidden.bias', tensor([0.0920, 0.5045])), ('output.weight', tensor([[0.0630, 0.2834]])), ('output.bias', tensor([0.5416]))])
  17. {'state': {}, 'param_groups': [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3]}]}

只有具有可学习参数的层(卷积层、线性层等)才有state_dict中的条目。优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。

保存与加载模型

PyTorch中保存和加载训练模型有两种常见的方法:

  1. 仅保存和加载模型参数(state_dict);
  2. 保存和加载整个模型。

保存和加载state_dict(推荐)

  1. # 保存
  2. torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth
  3. # 加载
  4. model = TheModelClass(*args, **kwargs)
  5. model.load_state_dict(torch.load(PATH))

特别要注意的是,在加载模型文件前先要对欲加载的模型进行实例化,然后再通过对象.load_state_dict(dict)方法载入模型,该方法不能直接传入文件名,需要先通过torch.load(path)方法导入字典。

例:

  1. # 保存
  2. torch.save(net.state_dict(), 'test.pt')
  3. model = MLP()
  4. # 载入
  5. model.load_state_dict(torch.load('test.pt'))
  6. print(model)
  7. 结果:
  8. MLP(
  9. (hidden): Linear(in_features=3, out_features=2, bias=True)
  10. (act): ReLU()
  11. (output): Linear(in_features=2, out_features=1, bias=True)
  12. )

要注意,要载入哪一个模型就尽量选择哪个模型类进行实例化,但结构完全相同的两个不同的网络类可以共享模型文件,参数或网络结构不一样的两个网络则不行。

如:

  1. class MLP1(nn.Module):
  2. def __init__(self):
  3. super(MLP, self).__init__()
  4. self.hidden = nn.Linear(3, 2)
  5. self.act = nn.ReLU()
  6. self.output = nn.Linear(2, 1)
  7. def forward(self, x):
  8. a = self.act(self.hidden(x))
  9. return self.output(a)
  10. # 保存
  11. torch.save(net.state_dict(), 'test.pt')
  12. model = MLP1()
  13. # 载入
  14. model.load_state_dict(torch.load('test.pt'))
  15. print(model)
  16. 结果:
  17. MLP1(
  18. (hidden): Linear(in_features=3, out_features=2, bias=True)
  19. (act): ReLU()
  20. (output): Linear(in_features=2, out_features=1, bias=True)
  21. )

保存和加载整个模型

  1. # 保存
  2. torch.save(model, PATH)
  3. # 加载
  4. model = torch.load(PATH)

例如我们可以看看用MLP类实例化的对象运算得到的结果与从模型文件载入生成的对象得到的结果是否相等:

  1. x = torch.rand(2, 3)
  2. torch.save(net, 'MLP.pt')
  3. model = torch.load('MLP.pt')
  4. print(model(x) == net(x))
  5. 结果:
  6. tensor([[True],
  7. [True]])

因为这netnet2都有同样的模型参数,那么对同一个输入X的计算结果将会是一样的。上面的输出也验证了这一点。 此外,还有一些其他使用场景,例如GPU与CPU之间的模型保存与读取、使用多块GPU的模型的存储等等,使用的时候可以参考官方文档