CS224W:图机器学习4

GNN图神经网络

前面几节内容主要介绍了节点嵌入的概念,也就是说我们可以将一个图中的节点映射到一个的d维向量上,而这种映射方式使得相似的节点对应的向量更接近,但主要的问题还是,我们如何学习出一个映射函数f,而图嵌入的两个核心组件是编码器和相似度函数,之前也介绍了比较naive的shallow编码器,而图神经网络GNN提供了一种基于深度学习的节点嵌入方法。

CS224W:图机器学习04 - 图1

  • 相比于图片和文本,网络结构要更复杂并且缺少规律性,而且往往是动态的

深度学习基础

  • 介绍了监督学习、损失函数、机器学习中的优化、梯度向量等基本概念,不过这些内容已经反复接触过了,因此也就不再专门整理记录
  • 随机梯度下降:
    • 一般的梯度下降需要求出整个损失函数的梯度,并结合设置好的学习率进行参数的更新,但是这样做的计算复杂度太高了
    • 随机梯度下降就是每一次选择一个小批量(minibatch)的数据集B对其求梯度并进行参数的更新
  • 深度学习中的函数往往更为复杂,并且在反向传播进行参数更新的过程中需要用链式法则来求梯度,但总的来讲,不管是传统的监督机器学习还是深度学习,我们的优化目标都是:

CS224W:图机器学习04 - 图2)%0A#card=math&code=%5Cmin%20%5Cmathcal%20L%28y%2Cf%28x%29%29%0A&id=uCECy)

图深度学习

问题的定义

  • 我们把要研究的图记为G,图中所有节点构成集合V,A是图的接邻矩阵
  • 用X表示图节点的特征矩阵,每个节点具有m维的特征,节点的特征可以根据图的实际情况来选取
  • 一种很naive的方法是将图G的接邻矩阵和特征矩阵进行合并,然后放到一个深度神经网络中进行学习,但是这样一来就会有CS224W:图机器学习04 - 图3#card=math&code=O%28%7CV%7C%29&id=Rt7E0)数量级的参数,并且对于不同大小的图需要重新设计网络结构,这样的处理方式也使得结果对节点的顺序非常敏感

CS224W:图机器学习04 - 图4

图卷积神经网络的结构

联想到CNN通过卷积的方式来提取和融合邻近像素点的特征,图结构中也可以聚合节点的邻近节点的特征,但是相比于图像,图结构往往不具有一个固定的子结构或者滑动窗口可以用来定义图中的卷积,并且是permutationinvariant的,我们可以使用局部的邻近节点的特征来生成图节点的嵌入向量

  • 每个节点生成了一张计算图,来表示其特征融合的过程

CS224W:图机器学习04 - 图5

  • 这种神经网络可以是任意深度,每一层有节点的嵌入:
    • 第0层只用输入的特征表示节点的嵌入向量
    • 第k层使用从k个away之前得到的信息来进行节点嵌入

CS224W:图机器学习04 - 图6

  • 上述结构只是一种最基本的框架,最主要的区别在于图中的box对应的内容,也就是信息聚合的方式和之后的一系列处理
    • 常见的方法:对得到的信息求平均并放到神经网络中

