模型的保存与加载
介绍
本实验主要讲解了在不同环境下, 如何利用PyTorch对模型进行加载的过程. 在本实验中我们将学到torch.save(), torch.load() 和 torch.nn.Module().load_state_dict()
的作用以及他们的使用方式
知识点:
- 完整模型的保存
- 模型参数的保存
- 模型的加载
模型的保存与加载
模型训练的实质就是优化模型中的参数, 使模型损失最小的过程. 而模型保存其实也有两种方式, 一种是直接保存整个模型, 另一种就是保存模型的参数.
接下来我们建立一个简单的全连接网络模型
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, n_input_features):
super(Model, self).__init__()
self.linear = nn.Linear(n_input_features, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
# 输入层为6个神经元节点
model = Model(n_input_features=6)
model
假设这个模型已经训练完毕,我们得到了此时模型的参数,那么此时我们应该怎么保存该模型呢?
整个模型的保存与加载
我们可以将整个模型直接进行保存, 使用torch.save(model, FILE)即可, 其中model为模型的变量名, FILE为想要保存保存的文件路径
FILE = "model.pt"
torch.save(model, FILE)
print("保存成功")
接下来让我们使用torch.load(FILE)
来对模型进行加载
# 由于模型中已经有了结构和参数,因此我们可以直接用一个新的变量接它即可
loaded_model = torch.load(FILE)
# 再展示之前,必须需要告诉模型现在在做模型评估,避免模型自动梯度下降
loaded_model.eval()
loaded_model
我们可以使用model.parameters() 查看保存前后模型参数是否发生变化
# 保存前
print("保存前:")
for param in model.parameters():
print(param)
print("=====================================")
# 加载后
print("保存后:")
for param in loaded_model.parameters():
print(param)
从结果可以很清楚的看到, 从本地加载模型完全是原来模型的复制, 也就是说torch.save(model,FILE) 函数可以很完整的保存模型
模型参数的保存与加载
由于模型除了参数之外还存在模型结构等内容, 保存整个模型的文件一般都会比只保存模型参数的文件大得多.因此, 我们在训练过程中都会选择只保存模型参数
我们可以使用model.state_dict()将模型参数转为字典对象,即每层网络结构的参数分开.
print(model.state_dict())
模型参数的保存, 其实就是对上面这种字典对象的保存. 我们可以使用torch.save(model.state_dict(),FILE) 对模型参数进行保存
FILE = "model.pt"
torch.save(model.state_dict(), FILE)
print("保存成功")
由于此时我们只保存了模型的参数, 因此在加载模型时,我们需要提前指定模型的网络结构. 如果指定的网络结构和我们定义的模型参数不匹配,则会报错
# 指定网络结构
loaded_model = Model(n_input_features=6)
# 加载参数
dicts = torch.load(FILE)
loaded_model.load_state_dict(dicts)
print(loaded_model.state_dict())
GPU和CPU
由于 GPU 和 CPU 的训练模型方式不同,因此保存下来的模型也存在不同。为此,面对不同环境下训练出来的模型,我们的加载方式也存在细微的差别。
由于一般的模型保存都只是保存参数,因此下面的所有代码都是以保存和加载模型参数为例。
如果保存模型在 GPU 上,加载模型在 CPU 上,那么我们的保存与加载的代码应该如下:
# Save on GPU
device = torch.device("cuda")
model.to(device)
torch.save(model.state_dict(), PATH)
# Load on CPU
device = torch.device('cpu')
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
如果模型保存在GPU上, 加载模型的GPU上,那我们保存与加载的代码如下:
device = torch.device("cuda")
model.to(device)
torch.save(model.state_dict(), PATH)
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
如果保存模型在 CPU 上,加载模型在 GPU 上,那么我们的保存与加载的代码应该如下
torch.save(model.state_dict(), PATH)
device = torch.device("cuda")
model = Model(*args, **kwargs)
# 选择加载到哪一个GPU设备上
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
model.to(device)
如果保存模型在 CPU 上,加载模型在 CPU 上,那么我们的保存与加载的代码应该如下:
torch.save(model.state_dict(), PATH)
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
小结
本节我们主要学习如何保存模型