1.保存模型与加载

  1. # 保存整个网络
  2. torch.save(net, PATH)
  3. # 保存网络中的参数, 速度快,占空间少
  4. torch.save(net.state_dict(),PATH)
  5. #--------------------------------------------------
  6. #针对上面一般的保存方法,加载的方法分别是:
  7. model_dict=torch.load(PATH)
  8. model_dict=model.load_state_dict(torch.load(PATH))

然而,在实验中往往需要保存更多的信息,比如优化器的参数,那么可以采取下面的方法保存:

  1. torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,
  2. 'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
  3. checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')

以上包含的信息有,epochID, state_dict, min loss, optimizer, 自定义损失函数的两个参数;格式以字典的格式存储。

加载的方式:

  1. def load_checkpoint(model, checkpoint_PATH, optimizer):
  2. if checkpoint != None:
  3. model_CKPT = torch.load(checkpoint_PATH)
  4. model.load_state_dict(model_CKPT['state_dict'])
  5. print('loading checkpoint!')
  6. optimizer.load_state_dict(model_CKPT['optimizer'])
  7. return model, optimizer

其他的参数可以通过以字典的方式获得

但是,但是,我们可能修改了一部分网络,比如加了一些,删除一些,等等,那么需要过滤这些参数,加载方式:

  1. def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
  2. if checkpoint != 'No':
  3. print("loading checkpoint...")
  4. model_dict = model.state_dict()
  5. modelCheckpoint = torch.load(checkpoint)
  6. pretrained_dict = modelCheckpoint['state_dict']
  7. # 过滤操作
  8. new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
  9. model_dict.update(new_dict)
  10. # 打印出来,更新了多少的参数
  11. print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
  12. model.load_state_dict(model_dict)
  13. print("loaded finished!")
  14. # 如果不需要更新优化器那么设置为false
  15. if loadOptimizer == True:
  16. optimizer.load_state_dict(modelCheckpoint['optimizer'])
  17. print('loaded! optimizer')
  18. else:
  19. print('not loaded optimizer')
  20. else:
  21. print('No checkpoint is included')
  22. return model, optimizer

2. 冻结部分参数,训练另一部分参数

1)添加下面一句话到模型中

  1. for p in self.parameters():
  2. p.requires_grad = False

比如加载了resnet预训练模型之后,在resenet的基础上连接了新的模快,resenet模块那部分可以先暂时冻结不更新,只更新其他部分的参数,那么可以在下面加入上面那句话

  1. class RESNET_MF(nn.Module):
  2. def __init__(self, model, pretrained):
  3. super(RESNET_MF, self).__init__()
  4. self.resnet = model(pretrained)
  5. for p in self.parameters():
  6. p.requires_grad = False
  7. self.f = SpectralNorm(nn.Conv2d(2048, 512, 1))
  8. self.g = SpectralNorm(nn.Conv2d(2048, 512, 1))
  9. self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1))
  10. ...


同时在优化器中添加:filter(lambda p: p.requires_grad, model.parameters())

  1. optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, betas=(0.9, 0.999),
  2. eps=1e-08, weight_decay=1e-5)

2) 参数保存在有序的字典中,那么可以通过查找参数的名字对应的id值,进行冻结

查找的代码:

  1. model_dict = torch.load('net.pth.tar').state_dict()
  2. dict_name = list(model_dict)
  3. for i, p in enumerate(dict_name):
  4. print(i, p)

保存一下这个文件,可以看到大致是这个样子的:

  1. 0 gamma
  2. 1 resnet.conv1.weight
  3. 2 resnet.bn1.weight
  4. 3 resnet.bn1.bias
  5. 4 resnet.bn1.running_mean
  6. 5 resnet.bn1.running_var
  7. 6 resnet.layer1.0.conv1.weight
  8. 7 resnet.layer1.0.bn1.weight
  9. 8 resnet.layer1.0.bn1.bias
  10. 9 resnet.layer1.0.bn1.running_mean
  11. ....

同样在模型中添加这样的代码:

  1. for i,p in enumerate(net.parameters()):
  2. if i < 165:
  3. p.requires_grad = False

在优化器中添加上面的那句话可以实现参数的屏蔽