参考来源:
梯度剪裁: torch.nn.utils.clipgrad_norm()

梯度裁剪

当神经网络深度逐渐增加,网络参数量增多的时候,反向传播过程中链式法则里的梯度连乘项数便会增多,更易引起梯度消失和梯度爆炸。对于梯度爆炸问题,解决方法之一便是进行梯度剪裁,即设置一个梯度大小的上限。本文介绍了pytorch中梯度剪裁方法的原理和使用方法。

注:为了防止混淆,本文对神经网络中的参数称为“网络参数”,其他程序相关参数成为“参数”。

pytorch 中梯度剪裁方法为 torch.nn.utils.clipgrad_norm(parameters, max_norm, norm_type=2)。三个参数:

  • **parameters**:希望实施梯度裁剪的可迭代网络参数
  • **max_norm**:该组网络参数梯度的范数上限
  • **norm_type**:范数类型

官方对该方法的描述为:
“对一组可迭代(网络)参数的梯度范数进行裁剪。效果如同将所有参数连接成单个向量来计算范数。梯度原位修改。”

“Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.”

torch.nn.utils.clipgrad_norm() 的使用应该在 loss.backward() 之后,optimizer.step() 之前。