网络定义

rnn = nn.LSTM(10, 20, 2)
输入特征size,隐藏层size,LSTM网络层数,默认为1

Inputs: input, (h_0, c_0)

input = torch.randn(5, 3, 10)
(seq_len, batch, input_size) 序列长度,batch,输入特征size

h0 = torch.randn(2, 3, 20)
(numlayers * num_directions, batch, hidden_size)
_
c0 = torch.randn(2, 3, 20)
(num_layers * num_directions, batch, hidden_size)

Outputs: output, (h_n, c_n)

output, (hn, cn) = rnn(input, (h0, c0))

output* (seq_len, batch, num_directions hidden_size)
比如这个seq长度是seq_len,那么会产出每个字,对应的隐藏层特征(size = num_directions * hidden_size)

h_n of shape (num_layers * num_directions, batch, hidden_size)
seq经过lstm网络处理后,对后续语句的影响(隐藏层状态),
graphsage中就用这个,来计算邻居序列对中心节点的影响

c_n of shape (num_layers * num_directions, batch, hidden_size)
seq经过lstm网络处理后,对后续语句的影响(cell状态)

Example

rnn = nn.LSTM(10, 20, 2) input = torch.randn(5, 3, 10) h0 = torch.randn(2, 3, 20) c0 = torch.randn(2, 3, 20) output, (hn, cn) = rnn(input, (h0, c0))