参考来源:
知乎:PyTorch | 保存和加载模型
原文:https://pytorch.org/tutorials/beginner/saving_loading_models.html
完整代码:https://github.com/pytorch/tutorials/blob/master/beginner_source/saving_loading_models.py
CSDN:pytorch中 model.to(device) 和 map_location=device 的区别
CSDN:Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()

简介

本文主要介绍如何加载和保存 PyTorch 的模型。这里主要有三个核心函数:

  1. torch.save:把序列化的对象保存到硬盘。它利用了 Python 的 pickle 来实现序列化。模型、张量以及字典都可以用该函数进行保存;
  2. torch.load:采用 pickle 将反序列化的对象从存储中加载进来。
  3. torch.nn.Module.load_state_dict:采用一个反序列化的 state_dict 加载一个模型的参数字典。

本文主要内容如下:

  • 什么是状态字典(state_dict)?
  • 预测时加载和保存模型
  • 加载和保存一个通用的检查点(Checkpoint)
  • 在同一个文件保存多个模型
  • 采用另一个模型的参数来预热模型(Warmstaring Model)
  • 不同设备下保存和加载模型

1. 什么是状态字典(state_dict)

PyTorch 中,一个模型(torch.nn.Module)的可学习参数(也就是权重和偏置值)是包含在模型参数(模型的参数通过 model.parameters() 获取)中的,一个状态字典(state_dict)就是一个简单的 Python 的字典,其键值对是每个网络层和其对应的参数张量。
模型的状态字典(**state_dict**)只包含带有可学习参数的网络层(比如卷积层、全连接层等)和注册的缓存(比如,**batchnorm****running_mean**)。优化器对象(**torch.optim**)同样也是有一个状态字典,包含的优化器状态的信息以及使用的超参数。
由于状态字典也是 Python 的字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都很容易实现。

1.1 torch.nn.Module.state_dict()

  1. torch.nn.Module.state_dict(destination=None, prefix='', keep_vars=False)

返回一个包含模型状态信息的字典。包含参数(权重和偏置)和持续的缓冲值(如:观测值的平均值)。只有具有可更新参数的层才会被保存在模型的 state_dict 数据结构中。

1.2 torch.optim.Optimizer.state_dict()

  1. torch.optim.Optimizer.state_dict()

返回一个包含优化器状态信息的字典。包含两个 key:

  • state:字典,保存当前优化器的状态信息。不同优化器内容不同。
  • param_groups:字典,包含所有参数组(eg:超参数)。

下面是一个简单的使用例子,例子来自:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. # 定义模型
  5. class TheModelClass(nn.Module):
  6. def __init__(self):
  7. super(TheModelClass, self).__init__()
  8. self.conv1 = nn.Conv2d(3, 6, 5)
  9. self.pool = nn.MaxPool2d(2, 2)
  10. self.conv2 = nn.Conv2d(6, 16, 5)
  11. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  12. self.fc2 = nn.Linear(120, 84)
  13. self.fc3 = nn.Linear(84, 10)
  14. def forward(self, x):
  15. x = self.pool(F.relu(self.conv1(x)))
  16. x = self.pool(F.relu(self.conv2(x)))
  17. x = x.view(-1, 16 * 5 * 5)
  18. x = F.relu(self.fc1(x))
  19. x = F.relu(self.fc2(x))
  20. x = self.fc3(x)
  21. return x
  22. # 初始化模型
  23. model = TheModelClass()
  24. # 初始化优化器
  25. optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  26. # 打印模型的 state_dict/状态字典
  27. print("Model's state_dict:")
  28. for param_tensor in model.state_dict():
  29. print(param_tensor, "\t", model.state_dict()[param_tensor].size())
  30. # 打印优化器的 state_dict/状态字典
  31. print("Optimizer's state_dict:")
  32. for var_name in optimizer.state_dict():
  33. print(var_name, "\t", optimizer.state_dict()[var_name])
  34. """
  35. 输出结果:
  36. Model's state_dict:
  37. conv1.weight torch.Size([6, 3, 5, 5])
  38. conv1.bias torch.Size([6])
  39. conv2.weight torch.Size([16, 6, 5, 5])
  40. conv2.bias torch.Size([16])
  41. fc1.weight torch.Size([120, 400])
  42. fc1.bias torch.Size([120])
  43. fc2.weight torch.Size([84, 120])
  44. fc2.bias torch.Size([84])
  45. fc3.weight torch.Size([10, 84])
  46. fc3.bias torch.Size([10])
  47. Optimizer's state_dict:
  48. state {}
  49. param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]
  50. """

