消息传递的流程图

我们知道,在图网络中,根据消息传递的过程,其流程图如下所示:
下面主要看一下在这个消息传递机制中的aggrate()聚合函数,它是怎么样对message()函数归一化后的函数进行聚合的,这一部分主要是通过打断点调试PyT中的MessgePassing这个类和scatter这个类得到的结论。

scatter的聚合类型

♠ 聚合机制 - 图1

官方函数

通过不同的参数选择不同的聚合方式,接下来拿出其中的一种聚合方式出来看一下它是怎么工作的:

  1. def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
  2. out: Optional[torch.Tensor] = None,
  3. dim_size: Optional[int] = None) -> torch.Tensor:
  4. index = broadcast(index, src, dim)
  5. if out is None:
  6. size = list(src.size())
  7. if dim_size is not None:
  8. size[dim] = dim_size
  9. elif index.numel() == 0:
  10. size[dim] = 0
  11. else:
  12. size[dim] = int(index.max()) + 1
  13. out = torch.zeros(size, dtype=src.dtype, device=src.device)
  14. return out.scatter_add_(dim, index, src)
  15. else:
  16. return out.scatter_add_(dim, index, src)

参考博客一:从坐标角度来看

在开始scatter_sum之前,我们先来看一下scatter函数是用来做什么的

这部分主要是参考这篇博客的工作,来看一下其数据的变化过程:

scatter简单来说就是通过一个张量src来修改另一个张量,哪个元素需要修改、用src中的哪个元素来修改是由dim和index来决定的,官方给出了3维张量的具体操作说明,如下所示:

  1. self[index[i][j][k]] [j][k] = src[i][j][k] # if dim==0
  2. self[i] index[[i][j][k]] [k] = src[i][j][k] # if dim==1
  3. self[i][j] index[[i][j][k]] = src[i][j][k] # if dim==2

拿一个例子来看一下:

  1. x = torch.rand(2, 5)
  2. #tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
  3. # [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])
  4. torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
  5. #tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
  6. # [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
  7. # [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

上面的torch.zeros(3, 5)就是我们需要修改的张量,x就是src,我们通过index来索引src,然后再将索引到的值还给我们需要修改的张量torch.zeros(3, 5)。

对应上面给出的索引,我们可以得到对2维张量来说,我们只需要索引行列即可。

首先,初始数据如下所示:
image.png
image.png
image.png

下面,针对目标数据进行填充,索引形式♠ 聚合机制 - 图5,而
♠ 聚合机制 - 图6
♠ 聚合机制 - 图7
可以看到,我们的目标数据中,只有10个被填充,剩下的仍然是原来的0。

参考博客二:从行/列索引,整个维度来看

不同于上面的参考博客一,这篇博客是根据dim=0还是dim=1来设置的行发散还是列发散:

情况1:dim=1

dim=1的情况下,使用的是行发散,如下情况所示:

image.png

情况2:dim=0

dim=0的情况下,使用的是列发散,如下图所示:
image.png

scatter_add

如果,我们出现了目标数据列数 < 原始数据列数的情况,此时可以scatter()提供的不同聚合类型,下面使用scatter_add来解释这种情况:
image.png

扩充到三维张量

我们上面的二维的现在理解清楚后,来看一下三维的情况

  1. import torch
  2. from torch_scatter.utils import broadcast
  3. def scatter_sum(src=None, index=None, dim=1, out=None, dim_size=None):
  4. index = broadcast(index, src, dim)
  5. if out is None:
  6. size = list(src.size())
  7. if dim_size is not None:
  8. size[dim] = dim_size
  9. elif index.numel() == 0:
  10. size[dim] = 0
  11. else:
  12. size[dim] = int(index.max()) + 1
  13. out = torch.zeros(size, dtype=src.dtype, device=src.device)
  14. return out.scatter_add_(dim, index, src)
  15. else:
  16. return out.scatter_add_(dim, index, src)
  17. src=torch.randn(1,3,2)
  18. print('输入数据的大小{0}'.format(src.size()))
  19. print(src)
  20. index = torch.tensor([0,0,1])
  21. out = scatter_sum(src, index)
  22. print('输出数据的大小{}'.format(out.size()))
  23. print(out)

输入数据的大小torch.Size([1, 3, 2]) tensor([[[-0.5140, 1.4386],

  1. [-0.8645, -0.8863],
  2. [-1.0041, 0.9199]]])

输出数据的大小torch.Size([1, 2, 2])

tensor([[[-1.3785, 0.5523],

  1. [-1.0041, 0.9199]]])

下面,我们来看一下,这个3维的张量输出是怎么形成的:
image.png