做深度学习时,最麻烦的一件事情就是等待,等待训练过程的结束往往要耗费大量时间,如果我们曾经训练过某个模型,下次使用时又需要重新训练一遍,为此不得不再等上好久,这部分等待时间是毫无意义的。因此训练过模型之后,及时地添加代码然后保存模型就可以节省很多时间,下一次要用到模型的时候,直接加载即可。

直接保存/加载模型本身

image.png
image.png

下面以利用内置数据集一览Pytorch实现图像识别为例,展示模型保存和模型加载的用法。

保存模型

  1. # ...模型训练完成
  2. torch.save(model, '/mnt/main/pkl_checkpoint_file/plain_CNNet.pkl') #保存速度很慢

加载模型并测试

  1. # 导入模块和加载数据集部分略
  2. model = torch.load('/mnt/main/pkl_checkpoint_file/plain_CNNet.pkl')
  3. model = model.to(device)
  4. print('--------------查看网络结构-----------')
  5. print(model)
  6. #测试模型
  7. eval_loss = 0
  8. eval_acc = 0
  9. class_correct = list(0. for i in range(10))
  10. class_total = list(0. for i in range(10))
  11. total = 0
  12. model.eval()
  13. with torch.no_grad():
  14. for img, label in test_loader:
  15. img, label = img.to(device), label.to(device)
  16. out = model(img)
  17. #计算损失值
  18. loss = criterion(out,label)
  19. eval_loss += loss.item()
  20. #计算准确率
  21. _, pred = out.max(1)
  22. #print("len(label):{}".format(len(label)))
  23. num_correct += (pred == label).sum()
  24. c = (pred == label).squeeze()
  25. acc = num_correct/len(label)
  26. eval_acc += acc
  27. total += label.size(0)
  28. #计算各类别准确率
  29. for i in range(4):
  30. class_correct[label[i]] += c[i].item()
  31. class_total[label[i]] += 1
  32. eval_losses.append(eval_loss/total)
  33. eval_acces.append(eval_acc/total)
  34. print("total:{}".format(total))
  35. print("len(test_loader):{}".format(len(test_loader)))
  36. for i in range(10):
  37. print("accuracy of {}:{}%".format(classes[i],100*class_correct[i]/class_total[i]))
  38. print("----------------")

输出结果

image.png

保存/加载模型的参数

这种只保存参数未保存结构的方法的优点就是保存、加载都很快。但是由于没有保存结构,因此加载之前还是需要先定义一下结构。加载的时候,是把参数加载到空的网络之中。

示例:(莫烦习惯用这个)

  1. torch.save(net1.state_dict(),'net_params.pkl')
  2. net2 = torchvision.models.resnet18()
  3. net2.load_state_dict(torch.load('net_params.pkl'))

下面同样以利用内置数据集一览Pytorch实现图像识别为例,展示只保存模型参数和加载模型参数的用法。

保存模型参数

  1. # ...模型训练完成
  2. torch.save(model.state_dict(),'/mnt/main/pkl_checkpoint_file/plain_CNNet_params.pkl') #保存速度极快

加载模型参数

  1. # 导入模块和加载数据集部分略
  2. #定义网络结构
  3. class CNNNet(nn.Module):
  4. def __init__(self):
  5. super(CNNNet, self).__init__()
  6. self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1)
  7. self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
  8. self.conv2 = nn.Conv2d(in_channels=16,out_channels=36,kernel_size=3,stride=1)
  9. self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
  10. #self.aap = nn.AdaptiveAvgPool2d(1)
  11. self.fc1 = nn.Linear(1296,128)
  12. self.fc2 = nn.Linear(128,10)
  13. #self.fc3 = nn.Linear(36,10)
  14. def forward(self,x):
  15. x = self.pool1(F.relu(self.conv1(x)))
  16. x = self.pool2(F.relu(self.conv2(x)))
  17. #x = self.aap(x)
  18. #x = x.view(x.shape[0],-1)
  19. #x = self.fc3(x)
  20. x = x.view(-1,36*6*6)
  21. #print("x.shape:{}".format(x.shape))
  22. x = F.relu(self.fc2(F.relu(self.fc1(x))))
  23. return x
  24. model2 = CNNNet()
  25. model2 = model2.to(device)
  26. print('--------------查看网络结构-----------')
  27. print(model2)
  28. model2.load_state_dict(torch.load('/mnt/main/pkl_checkpoint_file/plain_CNNet_params.pkl'))
  29. #测试模型
  30. eval_loss = 0
  31. eval_acc = 0
  32. class_correct = list(0. for i in range(10))
  33. class_total = list(0. for i in range(10))
  34. total = 0
  35. model2.eval()
  36. with torch.no_grad():
  37. for img, label in test_loader:
  38. img, label = img.to(device), label.to(device)
  39. out = model2(img)
  40. #计算损失值
  41. loss = criterion(out,label)
  42. eval_loss += loss.item()
  43. #计算准确率
  44. _, pred = out.max(1)
  45. #print("len(label):{}".format(len(label)))
  46. num_correct += (pred == label).sum()
  47. c = (pred == label).squeeze()
  48. acc = num_correct/len(label)
  49. eval_acc += acc
  50. total += label.size(0)
  51. #计算各类别准确率
  52. for i in range(4):
  53. class_correct[label[i]] += c[i].item()
  54. class_total[label[i]] += 1
  55. eval_losses.append(eval_loss/total)
  56. eval_acces.append(eval_acc/total)
  57. print("total:{}".format(total))
  58. print("len(test_loader):{}".format(len(test_loader)))
  59. for i in range(10):
  60. print("accuracy of {}:{}%".format(classes[i],100*class_correct[i]/class_total[i]))
  61. print("----------------")

输出结果

image.png

[

](https://blog.csdn.net/baishuiniyaonulia/article/details/100039845#:~:text=PyTorch%20%E6%8F%90%E4%BE%9B%E4%BA%86%E4%B8%A4%E7%A7%8D%20%E4%BF%9D%E5%AD%98%E8%AE%AD%E7%BB%83%20%E5%A5%BD%E7%9A%84%E6%A8%A1%E5%9E%8B%E7%9A%84%E6%96%B9%E6%B3%95%E3%80%82%20%E7%AC%AC%E4%B8%80%E7%A7%8D%E6%98%AF%E5%8F%AA%20%E4%BF%9D%E5%AD%98%20%E6%A8%A1%E5%9E%8B%E5%8F%82%E6%95%B0%EF%BC%8C%E8%BF%99%E4%B9%9F%E6%98%AF%E6%8E%A8%E8%8D%90%E7%9A%84%E6%96%B9%E6%B3%95%EF%BC%9A%23%20%E4%BF%9D%E5%AD%98,%2A%2Akwargs%29%20the_model.load_state_dict%20%28torch.load%20%28PATH%29%29%E7%AC%AC%E4%BA%8C%E7%A7%8D%E6%96%B9%E6%B3%95%20%E4%BF%9D%E5%AD%98%20%E6%95%B4%E4%B8%AA%E6%A8%A1%E5%9E%8B%EF%BC%9A%23%20%E4%BF%9D%E5%AD%98%20tor)