PyTorch
Pytorch 保存和加载模型后缀:.pt 和.pth

1、torch.save()

保存一个序列化(serialized)的目标到磁盘。函数使用了Python的pickle程序用于序列化。
模型(models),张量(tensors)和文件夹(dictionaries)都是可以用这个函数保存的目标类型。

  1. torch.save(obj, f, pickle_module=<module '...'>, pickle_protocol=2)

参数描述:

  • obj:保存对象
  • f:类文件对象 (必须实现写和刷新)或一个保存文件名的字符串
  • pickle_module:用于 pickling 元数据和对象的模块
  • pickle_protocol:指定 pickle protocol 可以覆盖默认参数

    案例:

  • 保存整个模型

    1. torch.save(model,'save.pt')
  • 只保存训练好的权重

    1. torch.save(model.state_dict(), 'save.pt')

    2、torch.load()

    用来加载模型torch.load() 使用 Python 的 解压工具(unpickling)来反序列化 pickled object 到对应存储设备上
    首先在 CPU 上对压缩对象进行反序列化并且移动到它们保存的存储设备上,如果失败了(如:由于系统中没有相应的存储设备),就会抛出一个异常。用户可以通过 register_package 进行扩展,使用自己定义的标记和反序列化方法。

    1. torch.load(f, map_location=None, pickle_module=<module 'pickle' from '...'>)

    参数描述:

  • **f**:类文件对象 (返回文件描述符)或一个保存文件名的字符串

  • **map_location**:一个函数或字典规定如何映射存储设备
  • **pickle_module**:用于 unpickling 元数据和对象的模块 (必须匹配序列化文件时的 pickle_module )

    案例:

    ```python torch.load(‘tensors.pt’)

Load all tensors onto the CPU

torch.load(‘tensors.pt’, map_location=torch.device(‘cpu’))

Load all tensors onto the CPU, using a function

torch.load(‘tensors.pt’, map_location=lambda storage, loc: storage)

Load all tensors onto GPU 1

torch.load(‘tensors.pt’, map_location=lambda storage, loc: storage.cuda(1))

Map tensors from GPU 1 to GPU 0

torch.load(‘tensors.pt’, map_location={‘cuda:1’:’cuda:0’})

Load tensor from io.BytesIO object

with open(‘tensor.pt’) as f: buffer = io.BytesIO(f.read()) torch.load(buffer)

  1. <a name="Krtoy"></a>
  2. ## 3、`torch.nn.Module.load_state_dict(state_dict)`
  3. **使用 state_dict 反序列化模型参数字典**。用来加载模型参数。将 state_dict 中的 parameters 和 buffers **复制到此 module 及其子节点中**。
  4. ```python
  5. torch.nn.Module.load_state_dict(state_dict, strict=True)

参数描述:

  • **state_dict**:保存 parameters 和 persistent buffers 的字典
  • **strict**:可选,bool型。state_dict 中的 key 是否和 model.state_dict() 返回的 key 一致。

    案例:

    1. torch.save(model,'save.pt')
    2. model.load_state_dict(torch.load("save.pt"))
    3. #model.load_state_dict()函数把加载的权重复制到模型的权重中去

    什么是state_dict?

    在PyTorch中,一个torch.nn.Module模型中的可学习参数(比如weights和biases),模型的参数通过**model.parameters()**获取。而state_dict就是一个简单的Python dictionary,其功能是将层与层的参数张量一一映射
    注意,只包含了可学习参数(卷积层、线性层等)的层和已注册的命令(registered buffers,比如batchnorm的running_mean)才有模型的state_dict入口。
    优化方法目标(torch.optim)也有state_dict,其中包含的是关于优化器状态的信息和使用到的超参数。
    因为state_dict目标是Python dictionaries,所以它们可以很轻松地实现保存、更新、变化和再存储,从而给PyTorch模型和优化器增加了大量的模块化(modularity)。

    1) torch.nn.Module.state_dict

    1. torch.nn.Module.state_dict(destination=None, prefix='', keep_vars=False)

    返回一个包含模型状态信息的字典。包含参数(weighs and biases)和持续的缓冲值(如:观测值的平均值)。只有具有可更新参数的层才会被保存在模型的 state_dict 数据结构中。
    案例:

    1. module.state_dict().keys()
    2. # ['bias', 'weight']

    2) torch.optim.Optimizer.state_dict

    1. torch.optim.Optimizer.state_dict()

    返回一个包含优化器状态信息的字典。包含两个 key:

  • **state**:字典,保存当前优化器的状态信息。不同优化器内容不同。

  • **param_groups**:字典,包含所有参数组(如:超参数)。

