动机
在实践中,除了训练后可能出现新节点的可能性之外,节点也可能频繁出现和消失。
现有的方法使用GNN作为特征提取器,使用RNN从提取的节点特征中学习动态。但是这些方法需要节点在整个时间跨度内的信息(包括训练集和测试集),这对节点集频繁变化的情况不太适用。
本文使用RNN来演化GNN参数,以便在演化的网络参数中捕获动态。特点是优点可以处理更灵活的动态数据,因为节点不需要一直存在。
核心思想
注意随着时间的变化,图的结构完全改变
- 每个离散时刻的图,使用一个GCN模型进行训练,得到各自GCN的权重。
- 由于图是有时序变化关系的,那么对应的每个GCN模型的权重,也是有关的。
如果把各个时刻的GCN中,相同层的参数当成一个序列,那么就可以用RNN来进行学习。因为RNN可以记录历史信息,是进行时间序列分析时最好的选择。
- 这种方式只关注模型本身,而不关心节点,因此,节点的改变不构成限制。
可以按上述过程理解,但实际过程并非如此,后面会详细介绍
- 方法的核心就是GCN模型中的权重如何去学习和演化,
具体来说就是每个时刻t的GCN模型的第l层参数W**如何进行更新。
**
具体方法
- t:时间index
- l:GCN层index
- n:所有图的结点数量(为了避免符号混乱,我们假设所有的图都有 n 个节点;尽管我们反复强调节点可能会随着时间而变化。)
- :表示时刻 t 的输入数据,A**t** 是图的加权邻接矩阵,X**t** 是输入节点的特征矩阵,即每个节点的特征是 d 维特征向量(静态场景图的输入格式?)
- W**:在时刻 t 的GCN的第 l 层参数
GCN部分
- 在t时刻,GCN中的第 l 层将邻接矩阵 A__ 和节点特征矩阵 H** 作为输入,使用GCN的权重矩阵 W**将节点特征矩阵更新为 H**并将其作为输出(H** = Xt**)
公式这一块就是
- 除了输出层,中间每一层还过一下ReLU
- 假设GCN有 L 层,那么 H 就包含了图节点的高级表示,或者是图节点分类的softmax
GCN Weight Evolution部分
- 论文方法的核心就是基于当前以及历史信息在时间t更新GCN的权重矩阵 W__**。
这个需求可以通过循环结构来实现,有两种选择。
**
- 第一种选择(版本H)
RNN使用gated recurrent unit (GRU),W**既是GCN的权重参数,也是GRU的隐藏层权重。
**
- 第二种选择(版本O)
RNN使用LSTM,虽然GCN部分必然需要结点特征矩阵 H 作为输入,但是在 W**的更新部分不需要 H作为输入。
- 两个版本的图示
- 蓝色表示输出,
- 红色表示输入,
- 两个等式的右侧表示evolving graph convolution unit(EGCU)
- 从evolving graph convolution unit的伪码对比两个版本的不同
- 伪码实际上相比于图片来说,更好地反应了模型的整个过程。
EvolveGCN-O实际运行过程
结合代码和过程图体现EvolveGCN-O实际的运行过程
- 运行过程图示
class EvolveGCNO(torch.nn.Module): …
def forward(self, X: torch.FloatTensor, edge_index: torch.LongTensor,
edge_weight: torch.FloatTensor=None) -> torch.FloatTensor:
"""
Making a forward pass.
Arg types:
* **X** *(PyTorch Float Tensor)* - Node embedding.
* **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
* **edge_weight** *(PyTorch Float Tensor, optional)* - Edge weight vector.
Return types:
* **X** *(PyTorch Float Tensor)* - Output matrix for all nodes.
"""
W = self.conv_layer.weight[None, :, :] # 取出上一时刻GCN的参数Wt-1
W, _ = self.recurrent_layer(W) # 得到当前时刻Wt:Wt = LSTM(Wt-1)
self.conv_layer.weight = torch.nn.Parameter(W.squeeze()) # 存储当前时刻的Wt
# 送入GCN:Ht(l+1) = GCONV(Ht(l), At, Wt(l))
X = self.conv_layer(X, edge_index, edge_weight)
return X
```