有时需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。
4.5.1 存取Tensor
我们可以直接使用save
函数和load
函数分别存储和读取Tensor
。
实质上是对pickle模块的一层封装
import torch
from torch import nn
# 存储
x = torch.ones(3)
torch.save(x, "x.pt")
# 读取
y = torch.load("x.pt")
print(y)
# Tensor 列表的存取
x = torch.zeros(3)
y = torch.ones(4)
torch.save([x, y], "xy_list.pt")
xy_list = torch.load("xy_list.pt")
print(xy_list)
# 存储并读取一个从字符串映射到 Tensor 的字典
torch.save({'x': x, 'y': y}, "xy_dict.pt")
xy_dict = torch.load("xy_dict.pt")
print(xy_dict)
运行结果
tensor([1., 1., 1.])
[tensor([0., 0., 0.]), tensor([1., 1., 1., 1.])]
{'x': tensor([0., 0., 0.]), 'y': tensor([1., 1., 1., 1.])}
4.5.2 存取模型
4.5.2.1 state_dict
在PyTorch中,Module
的可学习参数(即权重和偏差),模块模型包含在参数中(通过model.parameters()访问)。state_dict
是一个从参数名称隐射到参数Tesnor的字典对象。
lass MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
# state_dict是一个从参数名称隐射到参数 Tesnor 的字典对象。
net = MLP()
print(net.state_dict())
for i in net.state_dict().items():
print(i)
运行结果
OrderedDict([('hidden.weight', tensor([[-0.0567, 0.1704, 0.3017],
[ 0.5385, 0.5153, -0.3220]])), ('hidden.bias', tensor([-0.4603, -0.4822])), ('output.weight', tensor([[-0.5111, -0.7002]])), ('output.bias', tensor([-0.3096]))])
('hidden.weight', tensor([[-0.0567, 0.1704, 0.3017],
[ 0.5385, 0.5153, -0.3220]]))
('hidden.bias', tensor([-0.4603, -0.4822]))
('output.weight', tensor([[-0.5111, -0.7002]]))
('output.bias', tensor([-0.3096]))
只有具有可学习参数的层(卷积层、线性层等)才有state_dict
中的条目。优化器(optim
)也有一个state_dict
,其中包含关于优化器状态以及所使用的超参数的信息。
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
print(optimizer.state_dict())
运行结果
{'state': {}, 'param_groups': [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140603236014432, 140603236015072, 140602887169280, 140602887160128]}]}
4.5.2.2 两种存取模型的方式
存取模型有两种方式:仅存取模型参数和存取整个模型。前者是推荐用法。
# 仅存取模型参数 state_dict, 推荐
torch.save(net.state_dict(), "state_dict.pt")
model = MLP()
model.load_state_dict(torch.load("state_dict.pt"))
print(model)
# 存取整个模型, 即结构 + 参数
torch.save(net, "whole_model.pt")
model = torch.load("whole_model.pt")
print(model)
运行结果
MLP(
(hidden): Linear(in_features=3, out_features=2, bias=True)
(act): ReLU()
(output): Linear(in_features=2, out_features=1, bias=True)
)
MLP(
(hidden): Linear(in_features=3, out_features=2, bias=True)
(act): ReLU()
(output): Linear(in_features=2, out_features=1, bias=True)
)
/home/luzhan/anaconda3/lib/python3.7/site-packages/torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type MLP. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
/home/luzhan/anaconda3/lib/python3.7/site-packages/torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Linear. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
/home/luzhan/anaconda3/lib/python3.7/site-packages/torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type ReLU. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
4.5.2.3 简单实例
# 简单实践
x = torch.randn(2, 3)
y = net(x)
PATH = "./test.pt"
torch.save(net.state_dict(), PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
y2 = net2(x)
print(y2 == y)
因为模型参数相等,因此使用同一种模型实例化后的计算结果也是相等的。
tensor([[True],
[True]])
4.5.3 其他场景
例如GPU与CPU之间的模型保存与读取、使用多块GPU的模型的存储等等,使用的时候可以参考官方文档。