消息传递的流程图
我们知道,在图网络中,根据消息传递的过程,其流程图如下所示:
下面主要看一下在这个消息传递机制中的aggrate()
聚合函数,它是怎么样对message()函数归一化后的函数进行聚合的,这一部分主要是通过打断点调试PyT中的MessgePassing这个类和scatter这个类得到的结论。
scatter的聚合类型
官方函数
通过不同的参数选择不同的聚合方式,接下来拿出其中的一种聚合方式出来看一下它是怎么工作的:
def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
index = broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
参考博客一:从坐标角度来看
在开始scatter_sum
之前,我们先来看一下scatter
函数是用来做什么的
这部分主要是参考这篇博客的工作,来看一下其数据的变化过程:
scatter简单来说就是通过一个张量src来修改另一个张量,哪个元素需要修改、用src中的哪个元素来修改是由dim和index来决定的,官方给出了3维张量的具体操作说明,如下所示:
self[index[i][j][k]] [j][k] = src[i][j][k] # if dim==0
self[i] index[[i][j][k]] [k] = src[i][j][k] # if dim==1
self[i][j] index[[i][j][k]] = src[i][j][k] # if dim==2
拿一个例子来看一下:
x = torch.rand(2, 5)
#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
# [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
#tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
# [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
# [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])
上面的torch.zeros(3, 5)就是我们需要修改的张量,x就是src,我们通过index来索引src,然后再将索引到的值还给我们需要修改的张量torch.zeros(3, 5)。
对应上面给出的索引,我们可以得到对2维张量来说,我们只需要索引行列即可。
首先,初始数据如下所示:
下面,针对目标数据进行填充,索引形式,而
可以看到,我们的目标数据中,只有10个被填充,剩下的仍然是原来的0。
参考博客二:从行/列索引,整个维度来看
不同于上面的参考博客一,这篇博客是根据dim=0
还是dim=1
来设置的行发散还是列发散:
情况1:dim=1
在dim=1
的情况下,使用的是行发散,如下情况所示:
情况2:dim=0
scatter_add
如果,我们出现了目标数据列数 < 原始数据列数的情况,此时可以scatter()
提供的不同聚合类型,下面使用scatter_add
来解释这种情况:
扩充到三维张量
我们上面的二维的现在理解清楚后,来看一下三维的情况
import torch
from torch_scatter.utils import broadcast
def scatter_sum(src=None, index=None, dim=1, out=None, dim_size=None):
index = broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
src=torch.randn(1,3,2)
print('输入数据的大小{0}'.format(src.size()))
print(src)
index = torch.tensor([0,0,1])
out = scatter_sum(src, index)
print('输出数据的大小{}'.format(out.size()))
print(out)
输入数据的大小torch.Size([1, 3, 2]) tensor([[[-0.5140, 1.4386],
[-0.8645, -0.8863],
[-1.0041, 0.9199]]])
输出数据的大小torch.Size([1, 2, 2])
tensor([[[-1.3785, 0.5523],
[-1.0041, 0.9199]]])
下面,我们来看一下,这个3维的张量输出是怎么形成的: