在前面两节中,如果不裁剪梯度,模型将无法正常训练。为了深刻理解这一现象,本节将介绍循环神经网络中梯度的计算和存储方法,即通过时间反向传播(back-propagation through time)。

我们在3.14节(正向传播、反向传播和计算图)中介绍了神经网络中梯度计算与存储的一般思路,并强调正向传播和反向传播相互依赖。正向传播在循环神经网络中比较直观,而通过时间反向传播其实是反向传播在循环神经网络中的具体应用。我们需要将循环神经网络按时间步展开,从而得到模型变量和参数之间的依赖关系,并依据链式法则应用反向传播计算并存储梯度。

6.6.1 定义模型

简单起见,我们考虑一个无偏差项的循环神经网络,且激活函数为恒等映射(6.6 通过时间反向传播 - 图1%3Dx#card=math&code=%5Cphi%28x%29%3Dx))。设时间步 6.6 通过时间反向传播 - 图2 的输入为单样本 6.6 通过时间反向传播 - 图3,标签为 6.6 通过时间反向传播 - 图4,那么隐藏状态 6.6 通过时间反向传播 - 图5的计算表达式为

6.6 通过时间反向传播 - 图6

其中6.6 通过时间反向传播 - 图76.6 通过时间反向传播 - 图8是隐藏层权重参数。设输出层权重参数6.6 通过时间反向传播 - 图9,时间步6.6 通过时间反向传播 - 图10的输出层变量6.6 通过时间反向传播 - 图11计算为

6.6 通过时间反向传播 - 图12

设时间步6.6 通过时间反向传播 - 图13的损失为6.6 通过时间反向传播 - 图14#card=math&code=%5Cell%28%5Cboldsymbol%7Bo%7D_t%2C%20y_t%29)。时间步数为6.6 通过时间反向传播 - 图15的损失函数6.6 通过时间反向传播 - 图16定义为

6.6 通过时间反向传播 - 图17.%0A#card=math&code=L%20%3D%20%5Cfrac%7B1%7D%7BT%7D%20%5Csum_%7Bt%3D1%7D%5ET%20%5Cell%20%28%5Cboldsymbol%7Bo%7D_t%2C%20y_t%29.%0A)

我们将6.6 通过时间反向传播 - 图18称为有关给定时间步的数据样本的目标函数,并在本节后续讨论中简称为目标函数。

6.6.2 模型计算图

为了可视化循环神经网络中模型变量和参数在计算中的依赖关系,我们可以绘制模型计算图,如图6.3所示。例如,时间步3的隐藏状态6.6 通过时间反向传播 - 图19的计算依赖模型参数6.6 通过时间反向传播 - 图206.6 通过时间反向传播 - 图21、上一时间步隐藏状态6.6 通过时间反向传播 - 图22以及当前时间步输入6.6 通过时间反向传播 - 图23

6.6_rnn-bptt.svg

6.6.3 方法

刚刚提到,图6.3中的模型的参数是 6.6 通过时间反向传播 - 图25, 6.6 通过时间反向传播 - 图266.6 通过时间反向传播 - 图27。与3.14节(正向传播、反向传播和计算图)中的类似,训练模型通常需要模型参数的梯度6.6 通过时间反向传播 - 图286.6 通过时间反向传播 - 图296.6 通过时间反向传播 - 图30
根据图6.3中的依赖关系,我们可以按照其中箭头所指的反方向依次计算并存储梯度。为了表述方便,我们依然采用3.14节中表达链式法则的运算符prod。

首先,目标函数有关各时间步输出层变量的梯度6.6 通过时间反向传播 - 图31很容易计算:

6.6 通过时间反向传播 - 图32%7D%7BT%20%5Ccdot%20%5Cpartial%20%5Cboldsymbol%7Bo%7D_t%7D.%0A#card=math&code=%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_t%7D%20%3D%20%20%5Cfrac%7B%5Cpartial%20%5Cell%20%28%5Cboldsymbol%7Bo%7D_t%2C%20y_t%29%7D%7BT%20%5Ccdot%20%5Cpartial%20%5Cboldsymbol%7Bo%7D_t%7D.%0A)

下面,我们可以计算目标函数有关模型参数6.6 通过时间反向传播 - 图33的梯度6.6 通过时间反向传播 - 图34。根据图6.3,6.6 通过时间反向传播 - 图35通过6.6 通过时间反向传播 - 图36依赖6.6 通过时间反向传播 - 图37。依据链式法则,

6.6 通过时间反向传播 - 图38%20%0A%3D%20%5Csum%7Bt%3D1%7D%5ET%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_t%7D%20%5Cboldsymbol%7Bh%7D_t%5E%5Ctop.%0A#card=math&code=%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7BW%7D%7Bqh%7D%7D%20%0A%3D%20%5Csum%7Bt%3D1%7D%5ET%20%5Ctext%7Bprod%7D%5Cleft%28%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_t%7D%2C%20%5Cfrac%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_t%7D%7B%5Cpartial%20%5Cboldsymbol%7BW%7D%7Bqh%7D%7D%5Cright%29%20%0A%3D%20%5Csum_%7Bt%3D1%7D%5ET%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_t%7D%20%5Cboldsymbol%7Bh%7D_t%5E%5Ctop.%0A)

其次,我们注意到隐藏状态之间也存在依赖关系。
在图6.3中,6.6 通过时间反向传播 - 图39只通过6.6 通过时间反向传播 - 图40依赖最终时间步6.6 通过时间反向传播 - 图41的隐藏状态6.6 通过时间反向传播 - 图42。因此,我们先计算目标函数有关最终时间步隐藏状态的梯度6.6 通过时间反向传播 - 图43。依据链式法则,我们得到

6.6 通过时间反向传播 - 图44%20%3D%20%5Cboldsymbol%7BW%7D%7Bqh%7D%5E%5Ctop%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_T%7D.%0A#card=math&code=%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_T%7D%20%3D%20%5Ctext%7Bprod%7D%5Cleft%28%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_T%7D%2C%20%5Cfrac%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_T%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_T%7D%20%5Cright%29%20%3D%20%5Cboldsymbol%7BW%7D%7Bqh%7D%5E%5Ctop%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_T%7D.%0A)

接下来对于时间步6.6 通过时间反向传播 - 图45, 在图6.3中,6.6 通过时间反向传播 - 图46通过6.6 通过时间反向传播 - 图476.6 通过时间反向传播 - 图48依赖6.6 通过时间反向传播 - 图49。依据链式法则,
目标函数有关时间步6.6 通过时间反向传播 - 图50的隐藏状态的梯度6.6 通过时间反向传播 - 图51需要按照时间步从大到小依次计算:

6.6 通过时间反向传播 - 图52%20%2B%20%5Ctext%7Bprod%7D%20(%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bo%7Dt%7D%2C%20%5Cfrac%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_t%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_t%7D%20)%20%3D%20%5Cboldsymbol%7BW%7D%7Bhh%7D%5E%5Ctop%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D%7Bt%2B1%7D%7D%20%2B%20%5Cboldsymbol%7BW%7D%7Bqh%7D%5E%5Ctop%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bo%7Dt%7D%0A#card=math&code=%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_t%7D%20%0A%3D%20%5Ctext%7Bprod%7D%20%28%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D%7Bt%2B1%7D%7D%2C%20%5Cfrac%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D%7Bt%2B1%7D%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_t%7D%29%20%2B%20%5Ctext%7Bprod%7D%20%28%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_t%7D%2C%20%5Cfrac%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_t%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_t%7D%20%29%20%3D%20%5Cboldsymbol%7BW%7D%7Bhh%7D%5E%5Ctop%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D%7Bt%2B1%7D%7D%20%2B%20%5Cboldsymbol%7BW%7D%7Bqh%7D%5E%5Ctop%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_t%7D%0A)

将上面的递归公式展开,对任意时间步6.6 通过时间反向传播 - 图53,我们可以得到目标函数有关隐藏状态梯度的通项公式

6.6 通过时间反向传播 - 图54%7D%5E%7BT-i%7D%20%5Cboldsymbol%7BW%7D%7Bqh%7D%5E%5Ctop%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D%7BT%2Bt-i%7D%7D.%0A#card=math&code=%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7Dt%7D%20%0A%3D%20%5Csum%7Bi%3Dt%7D%5ET%20%7B%5Cleft%28%5Cboldsymbol%7BW%7D%7Bhh%7D%5E%5Ctop%5Cright%29%7D%5E%7BT-i%7D%20%5Cboldsymbol%7BW%7D%7Bqh%7D%5E%5Ctop%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bo%7D_%7BT%2Bt-i%7D%7D.%0A)

由上式中的指数项可见,当时间步数 6.6 通过时间反向传播 - 图55 较大或者时间步 6.6 通过时间反向传播 - 图56 较小时,目标函数有关隐藏状态的梯度较容易出现衰减和爆炸。这也会影响其他包含6.6 通过时间反向传播 - 图57项的梯度,例如隐藏层中模型参数的梯度6.6 通过时间反向传播 - 图586.6 通过时间反向传播 - 图59
在图6.3中,6.6 通过时间反向传播 - 图60通过6.6 通过时间反向传播 - 图61依赖这些模型参数。
依据链式法则,我们有

6.6 通过时间反向传播 - 图62%20%0A%3D%20%5Csum%7Bt%3D1%7D%5ET%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_t%7D%20%5Cboldsymbol%7Bx%7D_t%5E%5Ctop%2C%5C%5C%0A%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7BW%7D%7Bhh%7D%7D%20%0A%26%3D%20%5Csum%7Bt%3D1%7D%5ET%20%5Ctext%7Bprod%7D%5Cleft(%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_t%7D%2C%20%5Cfrac%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_t%7D%7B%5Cpartial%20%5Cboldsymbol%7BW%7D%7Bhh%7D%7D%5Cright)%20%0A%3D%20%5Csum%7Bt%3D1%7D%5ET%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_t%7D%20%5Cboldsymbol%7Bh%7D%7Bt-1%7D%5E%5Ctop.%0A%5Cend%7Baligned%7D%0A#card=math&code=%5Cbegin%7Baligned%7D%0A%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7BW%7D%7Bhx%7D%7D%20%0A%26%3D%20%5Csum%7Bt%3D1%7D%5ET%20%5Ctext%7Bprod%7D%5Cleft%28%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7Dt%7D%2C%20%5Cfrac%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_t%7D%7B%5Cpartial%20%5Cboldsymbol%7BW%7D%7Bhx%7D%7D%5Cright%29%20%0A%3D%20%5Csum%7Bt%3D1%7D%5ET%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_t%7D%20%5Cboldsymbol%7Bx%7D_t%5E%5Ctop%2C%5C%5C%0A%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7BW%7D%7Bhh%7D%7D%20%0A%26%3D%20%5Csum%7Bt%3D1%7D%5ET%20%5Ctext%7Bprod%7D%5Cleft%28%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_t%7D%2C%20%5Cfrac%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_t%7D%7B%5Cpartial%20%5Cboldsymbol%7BW%7D%7Bhh%7D%7D%5Cright%29%20%0A%3D%20%5Csum%7Bt%3D1%7D%5ET%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cboldsymbol%7Bh%7D_t%7D%20%5Cboldsymbol%7Bh%7D%7Bt-1%7D%5E%5Ctop.%0A%5Cend%7Baligned%7D%0A)

我们已在3.14节里解释过,每次迭代中,我们在依次计算完以上各个梯度后,会将它们存储起来,从而避免重复计算。例如,由于隐藏状态梯度6.6 通过时间反向传播 - 图63被计算和存储,之后的模型参数梯度6.6 通过时间反向传播 - 图646.6 通过时间反向传播 - 图65的计算可以直接读取6.6 通过时间反向传播 - 图66的值,而无须重复计算它们。此外,反向传播中的梯度计算可能会依赖变量的当前值。它们正是通过正向传播计算出来的。
举例来说,参数梯度6.6 通过时间反向传播 - 图67的计算需要依赖隐藏状态在时间步6.6 通过时间反向传播 - 图68的当前值6.6 通过时间反向传播 - 图696.6 通过时间反向传播 - 图70是初始化得到的)。这些值是通过从输入层到输出层的正向传播计算并存储得到的。

小结

  • 通过时间反向传播是反向传播在循环神经网络中的具体应用。
  • 当总的时间步数较大或者当前时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸。

注:本节与原书基本相同,原书传送门