一、LSTM结构

参考链接🔗:pytorch_14_lstm原理 - 知乎 (zhihu.com)

1.1 LSTM的整体结构

一个LSTM的整体结构如下图所示

§ 深入浅出LSTM - 图1

看着比卷积神经网络复杂了许多,这是因为它里面添加了许多的门,用来调整不同时刻的输入,按照不同时刻的输入来看一下这个结构

1.2 与前一时刻有关的遗忘门

前一时刻信息单元经过遗忘门进行过滤,前一时刻的输出为§ 深入浅出LSTM - 图2,当前时刻的输入为§ 深入浅出LSTM - 图3,将这两个数据联合作为当前§ 深入浅出LSTM - 图4时刻的输入,计算出遗忘门的权值大小§ 深入浅出LSTM - 图5

§ 深入浅出LSTM - 图6%0A#card=math&code=ft%20%3D%20%5Csigma%28W_f%5Bh%7Bt-1%7D%2C%20x_t%5D%20%2B%20b_f%29%0A&id=yq2lU)

§ 深入浅出LSTM - 图7

其中的§ 深入浅出LSTM - 图8表示的就是一个激活函数,例如ReLU,所以§ 深入浅出LSTM - 图9,而1表示前一时刻的输出§ 深入浅出LSTM - 图10完全可用,0表示前一时刻的输出§ 深入浅出LSTM - 图11无法使用,直接丢弃。

§ 深入浅出LSTM - 图12是什么😧?别着急,后面往下走的时候你就会发现答案。

1.3 输入门

输入门和当前时刻有多少信息放入单元状态

§ 深入浅出LSTM - 图13

输入门:是一个sigmoid层,这个层用来决定我们将要更新的值,这个值是一个权重

§ 深入浅出LSTM - 图14%0A#card=math&code=it%20%3D%20%5Csigma%20%28W_i.%20%5Bh%7Bt-1%7D%2C%20x_t%5D%20%2B%20b_i%29%0A&id=dd86z)

一个tanh层创建一个新的候选值的向量,该向量可以添加到状态中

§ 深入浅出LSTM - 图15%0A#card=math&code=%5Ctilde%7BCt%7D%20%3D%20%5Ctext%7Btanh%7D%28W_c.%5Bh%7Bt-1%7D%2C%20x_t%5D%20%2B%20b_C%29%0A&id=ZLada)

计算当前单元状态:前一时刻单元经过遗忘门,当前时刻和前一时刻§ 深入浅出LSTM - 图16累加信息经过输入门,两个门过滤后的信息累加为单元状态

§ 深入浅出LSTM - 图17

§ 深入浅出LSTM - 图18

其中

  • § 深入浅出LSTM - 图19表示当前单元状态
  • § 深入浅出LSTM - 图20前一时刻单元状态
  • § 深入浅出LSTM - 图21当前输入§ 深入浅出LSTM - 图22和前一时刻§ 深入浅出LSTM - 图23的累加信息

1.4 输出门

通过1.3计算出了当前单元状态§ 深入浅出LSTM - 图24,接下来就要把这个状态给输出。同样地,这个输出又又又要根据§ 深入浅出LSTM - 图25§ 深入浅出LSTM - 图26计算权重

§ 深入浅出LSTM - 图27

通过一个sigmoid层来计算权值§ 深入浅出LSTM - 图28,这个权值表示当前单元状态有多少被输出

§ 深入浅出LSTM - 图29%0A#card=math&code=ot%20%3D%20%5Csigma%28W_o%5Bh%7Bt-1%7D%2Cx_t%5D%2Bb_o%29%0A&id=IBkug)

接下来就将它们当前单元状态输出,首先将当前的单元状态§ 深入浅出LSTM - 图30经过tanh函数从而使得输出值位于[-1, 1],之后与输出门的权重§ 深入浅出LSTM - 图31进行相乘

§ 深入浅出LSTM - 图32

§ 深入浅出LSTM - 图33%0A#card=math&code=h_t%20%3D%20o_t%2A%5Ctext%7Btanh%7D%28C_t%29%0A&id=GYVlb)

注意:之前有一个一直没有解释的§ 深入浅出LSTM - 图34,而现在我们知道了§ 深入浅出LSTM - 图35就是当前单元状态,而§ 深入浅出LSTM - 图36就是根据权重将当前状态输出的值

二、PyTorch实现

2.1 温故知新

进入PyTorch官网:LSTM — PyTorch 1.9.1 documentation,进去就能看到下面的公式

§ 深入浅出LSTM - 图37

头大了吗?很明显没有,之前我们都看过LSTM的处理流程,而这个公式正好和那个流程契合起来,下面就把这个公式拆分一下

§ 深入浅出LSTM - 图38 与前一时刻有关的遗忘门

§ 深入浅出LSTM - 图39

§ 深入浅出LSTM - 图40

§ 深入浅出LSTM - 图41 与当前时刻有关的输入门

§ 深入浅出LSTM - 图42

