参考来源:
梯度剪裁: torch.nn.utils.clipgrad_norm()
梯度裁剪
当神经网络深度逐渐增加,网络参数量增多的时候,反向传播过程中链式法则里的梯度连乘项数便会增多,更易引起梯度消失和梯度爆炸。对于梯度爆炸问题,解决方法之一便是进行梯度剪裁,即设置一个梯度大小的上限。本文介绍了pytorch中梯度剪裁方法的原理和使用方法。
注:为了防止混淆,本文对神经网络中的参数称为“网络参数”,其他程序相关参数成为“参数”。
pytorch 中梯度剪裁方法为 torch.nn.utils.clipgrad_norm(parameters, max_norm, norm_type=2)
。三个参数:
**parameters**
:希望实施梯度裁剪的可迭代网络参数**max_norm**
:该组网络参数梯度的范数上限**norm_type**
:范数类型
官方对该方法的描述为:
“对一组可迭代(网络)参数的梯度范数进行裁剪。效果如同将所有参数连接成单个向量来计算范数。梯度原位修改。”
“Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.”
torch.nn.utils.clipgrad_norm()
的使用应该在 loss.backward()
之后,optimizer.step()
之前。