2. torch.save()torch.load()torch.nn.Module.load_state_dict()

2.1 torch.save() [source]

保存一个序列化(serialized)的目标到磁盘。函数使用了 Python 的 pickle 程序用于序列化。模型(models),张量(tensors)和文件夹(dictionaries)都是可以用这个函数保存的目标类型。

  1. torch.save(obj, f, pickle_module=<module '...'>, pickle_protocol=2)

参数描述

  • obj:保存对象。
  • f:类文件对象 (必须实现写和刷新)或一个保存文件名的字符串。
  • pickle_modul:用于 pickling 元数据和对象的模块。
  • pickle_protocol:指定 pickle protocal 可以覆盖默认参数。

2.2 torch.load() [source]

用来加载模型。torch.load() 使用 Python 的 解压工具(unpickling)来反序列化 pickled object 到对应存储设备上。首先在 CPU 上对压缩对象进行反序列化并且移动到它们保存的存储设备上,如果失败了(如:由于系统中没有相应的存储设备),就会抛出一个异常。用户可以通过 register_package 进行扩展,使用自己定义的标记和反序列化方法。

  1. torch.load(f, map_location=None, pickle_module=<module 'pickle' from '...'>)

参数描述

  • f:类文件对象 (返回文件描述符)或一个保存文件名的字符串。
  • map_location:一个函数或字典规定如何映射存储设备。
  • pickle_module:用于 unpickling 元数据和对象的模块 (必须匹配序列化文件时的 pickle_module )。

例子:

  1. torch.load('tensors.pt')
  2. # Load all tensors onto the CPU
  3. torch.load('tensors.pt', map_location=torch.device('cpu'))
  4. # Load all tensors onto the CPU, using a function
  5. torch.load('tensors.pt', map_location=lambda storage, loc: storage)
  6. # Load all tensors onto GPU 1
  7. torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
  8. # Map tensors from GPU 1 to GPU 0
  9. torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
  10. # Load tensor from io.BytesIO object
  11. with open('tensor.pt') as f:
  12. buffer = io.BytesIO(f.read())
  13. torch.load(buffer)

2.3 torch.nn.Module.load_state_dict(state_dict) [source]

使用 state_dict 反序列化模型参数字典。用来加载模型参数。将 state_dict 中的 parametersbuffers 复制到此 module 及其子节点中。

  1. torch.nn.Module.load_state_dict(state_dict, strict=True)

参数描述

  • state_dict:保存 parameterspersistent buffers 的字典
  • strict:可选,bool 型。state_dict 中的 key 是否和 model.state_dict() 返回的 key 一致。
    • strict=True,要求预训练权重层数的键值与新构建的模型中的权重层数名称完全吻合;如果新构建的模型在层数上进行了部分微调,则就会报错:说 key 对应不上。
    • strict=False 表示忽略不匹配的网络层参数,新构建网络中 model.state_dict() 有与 state_dict 匹配层的键值就进行使用,没有的就默认初始化。

例子

  1. torch.save(model,'save.pt')
  2. model.load_state_dict(torch.load("save.pt")) #model.load_state_dict()函数把加载的权重复制到模型的权重中去

3. 预测时加载和保存模型

3.1 加载/保存状态字典(推荐做法)

保存:

  1. torch.save(model.state_dict(), PATH)

加载:

  1. model = TheModelClass(*args, **kwargs)
  2. model.load_state_dict(torch.load(PATH))
  3. model.eval()

当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。采用 torch.save() 来保存模型的状态字典的做法可以更方便加载模型,这也是推荐这种做法的原因。
通常会用 .pt 或者 .pth 后缀来保存模型。
记住

  1. 在进行预测之前,必须调用 model.eval() 方法来将 dropoutbatch normalization 层设置为验证模型。否则,只会生成前后不一致的预测结果。
  2. load_state_dict() 方法必须传入一个字典对象,而不是对象的保存路径,也就是说必须先反序列化字典对象,然后再调用该方法,也是例子中先采用 torch.load() ,而不是直接 model.load_state_dict(PATH)

3.2 加载/保存整个模型

保存:

  1. torch.save(model, PATH)

加载:

  1. # Model class must be defined somewhere
  2. model = torch.load(PATH)
  3. model.eval()

保存和加载模型都是采用非常直观的语法并且都只需要几行代码即可实现。这种实现保存模型的做法将是采用 Python 的 pickle 模块来保存整个模型,这种做法的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是 pickle 并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors 后采用都可能出现错误。

3. 加载和保存一个通用的检查点(Checkpoint)

