一、LSTM结构
参考链接🔗:pytorch_14_lstm原理 - 知乎 (zhihu.com)
1.1 LSTM的整体结构
一个LSTM的整体结构如下图所示
看着比卷积神经网络复杂了许多,这是因为它里面添加了许多的门,用来调整不同时刻的输入,按照不同时刻的输入来看一下这个结构
1.2 与前一时刻有关的遗忘门
前一时刻信息单元经过遗忘门进行过滤,前一时刻的输出为,当前时刻的输入为,将这两个数据联合作为当前时刻的输入,计算出遗忘门的权值大小
%0A#card=math&code=ft%20%3D%20%5Csigma%28W_f%5Bh%7Bt-1%7D%2C%20x_t%5D%20%2B%20b_f%29%0A&id=yq2lU)
其中的表示的就是一个激活函数,例如ReLU,所以,而1表示前一时刻的输出完全可用,0表示前一时刻的输出无法使用,直接丢弃。
是什么😧?别着急,后面往下走的时候你就会发现答案。
1.3 输入门
输入门和当前时刻有多少信息放入单元状态
输入门:是一个sigmoid层,这个层用来决定我们将要更新的值,这个值是一个权重
%0A#card=math&code=it%20%3D%20%5Csigma%20%28W_i.%20%5Bh%7Bt-1%7D%2C%20x_t%5D%20%2B%20b_i%29%0A&id=dd86z)
一个tanh层创建一个新的候选值的向量,该向量可以添加到状态中
%0A#card=math&code=%5Ctilde%7BCt%7D%20%3D%20%5Ctext%7Btanh%7D%28W_c.%5Bh%7Bt-1%7D%2C%20x_t%5D%20%2B%20b_C%29%0A&id=ZLada)
计算当前单元状态:前一时刻单元经过遗忘门,当前时刻和前一时刻的累加信息经过输入门,两个门过滤后的信息累加为单元状态
其中
- 表示当前单元状态
- 前一时刻单元状态
- 当前输入和前一时刻的累加信息
1.4 输出门
通过1.3计算出了当前单元状态,接下来就要把这个状态给输出。同样地,这个输出又又又要根据和计算权重
通过一个sigmoid层来计算权值,这个权值表示当前单元状态有多少被输出
%0A#card=math&code=ot%20%3D%20%5Csigma%28W_o%5Bh%7Bt-1%7D%2Cx_t%5D%2Bb_o%29%0A&id=IBkug)
接下来就将它们当前单元状态输出,首先将当前的单元状态经过tanh函数从而使得输出值位于[-1, 1],之后与输出门的权重进行相乘
%0A#card=math&code=h_t%20%3D%20o_t%2A%5Ctext%7Btanh%7D%28C_t%29%0A&id=GYVlb)
注意:之前有一个一直没有解释的,而现在我们知道了就是当前单元状态,而就是根据权重将当前状态输出的值
二、PyTorch实现
2.1 温故知新
进入PyTorch官网:LSTM — PyTorch 1.9.1 documentation,进去就能看到下面的公式
头大了吗?很明显没有,之前我们都看过LSTM的处理流程,而这个公式正好和那个流程契合起来,下面就把这个公式拆分一下
与前一时刻有关的遗忘门
与当前时刻有关的输入门
%0A#card=math&code=i%20%7Bt%7D%20%20%3D%20%20%5Csigma%20%28%20%20W%20%7Bii%7D%20%20%20x%20%7Bt%7D%20%20%2B%20%20b%20%7Bii%7D%20%20%2B%20%20W%20%7Bhi%7D%20%20h%20%7Bt-1%7D%20%20%2B%20%20b_%20%7Bhi%7D%20%20%29%0A&id=cncmY)
上述图中的就计算为
%0A#card=math&code=%5Ctilde%7BCt%7D%3Dg%7Bt%7D%20%3D%20%5Ctext%7Btanh%7D%28W%7Big%7Dx%7Bt%7D%2Bb%7Big%7D%2BW%7Bhg%7Dh%7Bt-1%7D%2Bb%7Bhg%7D%29%0A&id=JCztp)
所以计算出当前状态
与当前时刻有关的输出门
%5C%5C%0Ah%20%7Bt%7D%20%20%26%3D%20%20o%20%7Bt%7D%5Codot%20%5Ctanh%20(%20c%20%7Bt%7D)%0A%5Cend%7Baligned%7D%0A#card=math&code=%5Cbegin%7Baligned%7D%0Ao%20%7Bt%7D%20%26%3D%20%5Csigma%20%28%20W%20%7Bio%7D%20x%20%7Bt%7D%20%20%2B%20%20b%20%7Bio%7D%20%20%2B%20%20W%20%7Bho%7D%20h%20%7Bt-1%7D%20%2B%20%20b%20%7Bho%7D%29%5C%5C%0Ah%20%7Bt%7D%20%20%26%3D%20%20o%20%7Bt%7D%5Codot%20%5Ctanh%20%28%20c_%20%7Bt%7D%29%0A%5Cend%7Baligned%7D%0A&id=nmIuP)
2.2 代码实现
网络初始化:在使用之前首先要初始化该网络,所以就需要知道初始化该网络的参数,如下所示
input_size
:输入x的大小hidden_size
:隐藏层h的个数num_layers
:循环层的数量,如果设置num_layers=2
就表示将两个LSTM网络联合起来,其中第2个LSTM网络以第1个LSTM网络的输出作为输入,默认为1bias
:是否使用偏置,默认为True
batch_first
:如果设置为True
,输入和输出的形状为(batch_size[1], channels[2], Hin[3]),否则的话就为(channels, batch_size, Hin)。而一般人应该都是用第一种吧,设置为True,因为默认为Falsedropout
:如果设置为非0的话,就在网络中的每一个LSTM层后添加一个dropout层;如果设置为0,只在最后一个LSTM层后添加dropout。默认为0bidirectional
:如果设置为True
,则就变成一个双向的LSTM,默认为False
。proj_size
:如果大于0,则会用LSTM映射为相关的大小,默认为0.
网络输入:初始化完成后就需要往网络中添加数据训练,添加的数据如下
input
:上面初始化的时候设置batch_first
为True的话,则按照(batch_size, channels, Hin)的形式输入数据h_0
:初始化网络中的隐藏层,如果这个不输入的话网络就会自己初始化一个,大小为(D[4]*num_layers, batch_size, Hout)c_0
:最一开始的单元状态c_0
,如果不输入网络自己初始化,形状为(D*num_layers, batch_size, Hcell[5])
网络输出:整个LSTM的网络输出有3个,分别是
output
:网络的输出,如果设置batch_first=True
的话,则该输出为(batch_size, channels, D*Hout[6])h_n
:最后一个隐藏层的输出(D*num_layers, batch_size, Hout)c_n
:最后一个单元输出(D*num_layers, batch_size, Hcell)
📎 对上述的标注进行一下解释:
- batch_size:输入的批次大小
- channels:输入的数据通道大小
- Hin:输入的特征长度
- D=2 如果LSTM是双向的,否则为1
- Hcell:隐藏层大小
- Hout:输出数据的大小,如果设置参数
proj_size>0
的话则输出大小为proj_size
;如果参数proj_size=0
的话输出大小为Hcell
给一个官网的例子
import torch.nn as nn
import torch
rnn = nn.LSTM(10, 20, 2, batch_first=True)
input = torch.randn(3, 5, 10) # batch_size=3, channels=5, features=10
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))