参考来源:
CSDN:GCN torchgeometric utils scatter方法源码解析及样例

源码及注释

源码位置:torch_geometric\utils\scatter.py
代码相当于在原有 torch_scatter 包的基础上进行二次开发,1.4.3 之后版本移除了,相当于只要理解 torch.scatter(…, reduce='add') 就好了。
下面是开发者说的原话:
“The scatter call got removed in one of the more recent versions, and you can now simply use torch.scatter(…, reduce=‘add’) for the same effect.”(scatter 调用在较新的版本之一中被删除,您现在可以简单地使用 torch.scatter(…, reduce=’add’) 来获得相同的效果。)

  1. import torch_scatter
  2. def scatter_(name, src, index, dim=0, dim_size=None):
  3. r"""Aggregates all values from the :attr:`src` tensor at the indices
  4. specified in the :attr:`index` tensor along the first dimension.
  5. If multiple indices reference the same location, their contributions
  6. are aggregated according to :attr:`name` (either :obj:`"add"`,
  7. :obj:`"mean"` or :obj:`"max"`).
  8. Args:
  9. name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`,
  10. :obj:`"min"`, :obj:`"max"`).
  11. src (Tensor): The source tensor.
  12. index (LongTensor): The indices of elements to scatter.
  13. dim (int, optional): The axis along which to index. (default: :obj:`0`)
  14. dim_size (int, optional): Automatically create output tensor with size
  15. :attr:`dim_size` in the first dimension. If set to :attr:`None`, a
  16. minimal sized output tensor is returned. (default: :obj:`None`)
  17. :rtype: :class:`Tensor`
  18. """
  19. # 例行断言
  20. assert name in ['add', 'mean', 'min', 'max']
  21. # 返回对象属性值
  22. op = getattr(torch_scatter, 'scatter_{}'.format(name))
  23. # 获取对象属性后返回值可直接使用,看具体案例
  24. out = op(src, index, dim, None, dim_size)
  25. out = out[0] if isinstance(out, tuple) else out
  26. # 限制取最值后数据约束
  27. if name == 'max':
  28. out[out < -10000] = 0
  29. elif name == 'min':
  30. out[out > 10000] = 0
  31. return out

torch_scatter 方法

结合官网和以下网址大家可以看得懂(看不懂没关系,多看几遍,我也是看了好几遍才懂的,不得不佩服其设计),这里不做过多解释。
语雀:torch.tensor.scatter() 和 torch.tensor.scatter_() 函数
Pytorch 文档:https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html

scatter_ 方法解析

常规运行

在理解上述原有开发包的基础上,我们再来看一下此方法的作用。按传入方式聚合节点信息。其是用于创建消息传递层的基类。
为了便于理解,这里给出一个具体例子:
假设我们有一个图,图中有四个节点,节点特征维度为 3。
其关系如下图所示(为方便理解我画成了单向图):
image.png

  1. import torch
  2. import torch_scatter
  3. def scatter_(name, src, index, dim=0, dim_size=None):
  4. r"""Aggregates all values from the :attr:`src` tensor at the indices
  5. specified in the :attr:`index` tensor along the first dimension.
  6. If multiple indices reference the same location, their contributions
  7. are aggregated according to :attr:`name` (either :obj:`"add"`,
  8. :obj:`"mean"` or :obj:`"max"`).
  9. Args:
  10. name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`,
  11. :obj:`"min"`, :obj:`"max"`).
  12. src (Tensor): The source tensor.
  13. index (LongTensor): The indices of elements to scatter.
  14. dim (int, optional): The axis along which to index. (default: :obj:`0`)
  15. dim_size (int, optional): Automatically create output tensor with size
  16. :attr:`dim_size` in the first dimension. If set to :attr:`None`, a
  17. minimal sized output tensor is returned. (default: :obj:`None`)
  18. :rtype: :class:`Tensor`
  19. """
  20. # 例行断言
  21. assert name in ['add', 'mean', 'min', 'max']
  22. # 返回对象属性值
  23. op = getattr(torch_scatter, 'scatter_{}'.format(name))
  24. # 获取对象属性后返回值可直接使用,看具体案例
  25. out = op(src, index, dim, None, dim_size)
  26. # out = src.clone()
  27. # out = op(src, index, dim, out, dim_size)
  28. out = out[0] if isinstance(out, tuple) else out
  29. # 限制取最值后数据约束
  30. if name == 'max':
  31. out[out < -10000] = 0
  32. elif name == 'min':
  33. out[out > 10000] = 0
  34. return out
  35. # src 节点特征
  36. a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
  37. # index 索引
  38. b = torch.tensor([2, 3, 0, 1])
  39. # dim 维度信息
  40. c = 1
  41. # dim_size 返回的维度
  42. e = 4
  43. print(scatter_('add', a, b, dim=0, dim_size=e))

结果:

  1. tensor([[3., 3., 3.],
  2. [4., 4., 4.],
  3. [1., 1., 1.],
  4. [2., 2., 2.]])

可以看到他进行了一次消息传递,以节点 1 为例,他接收了来自节点 3 的消息,所以特征变为 [4., 4., 4.]

节点特征维度

此样例说明节点信息聚合时,各个维度相互独立。

  1. a = torch.tensor([[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3]])
  2. b = torch.tensor([2, 3, 0, 1])
  3. c = 1
  4. e = 4
  5. print(scatter_('add', a, b, dim=0, dim_size=e))
  6. """output:
  7. tensor([[3.1000, 3.2000, 3.3000],
  8. [4.1000, 4.2000, 4.3000],
  9. [1.1000, 1.2000, 1.3000],
  10. [2.1000, 2.2000, 2.3000]])
  11. """

dim_size(运算后返回的维度)参数

只修改 dim_size,可以看到返回维度增加。

  1. a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
  2. b = torch.tensor([2, 3, 0, 1])
  3. c = 1
  4. e = 8
  5. print(scatter_('add', a, b, dim=0, dim_size=e))
  6. """output:
  7. tensor([[3., 3., 3.],
  8. [4., 4., 4.],
  9. [1., 1., 1.],
  10. [2., 2., 2.],
  11. [0., 0., 0.],
  12. [0., 0., 0.],
  13. [0., 0., 0.],
  14. [0., 0., 0.]])
  15. """

index(索引)参数

这里相当于修改了节点的指向,修改后图:
image.png

  1. a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
  2. b = torch.tensor([2, 0, 0, 1])
  3. c = 1
  4. e = 4
  5. print(scatter_('add', a, b, dim=0, dim_size=e))
  6. """output:
  7. tensor([[5., 5., 5.],
  8. [4., 4., 4.],
  9. [1., 1., 1.],
  10. [0., 0., 0.]])
  11. """

因为这里使用的 'add' 方法,所以同时传递到一个节点的信息被累加。

name(聚合方式)

很显然,当我把聚合方式改为均值,除了节点1有两个值传入被平均,其他保持不变;同理最值一样道理。

  1. a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
  2. b = torch.tensor([2, 0, 0, 1])
  3. c = 1
  4. e = 4
  5. print(scatter_('mean', a, b, dim=0, dim_size=e))
  6. """output:
  7. tensor([[2.5000, 2.5000, 2.5000],
  8. [4.0000, 4.0000, 4.0000],
  9. [1.0000, 1.0000, 1.0000],
  10. [0.0000, 0.0000, 0.0000]])
  11. """

1.4.3 之后版本