CS224W:图机器学习04 - 图7%7D%3D%5Csigma%20(Wl%5Csum%7Bu%5Cin%20N(v)%7D%5Cfrac%7Bhu%5E%7B(l)%7D%7D%7B%7CN(v)%7C%7D%2BB_lh_v%5E%7B(l)%7D)%2C%5Cquad%20h_v%5E0%3DX_v%0A#card=math&code=h_v%5E%7B%28l%2B1%29%7D%3D%5Csigma%20%28W_l%5Csum%7Bu%5Cin%20N%28v%29%7D%5Cfrac%7Bh_u%5E%7B%28l%29%7D%7D%7B%7CN%28v%29%7C%7D%2BB_lh_v%5E%7B%28l%29%7D%29%2C%5Cquad%20h_v%5E0%3DX_v%0A&id=DYYNp)

  • 这里的CS224W:图机器学习04 - 图8是一个非线性的激活函数
  • 在得到了嵌入向量之后可以使用损失函数和随机梯度下降的方式来训练权重参数W和B
    • 我们可以将上面的表达式进行向量化来提高计算的效率,用对角矩阵D来表示CS224W:图机器学习04 - 图9%7C#card=math&code=D%7Bv%2Cv%7D%3D%7CN%28v%29%7C&id=mcV36),这样一来就有![](https://g.yuque.com/gr/latex?D%5E%7B-1%7D%7Bv%2Cv%7D%3D%5Cfrac%201%20%7BN(v)%7D#card=math&code=D%5E%7B-1%7D%7Bv%2Cv%7D%3D%5Cfrac%201%20%7BN%28v%29%7D&id=tbJPk)并且令CS224W:图机器学习04 - 图10%7D%3D%5Bh_1%5E%7B(l)%7D%2C%5Cdots%2Ch%7B%7CV%7C%7D%5E%7B(l)%7D%5D%5ET#card=math&code=H%5E%7B%28l%29%7D%3D%5Bh1%5E%7B%28l%29%7D%2C%5Cdots%2Ch%7B%7CV%7C%7D%5E%7B%28l%29%7D%5D%5ET&id=Qz580),这样一来上面的式子就可以表示为:CS224W:图机器学习04 - 图11%7D%3D%5Csigma(D%5E%7B-1%7DA%20H%5E%7B(l)%7DW_l%5ET%2BH%5E%7B(l)%7DB_l%5ET)%0A#card=math&code=H%5E%7B%28l%2B1%29%7D%3D%5Csigma%28D%5E%7B-1%7DA%20H%5E%7B%28l%29%7DW_l%5ET%2BH%5E%7B%28l%29%7DB_l%5ET%29%0A&id=CahwN)
  • selftranslation+neighborhoodaggregation两个部分组成

如何训练GNN

  • 监督学习的训练方式:定义一个loss函数并进行优化,相似的节点会有相近的嵌入向量,因此我们可以定义:CS224W:图机器学习04 - 图12)%0A#card=math&code=%5Cmathcal%20L%3D%5Csum%7Bz_u%2Cz_v%7D%5Cmathrm%7BCE%7D%28y%7Bu%2Cv%7D%2C%5Cmathrm%20%7BDEC%7D%28z_u%2Cz_v%29%29%0A&id=BcVMg)
    • CE表示交叉熵,y是表示两个节点相似的标签,DEC是解码器,比如可以使用内积来定义
    • 相似度的定义可以用前面提到的任何方法,包括随机游走,矩阵分解等等
  • 无监督学习的训练方式:没有节点的标签,可以使用图的结构作为监督
  • 对于一个节点分类的问题,可以用监督学习的方式直接训练模型,使用交叉熵loss表示如下:CS224W:图机器学习04 - 图13)%2B(1-yv)%5Clog%20(1-%5Csigma(z_v%5ET%5Ctheta))%0A#card=math&code=%5Cmathcal%20L%3D%5Csum%7Bv%5Cin%20V%7Dy_v%5Clog%20%28%5Csigma%28z_v%5ET%5Ctheta%29%29%2B%281-y_v%29%5Clog%20%281-%5Csigma%28z_v%5ET%5Ctheta%29%29%0A&id=muB4c)
    • 这里的y表示节点的标签,而z表示嵌入向量,CS224W:图机器学习04 - 图14表示分类的权重
  • 但是最终,我们学习到的一系列参数和模型是需要在新的图和新的节点上测试效果的,因此需要使模型拥有更好的泛化能力

图神经网络模型的Pipeline

  • 定义一个邻居节点聚合的方式
  • 定义一个基于嵌入向量的损失函数
  • 使用一系列节点进行训练
  • 生成节点的嵌入,用于实际的任务中

GNN概览

GNN由一系列GNN层线性组合构成,而GNN层包含了message和aggregation等多层次的信息,当我们训练一个GNN的时候可以从监督学习的视角出发进行训练,也可以从非监督学习的视角出发进行训练

CS224W:图机器学习04 - 图15

GNN的层结构