案例:

  1. from __future__ import print_function, division
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torch.optim import lr_scheduler
  6. import numpy as np
  7. import torchvision
  8. from torchvision import datasets, models, transforms
  9. import matplotlib.pyplot as plt
  10. import time
  11. import os
  12. import copy
  13. # 定义模型
  14. class TheModelClass(nn.Module):
  15. def __init__(self):
  16. super(TheModelClass, self).__init__()
  17. self.conv1 = nn.Conv2d(3, 6, 5)
  18. self.pool = nn.MaxPool2d(2, 2)
  19. self.conv2 = nn.Conv2d(6, 16, 5)
  20. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  21. self.fc2 = nn.Linear(120, 84)
  22. self.fc3 = nn.Linear(84, 10)
  23. def forward(self, x):
  24. x = self.pool(F.relu(self.conv1(x)))
  25. x = self.pool(F.relu(self.conv2(x)))
  26. x = x.view(-1, 16 * 5 * 5)
  27. x = F.relu(self.fc1(x))
  28. x = F.relu(self.fc2(x))
  29. x = self.fc3(x)
  30. return x
  31. # 初始化模型
  32. model = TheModelClass()
  33. # 初始化优化器
  34. optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  35. # 打印模型的 state_dict
  36. print("Model's state_dict:")
  37. for param_tensor in model.state_dict():
  38. print(param_tensor, "\t", model.state_dict()[param_tensor].size())
  39. # 打印优化器的 state_dict
  40. print("Optimizer's state_dict:")
  41. for var_name in optimizer.state_dict():
  42. print(var_name, "\t", optimizer.state_dict()[var_name])

输出:

  1. Model's state_dict:
  2. conv1.weight torch.Size([6, 3, 5, 5])
  3. conv1.bias torch.Size([6])
  4. conv2.weight torch.Size([16, 6, 5, 5])
  5. conv2.bias torch.Size([16])
  6. fc1.weight torch.Size([120, 400])
  7. fc1.bias torch.Size([120])
  8. fc2.weight torch.Size([84, 120])
  9. fc2.bias torch.Size([84])
  10. fc3.weight torch.Size([10, 84])
  11. fc3.bias torch.Size([10])
  12. Optimizer's state_dict:
  13. state {}
  14. param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

4、保存/加载

4.1 state_dict(推荐)

  1. # 保存
  2. torch.save(model.state_dict(), PATH)
  3. # 加载
  4. model = TheModelClass(*args, **kwargs)
  5. model.load_state_dict(torch.load(PATH))
  6. model.eval()

4.2 整个模型

  1. # 保存
  2. torch.save(model, PATH)
  3. # 加载
  4. # 模型类必须在别的地方定义
  5. model = torch.load(PATH)
  6. model.eval()

这种保存/加载模型的过程使用了最直观的语法,所用代码量少。这使用Python的pickle保存所有模块。
这种方法的缺点是,保存模型的时候,序列化的数据被绑定到了特定的类和确切的目录。这是因为pickle不保存模型类本身,而是保存这个类的路径,并且在加载的时候会使用。因此,当在其他项目里使用或者重构的时候,加载模型的时候会出错
一般来说,PyTorch的模型以.pt或者.pth文件格式保存。
一定要记住在评估模式的时候调用**model.eval()**来固定dropout和批次归一化。否则会产生不一致的推理结果。

4.3 保存加载用于推理的常规Checkpoint/或继续训练

  1. # 保存
  2. torch.save({
  3. 'epoch': epoch,
  4. 'model_state_dict': model.state_dict(),
  5. 'optimizer_state_dict': optimizer.state_dict(),
  6. 'loss': loss,
  7. ...
  8. }, PATH)
  9. # 加载
  10. model = TheModelClass(*args, **kwargs)
  11. optimizer = TheOptimizerClass(*args, **kwargs)
  12. checkpoint = torch.load(PATH)
  13. model.load_state_dict(checkpoint['model_state_dict'])
  14. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  15. epoch = checkpoint['epoch']
  16. loss = checkpoint['loss']
  17. model.eval()
  18. # - 或者 -
  19. model.train()

在保存用于推理或者继续训练的常规检查点的时候,除了模型的**state_dict**之外,还必须保存其他参数保存优化器的state_dict也非常重要,因为它包含了模型在训练时候优化器的缓存和参数。除此之外,还可以保存停止训练时epoch数,最新的模型损失,额外的torch.nn.Embedding层等。
要保存多个组件,则将它们放到一个字典中,然后使用**torch.save()**序列化这个字典。一般来说,使用.tar文件格式来保存这些检查点
加载各个组件,首先初始化模型和优化器,然后使用**torch.load()**加载保存的字典,然后可以直接查询字典中的值来获取保存的组件
同样,评估模型的时候一定不要忘了调用**model.eval()**

