网络定义
rnn = nn.LSTM(10, 20, 2)
输入特征size,隐藏层size,LSTM网络层数,默认为1
Inputs: input, (h_0, c_0)
input = torch.randn(5, 3, 10)
(seq_len, batch, input_size) 序列长度,batch,输入特征size
h0 = torch.randn(2, 3, 20)
(numlayers * num_directions, batch, hidden_size)
_
c0 = torch.randn(2, 3, 20)
(num_layers * num_directions, batch, hidden_size)
Outputs: output, (h_n, c_n)
output, (hn, cn) = rnn(input, (h0, c0))
output* (seq_len, batch, num_directions hidden_size)
比如这个seq长度是seq_len,那么会产出每个字,对应的隐藏层特征(size = num_directions * hidden_size)
h_n of shape (num_layers * num_directions, batch, hidden_size)
seq经过lstm网络处理后,对后续语句的影响(隐藏层状态),
graphsage中就用这个,来计算邻居序列对中心节点的影响
c_n of shape (num_layers * num_directions, batch, hidden_size)
seq经过lstm网络处理后,对后续语句的影响(cell状态)
Example
rnn = nn.LSTM(10, 20, 2) input = torch.randn(5, 3, 10) h0 = torch.randn(2, 3, 20) c0 = torch.randn(2, 3, 20) output, (hn, cn) = rnn(input, (h0, c0))