GNN单独的层结构将多个向量组合成一个单独的向量,需要经过两个步骤的处理,分别是message和aggregation,将输入的节点嵌入向量转化成输出的节点嵌入向量

  • message步骤:每个节点都会创建一个消息发送给附近的一系列节点,这样一来就有CS224W:图机器学习04 - 图16%7D%3D%5Cmathrm%7BMSG%7D%5E%7B(l)%7D(h_u%5E%7B(l-1)%7D)%0A#card=math&code=m_u%5E%7B%28l%29%7D%3D%5Cmathrm%7BMSG%7D%5E%7B%28l%29%7D%28h_u%5E%7B%28l-1%29%7D%29%0A&id=ZTPZ6)
  • aggregation步骤:每个节点会收集从其他节点发来的消息并进行一定的处理,比如求和,求均值或者最大值等等CS224W:图机器学习04 - 图17%7D%3D%5Cmathrm%7BACG%7D%5E%7B(l)%7D(m_u%5E%7B(l)%7D%2Cu%5Cin%20N(v))%0A#card=math&code=h_v%5E%7B%28l%29%7D%3D%5Cmathrm%7BACG%7D%5E%7B%28l%29%7D%28m_u%5E%7B%28l%29%7D%2Cu%5Cin%20N%28v%29%29%0A&id=GwNWj)
  • 但是只有aggregation的话每一层的节点输出的嵌入向量都不包含其自身从上一层携带回来的特征,因此可以将message步骤和aggregation步骤得到的结果进行合并作为最终的输出结果
    • 并且可以在两个过程中使用一些非线性的激活函数

CS224W:图机器学习04 - 图18%7D%3D%5Cmathrm%7BCONCAT%7D(%5Cmathrm%20%7BACG%7D%5E%7B(l)%7D(m_u%5E%7B(l)%7D%2Cu%5Cin%20N(v))%2C%20m_v%5E%7B(l)%7D)%0A#card=math&code=h_v%5E%7B%28l%29%7D%3D%5Cmathrm%7BCONCAT%7D%28%5Cmathrm%20%7BACG%7D%5E%7B%28l%29%7D%28m_u%5E%7B%28l%29%7D%2Cu%5Cin%20N%28v%29%29%2C%20m_v%5E%7B%28l%29%7D%29%0A&id=T2urI)

GCN

对于图卷积网络GCN的层结构,上面已经介绍过了它每一层的嵌入向量更新方式

CS224W:图机器学习04 - 图19%7D%3D%5Csigma%20(%5Csum%7Bu%5Cin%20N(v)%7DW_l%5Cfrac%7Bh_u%5E%7B(l)%7D%7D%7B%7CN(v)%7C%7D)%0A#card=math&code=h_v%5E%7B%28l%2B1%29%7D%3D%5Csigma%20%28%5Csum%7Bu%5Cin%20N%28v%29%7DW_l%5Cfrac%7Bh_u%5E%7B%28l%29%7D%7D%7B%7CN%28v%29%7C%7D%29%0A&id=pyR1m)

  • 这个公式里面其实分成message和aggregation两个部分,其中求和就是aggregation,而CS224W:图机器学习04 - 图20%7D%7D%7B%7CN(v)%7C%7D#card=math&code=W_l%5Cfrac%7Bh_u%5E%7B%28l%29%7D%7D%7B%7CN%28v%29%7C%7D&id=wEga1)就是每个节点产生并广播到所有邻居节点的message

GraphSAGE

GraphSAGE是另一种架构的图神经网络,其核心思想是使用一个可微分的函数将邻居节点传递过来的信息应设成一个单独的vector每一层的更新方式是:

CS224W:图机器学习04 - 图21%7D%3D%5Csigma%20(W_l%5Ctimes%5Cmathrm%7BCONCAT%7D(B_lh_V%5E%7B(l-1)%7D%2C%20%5Cmathrm%7BACG%7D(h_u%5E%7B(l-1)%7D%2Cu%20%5Cin%20N(v))))%0A#card=math&code=h_v%5E%7B%28l%29%7D%3D%5Csigma%20%28W_l%5Ctimes%5Cmathrm%7BCONCAT%7D%28B_lh_V%5E%7B%28l-1%29%7D%2C%20%5Cmathrm%7BACG%7D%28h_u%5E%7B%28l-1%29%7D%2Cu%20%5Cin%20N%28v%29%29%29%29%0A&id=Ihqys)

  • 这种架构的Message是每个邻居节点发过来的嵌入向量,通过ACG函数进行aggregation,之后再和当前节点本身的
  • 其中聚合函数ACG有多种不同的选择方式,比如使用均值函数,或者进行池化(最大池化,最小池化等等),也可以使用LSTM来进行邻近节点的reshuffled
  • 可以在每一层都可以进行一个l2标准化,即CS224W:图机器学习04 - 图22%7D%3D%5Cfrac%20%7Bh_v%5E%7B(l)%7D%7D%7B%7C%7Ch_v%5E%7B(l)%7D%7C%7C_2%7D#card=math&code=h_v%5E%7B%28l%29%7D%3D%5Cfrac%20%7Bh_v%5E%7B%28l%29%7D%7D%7B%7C%7Ch_v%5E%7B%28l%29%7D%7C%7C_2%7D&id=VX7kG),使用了标准化操作之后就会使得所有嵌入向量的L2范数统一,可以给性能带来比较大的提升。