§ 深入浅出LSTM - 图43%0A#card=math&code=i%20%7Bt%7D%20%20%3D%20%20%5Csigma%20%28%20%20W%20%7Bii%7D%20%20%20x%20%7Bt%7D%20%20%2B%20%20b%20%7Bii%7D%20%20%2B%20%20W%20%7Bhi%7D%20%20h%20%7Bt-1%7D%20%20%2B%20%20b_%20%7Bhi%7D%20%20%29%0A&id=cncmY)

上述图中的§ 深入浅出LSTM - 图44就计算为

§ 深入浅出LSTM - 图45%0A#card=math&code=%5Ctilde%7BCt%7D%3Dg%7Bt%7D%20%3D%20%5Ctext%7Btanh%7D%28W%7Big%7Dx%7Bt%7D%2Bb%7Big%7D%2BW%7Bhg%7Dh%7Bt-1%7D%2Bb%7Bhg%7D%29%0A&id=JCztp)

所以计算出当前状态

§ 深入浅出LSTM - 图46

§ 深入浅出LSTM - 图47 与当前时刻有关的输出门

§ 深入浅出LSTM - 图48

§ 深入浅出LSTM - 图49%5C%5C%0Ah%20%7Bt%7D%20%20%26%3D%20%20o%20%7Bt%7D%5Codot%20%5Ctanh%20(%20c%20%7Bt%7D)%0A%5Cend%7Baligned%7D%0A#card=math&code=%5Cbegin%7Baligned%7D%0Ao%20%7Bt%7D%20%26%3D%20%5Csigma%20%28%20W%20%7Bio%7D%20x%20%7Bt%7D%20%20%2B%20%20b%20%7Bio%7D%20%20%2B%20%20W%20%7Bho%7D%20h%20%7Bt-1%7D%20%2B%20%20b%20%7Bho%7D%29%5C%5C%0Ah%20%7Bt%7D%20%20%26%3D%20%20o%20%7Bt%7D%5Codot%20%5Ctanh%20%28%20c_%20%7Bt%7D%29%0A%5Cend%7Baligned%7D%0A&id=nmIuP)

2.2 代码实现

§ 深入浅出LSTM - 图50网络初始化:在使用之前首先要初始化该网络,所以就需要知道初始化该网络的参数,如下所示

  • input_size:输入x的大小
  • hidden_size:隐藏层h的个数
  • num_layers:循环层的数量,如果设置num_layers=2就表示将两个LSTM网络联合起来,其中第2个LSTM网络以第1个LSTM网络的输出作为输入,默认为1
  • bias:是否使用偏置,默认为True
  • batch_first:如果设置为True,输入和输出的形状为(batch_size[1], channels[2], Hin[3]),否则的话就为(channels, batch_size, Hin)。而一般人应该都是用第一种吧,设置为True,因为默认为False
  • dropout:如果设置为非0的话,就在网络中的每一个LSTM层后添加一个dropout层;如果设置为0,只在最后一个LSTM层后添加dropout。默认为0
  • bidirectional:如果设置为True,则就变成一个双向的LSTM,默认为False
  • proj_size:如果大于0,则会用LSTM映射为相关的大小,默认为0.

§ 深入浅出LSTM - 图51网络输入:初始化完成后就需要往网络中添加数据训练,添加的数据如下

  • input:上面初始化的时候设置batch_first为True的话,则按照(batch_size, channels, Hin)的形式输入数据
  • h_0:初始化网络中的隐藏层,如果这个不输入的话网络就会自己初始化一个,大小为(D[4]*num_layers, batch_size, Hout)
  • c_0:最一开始的单元状态c_0,如果不输入网络自己初始化,形状为(D*num_layers, batch_size, Hcell[5])

§ 深入浅出LSTM - 图52 网络输出:整个LSTM的网络输出有3个,分别是

  • output:网络的输出,如果设置batch_first=True的话,则该输出为(batch_size, channels, D*Hout[6])
  • h_n:最后一个隐藏层的输出(D*num_layers, batch_size, Hout)
  • c_n:最后一个单元输出(D*num_layers, batch_size, Hcell)

📎 对上述的标注进行一下解释:

  1. batch_size:输入的批次大小
  2. channels:输入的数据通道大小
  3. Hin:输入的特征长度
  4. D=2 如果LSTM是双向的,否则为1
  5. Hcell:隐藏层大小
  6. Hout:输出数据的大小,如果设置参数proj_size>0的话则输出大小为proj_size;如果参数proj_size=0的话输出大小为Hcell

给一个官网的例子

  1. import torch.nn as nn
  2. import torch
  3. rnn = nn.LSTM(10, 20, 2, batch_first=True)
  4. input = torch.randn(3, 5, 10) # batch_size=3, channels=5, features=10
  5. h0 = torch.randn(2, 3, 20)
  6. c0 = torch.randn(2, 3, 20)
  7. output, (hn, cn) = rnn(input, (h0, c0))