源码位置:Lib\site-packages\torch_scatter\scatter.py

  1. from typing import Optional, Tuple
  2. import torch
  3. from .utils import broadcast
  4. def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
  5. out: Optional[torch.Tensor] = None,
  6. dim_size: Optional[int] = None) -> torch.Tensor:
  7. index = broadcast(index, src, dim)
  8. if out is None:
  9. size = list(src.size())
  10. if dim_size is not None:
  11. size[dim] = dim_size
  12. elif index.numel() == 0:
  13. size[dim] = 0
  14. else:
  15. size[dim] = int(index.max()) + 1
  16. out = torch.zeros(size, dtype=src.dtype, device=src.device)
  17. return out.scatter_add_(dim, index, src)
  18. else:
  19. return out.scatter_add_(dim, index, src)
  20. def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
  21. out: Optional[torch.Tensor] = None,
  22. dim_size: Optional[int] = None) -> torch.Tensor:
  23. return scatter_sum(src, index, dim, out, dim_size)
  24. def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
  25. out: Optional[torch.Tensor] = None,
  26. dim_size: Optional[int] = None) -> torch.Tensor:
  27. return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)
  28. def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
  29. out: Optional[torch.Tensor] = None,
  30. dim_size: Optional[int] = None) -> torch.Tensor:
  31. out = scatter_sum(src, index, dim, out, dim_size)
  32. dim_size = out.size(dim)
  33. index_dim = dim
  34. if index_dim < 0:
  35. index_dim = index_dim + src.dim()
  36. if index.dim() <= index_dim:
  37. index_dim = index.dim() - 1
  38. ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
  39. count = scatter_sum(ones, index, index_dim, None, dim_size)
  40. count[count < 1] = 1
  41. count = broadcast(count, out, dim)
  42. if out.is_floating_point():
  43. out.true_divide_(count)
  44. else:
  45. out.floor_divide_(count)
  46. return out
  47. def scatter_min(
  48. src: torch.Tensor, index: torch.Tensor, dim: int = -1,
  49. out: Optional[torch.Tensor] = None,
  50. dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
  51. return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
  52. def scatter_max(
  53. src: torch.Tensor, index: torch.Tensor, dim: int = -1,
  54. out: Optional[torch.Tensor] = None,
  55. dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
  56. return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
  57. def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
  58. out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
  59. reduce: str = "sum") -> torch.Tensor:
  60. r"""
  61. |
  62. .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
  63. master/docs/source/_figures/add.svg?sanitize=true
  64. :align: center
  65. :width: 400px
  66. |
  67. Reduces all values from the :attr:`src` tensor into :attr:`out` at the
  68. indices specified in the :attr:`index` tensor along a given axis
  69. :attr:`dim`.
  70. For each value in :attr:`src`, its output index is specified by its index
  71. in :attr:`src` for dimensions outside of :attr:`dim` and by the
  72. corresponding value in :attr:`index` for dimension :attr:`dim`.
  73. The applied reduction is defined via the :attr:`reduce` argument.
  74. Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional
  75. tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
  76. and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional
  77. tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.
  78. Moreover, the values of :attr:`index` must be between :math:`0` and
  79. :math:`y - 1`, although no specific ordering of indices is required.
  80. The :attr:`index` tensor supports broadcasting in case its dimensions do
  81. not match with :attr:`src`.
  82. For one-dimensional tensors with :obj:`reduce="sum"`, the operation
  83. computes
  84. .. math::
  85. \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
  86. where :math:`\sum_j` is over :math:`j` such that
  87. :math:`\mathrm{index}_j = i`.
  88. .. note::
  89. This operation is implemented via atomic operations on the GPU and is
  90. therefore **non-deterministic** since the order of parallel operations
  91. to the same value is undetermined.
  92. For floating-point variables, this results in a source of variance in
  93. the result.
  94. :param src: The source tensor.
  95. :param index: The indices of elements to scatter.
  96. :param dim: The axis along which to index. (default: :obj:`-1`)
  97. :param out: The destination tensor.
  98. :param dim_size: If :attr:`out` is not given, automatically create output
  99. with size :attr:`dim_size` at dimension :attr:`dim`.
  100. If :attr:`dim_size` is not given, a minimal sized output tensor
  101. according to :obj:`index.max() + 1` is returned.
  102. :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`,
  103. :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
  104. :rtype: :class:`Tensor`
  105. .. code-block:: python
  106. from torch_scatter import scatter
  107. src = torch.randn(10, 6, 64)
  108. index = torch.tensor([0, 1, 0, 1, 2, 1])
  109. # Broadcasting in the first and last dim.
  110. out = scatter(src, index, dim=1, reduce="sum")
  111. print(out.size())
  112. .. code-block::
  113. torch.Size([10, 3, 64])
  114. """
  115. if reduce == 'sum' or reduce == 'add':
  116. return scatter_sum(src, index, dim, out, dim_size)
  117. if reduce == 'mul':
  118. return scatter_mul(src, index, dim, out, dim_size)
  119. elif reduce == 'mean':
  120. return scatter_mean(src, index, dim, out, dim_size)
  121. elif reduce == 'min':
  122. return scatter_min(src, index, dim, out, dim_size)[0]
  123. elif reduce == 'max':
  124. return scatter_max(src, index, dim, out, dim_size)[0]
  125. else:
  126. raise ValueError