GAT与注意力机制

GAT是GraphAttentionNetwork,即在图神经网络中加入了注意力机制,我们可以使用CS224W:图机器学习04 - 图23来表示节点v的邻居节点u的重要程度,这样的网络可以使网络在计算时将注意力集中到一些比较重要的节点上面去。

CS224W:图机器学习04 - 图24%7D%3D%5Csigma(%5Csum%7Bu%5Cin%20N(v)%7D%5Calpha%7Bvu%7DW%5E%7B(l)%7Dhu%5E%7B(l-1)%7D)%0A#card=math&code=h_v%5E%7B%28l%29%7D%3D%5Csigma%28%5Csum%7Bu%5Cin%20N%28v%29%7D%5Calpha_%7Bvu%7DW%5E%7B%28l%29%7Dh_u%5E%7B%28l-1%29%7D%29%0A&id=Ggxz0)

  • 一种CS224W:图机器学习04 - 图25的计算方式:
    • 首先定义一个计算重要性系数的函数a并计算两个相邻节点的重要程度CS224W:图机器学习04 - 图26%7Dhu%5E%7B(l-1)%7D%2CW%5E%7B(l)%7Dh_v%5E%7B(l-1)%7D)%0A#card=math&code=e%7Bvu%7D%3Da%28W%5E%7B%28l%29%7Dh_u%5E%7B%28l-1%29%7D%2CW%5E%7B%28l%29%7Dh_v%5E%7B%28l-1%29%7D%29%0A&id=iLh67)CS224W:图机器学习04 - 图27
    • 然后使用softmax来计算出权重CS224W:图机器学习04 - 图28CS224W:图机器学习04 - 图29%3D%5Cfrac%20%7B%5Cexp%20(e%7Bvu%7D)%7D%7B%5Csum%7Bu%5Cin%20N(v)%7D%20%5Cexp(%7Be%0A%7Bvu%7D%7D)%7D%0A#card=math&code=%5Calpha%7Bvu%7D%3D%5Cmathrm%20%7Bsoftmax%7D%28e%7Bvu%7D%29%3D%5Cfrac%20%7B%5Cexp%20%28e%7Bvu%7D%29%7D%7B%5Csum%7Bu%5Cin%20N%28v%29%7D%20%5Cexp%28%7Be%0A%7Bvu%7D%7D%29%7D%0A&id=dPq3s)CS224W:图机器学习04 - 图30
    • 这里的a函数往往是先进行嵌入向量的连接再通过一个线性函数进行运算然后输出结果
  • 多端注意力机制:创建多个注意力分数并进行连接

CS224W:图机器学习04 - 图31%7D%5Bi%5D%3D%5Csigma(%5Csum%7Bu%5Cin%20N(v)%7D%5Calpha%7Bvu%7D%5EiW%5E%7B(l)%7Dhu%5E%7B(l-1)%7D)%0A#card=math&code=h_v%5E%7B%28l%29%7D%5Bi%5D%3D%5Csigma%28%5Csum%7Bu%5Cin%20N%28v%29%7D%5Calpha_%7Bvu%7D%5EiW%5E%7B%28l%29%7Dh_u%5E%7B%28l-1%29%7D%29%0A&id=h2muB)

  • 注意力机制的优点:
    • 最主要的优点是允许对不同的邻居节点采取不同的权重来突出一些重要的邻居
    • 并且注意力权重的计算可以并行,提高了计算的效率
    • 提高了存储的效率

通用的GNN层设计

CS224W:图机器学习04 - 图32

  • 批标准化:可以是的神经网络的训练更加稳定,做法就是求出每一层输出结果的平均值和方差,然后进行标准化处理,可以防止过大或者过小的数据出现而导致梯度消失或者梯度爆炸等问题
  • Dropout:用来避免神经网络过你和,在训练的时候,以一定概率将一些神经元设置成0(即中途退出了计算),在测试的时候使用所有的神经元参与到计算中
    • 在GNN中一般在线性层使用dropout
  • 激活:使用一些非线性的激活函数,将特征非线性化,常见的非线性激活函数有sigmoid,ReLU等等

GNN的层级架构

上面主要介绍的是GNN单层的设计方式和基本的idea以及一些常用的神经网络tricks,但是在设计好了若干层神经网络之后,还需要用一定的规则将其组合起来,在GNN中往往就是将多个层级进行线性的组合,有的时候也会用的skipconnection机制。

  • 过平滑问题:是指所有的节点嵌入向量最终收敛到了同一个值上,这是一个非常糟糕的现象,因为我们希望不同的节点的嵌入向量是不同的,否则无法区别出不同的节点
  • 接受域ReceptiveField:是决定一个节点的嵌入向量的节点集合,在一个K层的GNN中,每个节点的接受域包含了相距K-hop的邻居
    • 可以用接受域来解释过平滑问题出现的原因

CS224W:图机器学习04 - 图33

  • 随着GNN的层数增加,每个节点的接受域也在不断扩大,这就会导致两个节点的接受域的重合度越来越高,而接受域可以决定一个节点的嵌入向量,因此随着层数增加嵌入向量也会越来越相似,最终就会导致过平滑问题的出现
  • 如何解决过平滑的问题?
    • 在添加GNN层的时候慎重考虑,不像CNN之类的网络结构,GNN有的时候添加太多层可能不会起作用
    • 因此需要先分析解决问题必要的接受域的大小,然后再将GNN层数设定为一个稍大于所需接受域的值,不能设置的太大,否则就会出现上述过平滑的问题
    • 可以减少GNN的层数(即使用shallowGNN),但是在每一层的box中增加多个线性层,也可以增加一些不传递消息的层,比如预处理层和后处理层
      • 改变GNN层的连接方式:增加skipconnection
  • 有的时候位于更底层的GNN层产生的嵌入向量更能区分不同的节点,我们可以方法底层在最后的嵌入向量中的作用效果,方法就是在GNN中添加一些shortcuts,也叫做skipconnection

CS224W:图机器学习04 - 图34

  • 为什么skipconnection可以work?因为skip机制创建了一个混合模型,通过skip,比较底层的一些计算结果直接跳过了中间一些层的计算直接作用到了更高的层次中,这样一来就使得深的GNN和浅的GNN的计算结果进行了混合
  • 有skipconnection机制的GCN的更新方式是:

CS224W:图机器学习04 - 图35%7D%3D%5Csigma%20(Wl%5Csum%7Bu%5Cin%20N(v)%7D%5Cfrac%7Bhu%5E%7B(l)%7D%7D%7B%7CN(v)%7C%7D)%5Crightarrow%20h_v%5E%7B(l%2B1)%7D%3D%5Csigma%20(W_l%5Csum%7Bu%5Cin%20N(v)%7D%5Cfrac%7Bhu%5E%7B(l)%7D%7D%7B%7CN(v)%7C%7D%2Bh_v%5E%7B(l)%7D)%0A#card=math&code=h_v%5E%7B%28l%2B1%29%7D%3D%5Csigma%20%28W_l%5Csum%7Bu%5Cin%20N%28v%29%7D%5Cfrac%7Bhu%5E%7B%28l%29%7D%7D%7B%7CN%28v%29%7C%7D%29%5Crightarrow%20h_v%5E%7B%28l%2B1%29%7D%3D%5Csigma%20%28W_l%5Csum%7Bu%5Cin%20N%28v%29%7D%5Cfrac%7Bh_u%5E%7B%28l%29%7D%7D%7B%7CN%28v%29%7C%7D%2Bh_v%5E%7B%28l%29%7D%29%0A&id=BlFkW)

GNN的图处理

  • 输入的图不一定就是最优的计算图,因为输入的图可能缺少特征,太稀疏或者太稠密,也有可能因为规模太大而无法在GPU上进行计算,因此可以对输入的图进行一定的处理,比如
    • 图中的特征太少的时候可以增加特征,比如:
      • 给图中的节点增加唯一ID标识,使用one-hot向量来表示特征
    • 图太稀疏可以增加虚拟的节点或者边
    • 图太稠密的时候可以在消息传递的过程中进行邻居采样
    • 图的规模太大的时候可以分割成若干个子图分别进行嵌入

CS224W:图机器学习04 - 图36