4.4 保存多个模型到一个文件

  1. # 保存
  2. torch.save({
  3. 'modelA_state_dict': modelA.state_dict(),
  4. 'modelB_state_dict': modelB.state_dict(),
  5. 'optimizerA_state_dict': optimizerA.state_dict(),
  6. 'optimizerB_state_dict': optimizerB.state_dict(),
  7. ...
  8. }, PATH)
  9. # 加载
  10. modelA = TheModelAClass(*args, **kwargs)
  11. modelB = TheModelBClass(*args, **kwargs)
  12. optimizerA = TheOptimizerAClass(*args, **kwargs)
  13. optimizerB = TheOptimizerBClass(*args, **kwargs)
  14. checkpoint = torch.load(PATH)
  15. modelA.load_state_dict(checkpoint['modelA_state_dict'])
  16. modelB.load_state_dict(checkpoint['modelB_state_dict'])
  17. optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
  18. optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
  19. modelA.eval()
  20. modelB.eval()
  21. # - 或者 -
  22. modelA.train()
  23. modelB.train()

4.5 使用其他模型来预热当前模型

  1. # 保存
  2. torch.save(modelA.state_dict(), PATH)
  3. # 加载
  4. modelB = TheModelBClass(*args, **kwargs)
  5. modelB.load_state_dict(torch.load(PATH), strict=False)

迁移学习或者训练新的复杂模型时,加载部分模型是很常见的。利用经过训练的参数,即使只有少数参数可用,也将有助于预热训练过程,并且使模型更快收敛
在加载部分模型参数进行预训练的时候,很可能会碰到键不匹配的情况(模型权重都是按键值对的形式保存并加载回来的)。因此,无论是缺少键还是多出键的情况,都可以通过在load_state_dict()函数中设定strict参数为False来忽略不匹配的键
如果想将某一层的参数加载到其他层,但是有些键不匹配,那么修改state_dict中参数的key可以解决这个问题。

4.6 跨设备保存与加载模型

1) GPU上保存,CPU上加载

  1. # 保存
  2. torch.save(model.state_dict(), PATH)
  3. # 加载
  4. device = torch.device('cpu')
  5. model = TheModelClass(*args, **kwargs)
  6. model.load_state_dict(torch.load(PATH, map_location=device))

当在CPU上加载一个GPU上训练的模型时,在torch.load()中指定**map_location=torch.device('cpu')**,此时,map_location动态地将tensors的底层存储重新映射到CPU设备上。
上述代码只有在模型是在一块GPU上训练时才有效,如果模型在多个GPU上训练,那么在CPU上加载时,会得到类似如下错误

KeyError: ‘unexpected key “module.conv1.weight” in state_dict’

原因是在使用多GPU训练并保存模型时,模型的参数名都带上了module前缀,因此可以在加载模型时,把key中的这个前缀去掉

  1. # 原始通过DataParallel保存的文件
  2. state_dict = torch.load('myfile.pth.tar')
  3. # 创建一个不包含`module.`的新OrderedDict
  4. from collections import OrderedDict
  5. new_state_dict = OrderedDict()
  6. for k, v in state_dict.items():
  7. name = k[7:] # 去掉 `module.`
  8. new_state_dict[name] = v
  9. # 加载参数
  10. model.load_state_dict(new_state_dict)

2) GPU上保存,GPU上加载

  1. # 保存
  2. torch.save(model.state_dict(), PATH)
  3. # 加载
  4. device = torch.device("cuda")
  5. model = TheModelClass(*args, **kwargs)
  6. model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # 选择希望使用的GPU
  7. model.to(device)

在把GPU上训练的模型加载到GPU上时,只需要使用**model.to(torch.devie('cuda'))**将初始化的模型转换为CUDA优化模型。同时确保在模型所有的输入上使用**.to(torch.device('cuda'))**
注意,调用my_tensor.to(device)返回一份在GPU上的my_tensor的拷贝不会覆盖原本的my_tensor,因此要记得手动将tensor重写my_tensor = my_tensor.to(torch.device('cuda'))

3) CPU上保存,GPU上加载

  1. # 保存
  2. torch.save(model.state_dict(), PATH)
  3. # 加载
  4. device = torch.device("cuda")
  5. model = TheModelClass(*args, **kwargs)
  6. model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
  7. # Choose whatever GPU device number you want
  8. model.to(device)
  9. # Make sure to call input = input.to(device) on any input tensors that you feed to the model

在 GPU 上加载 CPU 训练保存的模型时,将 torch.load() 函数的 map_location 参数设置为 **cuda:device_id**。这种方式将模型加载到指定设备。下一步,确保调用 **model.to(torch.device('cuda'))** 将模型参数 tensor 转换为 cuda tensor。最后,确保模型输入使用 **.to(torch.device('cuda'))** 为 cuda 优化模型准备数据。
注意:调用 my_tensor.to(device) 会在 GPU 上返回 my_tensor 的新副本,不会覆盖 my_tensor。因此,使用 my_tensor = my_tensor.to(torch.device('cuda')) 手动覆盖

4.7 保存torch.nn.DataParallel模型

  1. # 保存
  2. torch.save(model.state_dict(), PATH)
  3. # 加载
  4. # Load to whatever device you want

torch.nn.DataParallel 是支持模型使用 GPU 并行的封装器。要保存一个一般的 DataParallel 模型, 请保存 **model.module.state_dict()**。这种方式,可以灵活地以任何方式加载模型到任何设备上。