长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

RNN的训练有别于普通前馈神经网络,反向传播中会产生权重矩阵的连乘,使得长时间步下的微弱偏移得到放大,产生梯度消失和爆炸,正交初始化和激活函数的选择,目的都是让参数矩阵的特征值尽量维持在1附近。

LSTM是一种“曲线救国”的方式,它并不直接改变RNN权重的特征值,而是将神经元参数化,通过生成线性自循环的路径,取消了原本RNN中的权重参数W。
image.png

1、LSTM原理

RNN中对于不同时间步上的数据序列使用了同样的网络结构,只是将某个或者几个层存储起来用作下一时间步的输入,这个存储的结构就是记忆单元:
image.png
RNN结构中,依次输入序列化数据,前一个时间步的隐层会将得到的数据存放到memory单元,然后再将memory单元乘以矩阵W进入到下一个时间步的隐层。在上图中,下一个时间步的memory单元的存储的信息(Ct)就是:

LSTM - 图3

取消权重参数W,并对这个memory单元做重参数化:

  • 由一个门(gate)来参数化存储的步骤,当这个门打开时,才会将信息存放到memory单元,这个门叫做输入门(Input Gate)。

  • memory单元要不要将信息流入到下一个时间步,也由一个门控制,只有当这个门打开时,我们才会将信息输入到下一个时间步,这个门叫做输出门(Output Gate)。

memory单元并非是一个实体,完全可以嵌入到隐藏层本身,在这个操作之上又增添了两个门,用来控制输入和输出,就得到了一个基本结构:
image.png
其中,输入和输出由sigmoid函数控制,sigmoid函数的输出在[0,1],可以很好的刻画门的开启或者关闭状态,值的大小就可以表示门被开启的程度。将输入门的结果用一个函数Fi来表示,输出门的结果用Fo来表示,为了保证输入门的开启和关闭状态对输入的影响,将其直接相乘再一起进入memory,memory的状态我们用C来表示:

LSTM - 图5

这样存入到memory的值就受到了输入门的调节,当其完全关闭时,就代表着信息没有流入。接下来,在输出的时候,再次相乘输出门的结果,但是,如果希望输入门和输出门的尽可能独立一些,因为直接相乘必然会导致当输入门很小时,输出门即便很大,也不会产生多少输出,所以,使用一个函数g作用在memory的结果上,再进行输出门的控制处理:

LSTM - 图6

需要特别注意的是,如果使用sigmoid函数作为激活函数,那么当遗忘门为1时,就代表着将前一步的信息原封不动的存入到当前,这与它的名字恰好相反,也就是说,当遗忘门关闭时,它会忘记,当遗忘门打开时,它才会回忆。

整个流程就是,将当前时间步的数据乘以输入门的结果,同时前一步的记忆单元乘以遗忘门的结果,两者相加,一起乘以输出门的结果,得到下一层的输出,同时此时的记忆单元参与到下一时间步的运算。

此时得到了一个较为复杂的神经元,输入门控制了信息的流入,输出门控制了信息的流出,那么看起来的memory单元是不必要的,但是在RNN中,必然采用权重W来控制流通的信息,在LSTM中,我们并没有使用权重,而只是采用简单的相加:

LSTM - 图7

随着序列越来越长,时间步越来越大,前一步的memory会流入到下一步的memory,会使得后面的memory单元存储的数值越来越大。此时,有两种可能的后果:

  • 如果函数g也是一个带有挤压性质的激活函数,那么过大的值将会使得这个激活函数永远处于激活状态,失去了学习能力。

  • 如果函数g是ReLU类型的函数,值变得非常巨大时,会使得输出门失效,因为输出门的值再小,当它乘以一个庞大的值时,也会变的非常大。

无论是哪种情况,都在表明需要在memory单元中丢弃一些信息,LSTM的解决办法是在原本的单元中加入一个遗忘门(forget gate),它的作用是重参数化记忆单元,将记忆单元输入的信息乘以遗忘门的结果Ff,存入到记忆单元中作为信息,所以当前的信息就变为了:
image.png

可以写出公式如下:

LSTM - 图9

需要特别注意的是,如果我们使用sigmoid函数作为激活函数,那么当遗忘门为1时,就代表着将前一步的信息原封不动的存入到当前,这与它的名字恰好相反,也就是说,当遗忘门关闭时,它会忘记,当遗忘门打开时,它才会回忆。

