模型的保存与加载

介绍

本实验主要讲解了在不同环境下, 如何利用PyTorch对模型进行加载的过程. 在本实验中我们将学到torch.save(), torch.load() 和 torch.nn.Module().load_state_dict()的作用以及他们的使用方式

知识点:

  • 完整模型的保存
  • 模型参数的保存
  • 模型的加载

模型的保存与加载

模型训练的实质就是优化模型中的参数, 使模型损失最小的过程. 而模型保存其实也有两种方式, 一种是直接保存整个模型, 另一种就是保存模型的参数.

接下来我们建立一个简单的全连接网络模型

  1. import torch
  2. import torch.nn as nn
  3. class Model(nn.Module):
  4. def __init__(self, n_input_features):
  5. super(Model, self).__init__()
  6. self.linear = nn.Linear(n_input_features, 1)
  7. def forward(self, x):
  8. y_pred = torch.sigmoid(self.linear(x))
  9. return y_pred
  10. # 输入层为6个神经元节点
  11. model = Model(n_input_features=6)
  12. model

假设这个模型已经训练完毕,我们得到了此时模型的参数,那么此时我们应该怎么保存该模型呢?

整个模型的保存与加载

我们可以将整个模型直接进行保存, 使用torch.save(model, FILE)即可, 其中model为模型的变量名, FILE为想要保存保存的文件路径

  1. FILE = "model.pt"
  2. torch.save(model, FILE)
  3. print("保存成功")

接下来让我们使用torch.load(FILE)来对模型进行加载

  1. # 由于模型中已经有了结构和参数,因此我们可以直接用一个新的变量接它即可
  2. loaded_model = torch.load(FILE)
  3. # 再展示之前,必须需要告诉模型现在在做模型评估,避免模型自动梯度下降
  4. loaded_model.eval()
  5. loaded_model

我们可以使用model.parameters() 查看保存前后模型参数是否发生变化

  1. # 保存前
  2. print("保存前:")
  3. for param in model.parameters():
  4. print(param)
  5. print("=====================================")
  6. # 加载后
  7. print("保存后:")
  8. for param in loaded_model.parameters():
  9. print(param)

从结果可以很清楚的看到, 从本地加载模型完全是原来模型的复制, 也就是说torch.save(model,FILE) 函数可以很完整的保存模型

模型参数的保存与加载

由于模型除了参数之外还存在模型结构等内容, 保存整个模型的文件一般都会比只保存模型参数的文件大得多.因此, 我们在训练过程中都会选择只保存模型参数

我们可以使用model.state_dict()将模型参数转为字典对象,即每层网络结构的参数分开.

  1. print(model.state_dict())

模型参数的保存, 其实就是对上面这种字典对象的保存. 我们可以使用torch.save(model.state_dict(),FILE) 对模型参数进行保存

  1. FILE = "model.pt"
  2. torch.save(model.state_dict(), FILE)
  3. print("保存成功")

由于此时我们只保存了模型的参数, 因此在加载模型时,我们需要提前指定模型的网络结构. 如果指定的网络结构和我们定义的模型参数不匹配,则会报错

  1. # 指定网络结构
  2. loaded_model = Model(n_input_features=6)
  3. # 加载参数
  4. dicts = torch.load(FILE)
  5. loaded_model.load_state_dict(dicts)
  6. print(loaded_model.state_dict())

GPU和CPU

由于 GPU 和 CPU 的训练模型方式不同,因此保存下来的模型也存在不同。为此,面对不同环境下训练出来的模型,我们的加载方式也存在细微的差别。

由于一般的模型保存都只是保存参数,因此下面的所有代码都是以保存和加载模型参数为例。

如果保存模型在 GPU 上,加载模型在 CPU 上,那么我们的保存与加载的代码应该如下:

  1. # Save on GPU
  2. device = torch.device("cuda")
  3. model.to(device)
  4. torch.save(model.state_dict(), PATH)
  5. # Load on CPU
  6. device = torch.device('cpu')
  7. model = Model(*args, **kwargs)
  8. model.load_state_dict(torch.load(PATH, map_location=device))

如果模型保存在GPU上, 加载模型的GPU上,那我们保存与加载的代码如下:

  1. device = torch.device("cuda")
  2. model.to(device)
  3. torch.save(model.state_dict(), PATH)
  4. model = Model(*args, **kwargs)
  5. model.load_state_dict(torch.load(PATH))
  6. model.to(device)

如果保存模型在 CPU 上,加载模型在 GPU 上,那么我们的保存与加载的代码应该如下

  1. torch.save(model.state_dict(), PATH)
  2. device = torch.device("cuda")
  3. model = Model(*args, **kwargs)
  4. # 选择加载到哪一个GPU设备上
  5. model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
  6. model.to(device)

如果保存模型在 CPU 上,加载模型在 CPU 上,那么我们的保存与加载的代码应该如下:

  1. torch.save(model.state_dict(), PATH)
  2. model = Model(*args, **kwargs)
  3. model.load_state_dict(torch.load(PATH))

小结

本节我们主要学习如何保存模型