论文出版于ICLR 2021。
image.png

image.png
image.png
image.png
整体的计算过程可以总结为五步:

  1. 使用旧的权重前向传播后损失回传获得旧权重对应的梯度
  2. 使用SAM获得对应于旧权重的调整量并更新权重
  3. 使用新权重重新前向传播后损失回传获得新权重对应的梯度
  4. 将新权重调整回旧权重,但是其梯度仍使用新的梯度
  5. 使用新梯度更新旧权重

所以该算法需要两次前向传播,对于PyTorch,最基本的流程如下:

  1. # 注意这里写的过程与参考连接中的pytorch代码并不完全一样,参考连接中将一些优化器的操作封装了起来
  2. # 这里为了明晰,将其取了出来
  3. outputs = model(data) # 使用旧权重前向传播
  4. loss = self.criterion(outputs, targets)
  5. loss.backward() # 获得旧权重的梯度
  6. if use_sam:
  7. sam.finetune_weight() # 微调旧权重
  8. optimizer.zero_grad() # 梯度归零,因为旧权重的梯度不是我们想要的
  9. model.eval() # 保证BN的统计量只在第一次前向传播中计算(具体可见参考连接中pytorch代码的readme)
  10. outputs = model(data) # 使用新权重前向传播
  11. loss = self.criterion(outputs, targets)
  12. loss.backward() # 获得新权重的梯度
  13. sam.restore_weight() # 1. 恢复旧权重,2. 保留新权重的梯度
  14. model.train() # 恢复BN的状态
  15. optimizer.step() # 更新权重为w^{SAM}_{t+1}
  16. optimizer.zero_grad() # 梯度归零,准备下一次迭代

效果

暂未获得明确的具有性能提升的效果。不知道是否使用的有问题。