保存的示例代码:

  1. torch.save({
  2. 'epoch': epoch,
  3. 'model_state_dict': model.state_dict(),
  4. 'optimizer_state_dict': optimizer.state_dict(),
  5. 'loss': loss,
  6. ...
  7. }, PATH)

加载的示例代码:

  1. model = TheModelClass(*args, **kwargs)
  2. optimizer = TheOptimizerClass(*args, **kwargs)
  3. checkpoint = torch.load(PATH)
  4. model.load_state_dict(checkpoint['model_state_dict'])
  5. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  6. epoch = checkpoint['epoch']
  7. loss = checkpoint['loss']
  8. model.eval()
  9. # - or -
  10. model.train()

当保存一个通用的检查点(checkpoint)时,无论是用于继续训练还是预测,都需要保存更多的信息,不仅仅是 state_dict ,比如说优化器的 state_dict 也是非常重要的,它包含了用于模型训练时需要更新的参数和缓存信息,还可以保存的信息包括 epoch,即中断训练的批次,最后一次的训练 loss,额外的 torch.nn.Embedding 层等等。
上述保存代码就是介绍了如何保存这么多种信息,通过用一个字典来进行组织,然后继续调用 torch.save 方法,一般保存的文件后缀名是 .tar
加载代码也如上述代码所示,首先需要初始化模型和优化器,然后加载模型时分别调用 torch.load 加载对应的 state_dict 。然后通过不同的键来获取对应的数值。
加载完后,根据后续步骤,调用 model.eval() 用于预测,调用 model.train() 用于恢复训练。

5. 在同一个文件保存多个模型

保存模型的示例代码:

  1. torch.save({
  2. 'modelA_state_dict': modelA.state_dict(),
  3. 'modelB_state_dict': modelB.state_dict(),
  4. 'optimizerA_state_dict': optimizerA.state_dict(),
  5. 'optimizerB_state_dict': optimizerB.state_dict(),
  6. ...
  7. }, PATH)

加载模型的示例代码:

  1. modelA = TheModelAClass(*args, **kwargs)
  2. modelB = TheModelBClass(*args, **kwargs)
  3. optimizerA = TheOptimizerAClass(*args, **kwargs)
  4. optimizerB = TheOptimizerBClass(*args, **kwargs)
  5. checkpoint = torch.load(PATH)
  6. modelA.load_state_dict(checkpoint['modelA_state_dict'])
  7. modelB.load_state_dict(checkpoint['modelB_state_dict'])
  8. optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
  9. optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
  10. modelA.eval()
  11. modelB.eval()
  12. # - or -
  13. modelA.train()
  14. modelB.train()

当我们希望保存的是一个包含多个网络模型 torch.nn.Modules 的时候,比如 GAN、一个序列化模型,或者多个模型融合,实现的方法其实和保存一个通用的检查点的做法是一样的,同样采用一个字典来保持模型的 **state_dict** 和对应优化器的 **state_dict**除此之外,还可以继续保存其他相同的信息。
加载模型的示例代码如上述所示,和加载一个通用的检查点也是一样的,同样需要先初始化对应的模型和优化器。同样,保存的模型文件通常是以 .tar 作为后缀名。

6. 采用另一个模型的参数来预热模型(Warmstaring Model)

保存模型的示例代码:

  1. torch.save(modelA.state_dict(), PATH)

加载模型的示例代码:

  1. modelB = TheModelBClass(*args, **kwargs)
  2. modelB.load_state_dict(torch.load(PATH), strict=False)

在迁移学习或者训练新的复杂模型时,加载部分模型是很常见的。利用经过训练的参数,即使只有少数参数可用,也将有助于预热训练过程,并且使模型更快收敛。在之前迁移学习教程中也介绍了可以通过预训练模型来微调,加快模型训练速度和提高模型的精度。
在加载部分模型参数进行预训练的时候,很可能会碰到键不匹配的情况(模型权重都是按键值对的形式保存并加载回来的)。因此,无论是缺少键还是多出键的情况,都可以通过在 load_state_dict() 函数中设定 strict 参数为 False 来忽略不匹配的键。
如果想将某一层的参数加载到其他层,但是有些键不匹配,那么修改 state_dict 中参数的 key 可以解决这个问题。在希望加载参数名不一样的参数时,也可以这样处理,修改 state_dict 中参数的 key ,这样参数名字匹配了就可以成功加载。

7. 不同设备下保存和加载模型

7.1 model.to(device)map_location=device 的区别