整个流程就是,将当前时间步的数据乘以输入门的结果,同时前一步的记忆单元乘以遗忘门的结果,两者相加,一起乘以输出门的结果,得到下一层的输出,同时此时的记忆单元参与到下一时间步的运算。

2、LSTM结构

image.png

2.1、遗忘门

image.png
image.png
上图中红色框中的是 LSTM 遗忘门部分,用来判断 cell 状态 ct-1 中哪些信息应该删除。其中 σ 表示激活函数 sigmoid。输入的 ht-1 和 xt 经过 sigmoid 激活函数之后得到 ft,ft 中每一个值的范围都是 [0, 1]。ft 中的值越接近 1,表示 cell 状态 ct-1 中对应位置的值更应该记住;ft 中的值越接近 0,表示 cell 状态 ct-1 中对应位置的值更应该忘记。将 ft 与 ct-1 按位相乘 (ElementWise 相乘),即可以得到遗忘无用信息之后的 c’t-1。

2.2、输入门

image.png
image.png
上图中红色框中的是 LSTM 输入门部分,用来判断哪些新的信息应该加入到 cell 状态 c‘t-1 中。其中 σ 表示激活函数 sigmoid。输入的 ht-1 和 xt 经过 tanh 激活函数可以得到新的输入信息 (图中带波浪线的 Ct),但是这些新信息并不全是有用的,因此需要使用 ht-1 和 xt 经过 sigmoid 函数得到 it, it 表示哪些新信息是有用的。两向量相乘后的结果加到 c’t-1 中,即得到 t 时刻的 cell 状态 ct。

2.3、输出门

image.png
image.png
上图中红色框中的是 LSTM 输出门部分,用来判断应该输出哪些信息到 ht 中。cell 状态 ct 经过 tanh 函数得到可以输出的信息,然后 ht-1 和 xt 经过 sigmoid 函数得到一个向量 ot,ot 的每一维的范围都是 [0, 1],表示哪些位置的输出应该去掉,哪些应该保留。两向量相乘后的结果就是最终的 ht。

3、LSTM 缓解梯度消失、梯度爆炸

在上一节中我们知道,RNN 中出现梯度消失的原因主要是梯度函数中包含一个连乘项,如果能够把连乘项去掉就可以克服梯度消失问题。如何去掉连乘项呢?我们可以通过使连乘项约等于 0 或者约等于 1,从而去除连乘项。
image.png

LSTM 中通过门的作用,可以使连乘项约等于 0 或者 1。首先我们看一下 LSTM 中 ct 与 ht 的计算公式。
image.png

在公式中 ft 与 ot 都是通过 sigmoid 函数得到的,意味着它们的值要么接近 0,要么接近 1。因此在 LSTM 中的连乘项变成:
image.png

因此当门的梯度接近1时,连乘项能够保证梯度很好地在 LSTM 中传递,避免梯度消失的情况发生。

而当门的梯度接近 0 时,意味着上一时刻的信息对当前时刻并没有作用,此时没有必要把梯度回传。

这就是 LSTM 能够克服梯度消失、梯度爆炸的原因。

4、GRU

GRU 是 LSTM 的一种变种,结构比 LSTM 简单一点。LSTM有三个门 (遗忘门 forget,输入门 input,输出门output),而 GRU 只有两个门 (更新门 update,重置门 reset)。另外,GRU 没有 LSTM 中的 cell 状态 c。

image.png
image.png
图中的 zt 和 rt 分别表示更新门 (红色) 和重置门 (蓝色)。重置门 rt 控制着前一状态的信息 ht-1 传入候选状态 (图中带波浪线的ht) 的比例,重置门 rt 的值越小,则与 ht-1 的乘积越小,ht-1 的信息添加到候选状态越少。更新门用于控制前一状态的信息 ht-1 有多少保留到新状态 ht 中,当 (1-zt) 越大,保留的信息越多。

总结

循环神经网络适合用于序列数据,也是学习 NLP 过程中必学的模型,很多 NLP 的应用、算法都用到了循环神经网络。

传统的循环神经网络 RNN 容易出现梯度消失与梯度爆炸的问题,因此目前比较常用的一般是 LSTM 及其变种。

但也因为引入了很多内容,导致参数变多,也使得训练难度加大了很多。因此很多时候往往会使用效果和LSTM相当但参数更少的GRU来构建大训练量的模型。

LSTM - 图22