序列化与反序列化

序列化与反序列化一般描述的是内存硬盘之间的数据转换关系。
内存中,模型中的数据是以对象的形式存储的,而在硬盘中是以二进制序列的形式进行保存的。

所以序列化就是将内存中模型的对象以二进制的形式存储在硬盘中;反序列化就是将二进制序列转化成内存中的对象从而可以使用模型。
序列化是为了能够长久地保存模型。

torch.save

主要参数

  • obj:对象
  • f:输出路径

    torch.load

    主要参数:

  • f:文件路径

  • map_location:指定存放位置,cpu or gpu。

    • 因为gpu模式下训练保存的模型不能采用default加载进来。如果是cpu就不用关心

      模型保存与加载的两种方式

      法1:保存整个Module

      torch.save(net,path)
      参数:
  • net:保存的网络

  • path:保存路径

因为这里保存的是整个Module(包括8个有序字典),所以对应的加载方式就是直接将Module加载出来即可:

  1. path_model = "./model.pkl"
  2. net_load = torch.load(path_model)

就加载好了net

法2:保存模型参数(官方推荐)

state_dict=net.state_dict()
torch.save(state_dict,path)
参数

  • state_dict:保存的对象,这里放的就是通过net.state_dict()获得的可学习参数
  • path:保存路径

因为这里只保存了可学习参数,所以对应的加载方式为:

  1. path_state_dict = "./model_state_dict.pkl"
  2. state_dict_load = torch.load(path_state_dict) # 这里只加载了可学习参数
  3. net_new = LeNet2(classes=2019) # 新建一个网络
  4. net_new.load_state_dict(state_dict_load) # 这里将可学习参数加载进模型中

模型断点续训练

image.png
首先我们知道Pytorch的模型训练流程包括以上五个模块。其中数据会变动的只有模型(可学习参数)和优化器(Momentum中之前的更新量)。

因此我们在模型保存的时候只需要保存模型和优化器中的参数。也就是构建一个如下的字典:
image.png
其中epoch记录这是第几个epoch的参数信息

代码示例

模型保存

这里每隔checkpoint_interval个epoch保存一次模型,注意写在epoch的循环里:

  1. if (epoch+1) % checkpoint_interval == 0:
  2. checkpoint = {"model_state_dict": net.state_dict(),
  3. "optimizer_state_dic": optimizer.state_dict(),
  4. "loss": loss,
  5. "epoch": epoch}
  6. path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch)
  7. torch.save(checkpoint, path_checkpoint)

断点恢复

  1. path_checkpoint = "./checkpoint_4_epoch.pkl"
  2. checkpoint = torch.load(path_checkpoint)
  3. net.load_state_dict(checkpoint['model_state_dict']) # 加载Module权重
  4. optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # 加载优化器参数
  5. start_epoch = checkpoint['epoch']
  6. scheduler.last_epoch = start_epoch # 学习率调整时记录上一个epoch
  7. # ============================ step 5/5 训练 =====================================
  8. train_curve = list()
  9. valid_curve = list()
  10. for epoch in range(start_epoch + 1, MAX_EPOCH):
  11. loss_mean = 0.
  12. correct = 0.
  13. total = 0.
  14. net.train()
  15. for i, data in enumerate(train_loader):
  16. # forward
  17. inputs, labels = data
  18. outputs = net(inputs)
  19. # backward
  20. optimizer.zero_grad()
  21. loss = criterion(outputs, labels)
  22. loss.backward()
  23. # update weights
  24. optimizer.step()
  25. # 统计分类情况
  26. _, predicted = torch.max(outputs.data, 1)
  27. total += labels.size(0)
  28. correct += (predicted == labels).squeeze().sum().numpy()
  29. # 打印训练信息
  30. loss_mean += loss.item()
  31. train_curve.append(loss.item())
  32. if (i+1) % log_interval == 0:
  33. loss_mean = loss_mean / log_interval
  34. print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
  35. epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
  36. loss_mean = 0.
  37. scheduler.step() # 更新学习率