在已训练并保存在 CPU 上的 GPU 上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是 **model.to(device)****map_location=device** 两个参数,简介一下两者的不同。

  • torch.load() 函数中的参数 **map_location** 设置为 **cuda:device_id** 。这会将模型加载到给定的 GPU 设备。
  • 调用 **model.to(torch.device('cuda'))** 将模型的参数张量转换为**CUDA**张量,无论在cpu上训练还是 gpu 上训练,保存的模型参数都是参数张量不是 cuda 张量,因此,cpu 设备上不需要使用**torch.to(torch.device("cpu"))**

7.2 在 GPU 上保存模型,在 CPU 上加载模型

保存模型的示例代码:

  1. torch.save(model.state_dict(), PATH)

加载模型的示例代码:

  1. device = torch.device('cpu')
  2. model = TheModelClass(*args, **kwargs)
  3. model.load_state_dict(torch.load(PATH, map_location=device))

在 CPU 上加载在 GPU 上训练的模型,必须在调用 torch.load() 的时候,设置参数 map_location ,指定采用的设备是 torch.device('cpu'),此时,map_location 动态地将 tensors 的底层存储重新映射到 CPU 设备上。

上述代码只有在模型是在一块 GPU 上训练时才有效,如果模型在多个 GPU 上训练,那么在 CPU 上加载时,会得到类似如下错误:

  1. KeyError: unexpected key module.conv1.weight in state_dict

原因是在使用多 GPU 训练并保存模型时,模型的参数名都带上了 module 前缀,因此可以在加载模型时,把 key 中的这个前缀去掉:

  1. # 原始通过DataParallel保存的文件
  2. state_dict = torch.load('myfile.pth.tar')
  3. # 创建一个不包含`module.`的新OrderedDict
  4. from collections import OrderedDict
  5. new_state_dict = OrderedDict()
  6. for k, v in state_dict.items():
  7. name = k[7:] # 去掉 `module.`
  8. new_state_dict[name] = v
  9. # 加载参数
  10. model.load_state_dict(new_state_dict)

7.3 在 GPU 上保存模型,在 GPU 上加载模型

保存模型的示例代码:

  1. torch.save(model.state_dict(), PATH)

加载模型的示例代码:

  1. device = torch.device('cuda')
  2. model = TheModelClass(*args, **kwargs)
  3. model.load_state_dict(torch.load(PATH)
  4. model.to(device)
  5. # Make sure to call input = input.to(device) on any input tensors that you feed to the model/确保在您提供给模型的任何输入张量上调用 input = input.to(device)

在 GPU 上训练和加载模型,调用 torch.load() 加载模型后,还需要采用 model.to(torch.device('cuda')),将模型调用到 GPU 上,并且后续输入的张量都需要确保是在 GPU 上使用的,即也需要采用 my_tensor.to(device) 。同时确保在模型所有的输入上使用 .to(torch.device('cuda')) 。注意,调用 my_tensor.to(device) 会返回一份在 GPU 上的 my_tensor 的拷贝。不会覆盖原本的 my_tensor ,因此要记得手动将 tensor 重写:my_tensor=my_tensor.to(torch.device('cuda'))

7.4 在 CPU 上保存,在 GPU 上加载模型

保存模型的示例代码:

  1. torch.save(model.state_dict(), PATH)

加载模型的示例代码:

  1. device = torch.device("cuda")
  2. model = TheModelClass(*args, **kwargs)
  3. model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
  4. model.to(device)
  5. # Make sure to call input = input.to(device) on any input tensors that you feed to the model/确保在您提供给模型的任何输入张量上调用 input = input.to(device)

在 GPU 上加载 CPU 训练保存的模型时,将 torch.load() 函数的 map_location 参数设置为 "cuda:device_id"。这种方式将模型加载到指定设备。下一步,确保调用 model.to(torch.device(‘cuda’)) 将模型参数 tensor 转换为 cuda tensor。最后,确保模型输入使用 .to(torch.device(‘cuda’)) 为 cuda 优化模型准备数据。
注意,调用 my_tensor.to(device) 会返回一份在 GPU 上的 my_tensor 的拷贝。不会覆盖原本的 my_tensor ,因此要记得手动将 tensor 重写:my_tensor=my_tensor.to(torch.device('cuda'))

7.5 保存 torch.nn.DataParallel 模型

保存模型的示例代码:

  1. torch.save(model.module.state_dict(), PATH)

torch.nn.DataParallel 是支持模型使用 GPU 并行的封装器。要保存一个一般的 DataParallel 模型, 请保存 **model.module.state_dict()**。这种方式,可以灵活地以任何方式加载模型到任何设备上。
加载模型的代码也是一样的,采用 torch.load() ,并可以放到指定的 GPU 显卡上。