参考来源:
CSDN:GCN torchgeometric utils scatter方法源码解析及样例
源码及注释
源码位置:torch_geometric\utils\scatter.py
代码相当于在原有 torch_scatter
包的基础上进行二次开发,1.4.3 之后版本移除了,相当于只要理解 torch.scatter(…, reduce='add')
就好了。
下面是开发者说的原话:
“The scatter call got removed in one of the more recent versions, and you can now simply use torch.scatter(…, reduce=‘add’) for the same effect.”(scatter 调用在较新的版本之一中被删除,您现在可以简单地使用 torch.scatter(…, reduce=’add’) 来获得相同的效果。)
import torch_scatter
def scatter_(name, src, index, dim=0, dim_size=None):
r"""Aggregates all values from the :attr:`src` tensor at the indices
specified in the :attr:`index` tensor along the first dimension.
If multiple indices reference the same location, their contributions
are aggregated according to :attr:`name` (either :obj:`"add"`,
:obj:`"mean"` or :obj:`"max"`).
Args:
name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`,
:obj:`"min"`, :obj:`"max"`).
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index. (default: :obj:`0`)
dim_size (int, optional): Automatically create output tensor with size
:attr:`dim_size` in the first dimension. If set to :attr:`None`, a
minimal sized output tensor is returned. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
# 例行断言
assert name in ['add', 'mean', 'min', 'max']
# 返回对象属性值
op = getattr(torch_scatter, 'scatter_{}'.format(name))
# 获取对象属性后返回值可直接使用,看具体案例
out = op(src, index, dim, None, dim_size)
out = out[0] if isinstance(out, tuple) else out
# 限制取最值后数据约束
if name == 'max':
out[out < -10000] = 0
elif name == 'min':
out[out > 10000] = 0
return out
torch_scatter 方法
结合官网和以下网址大家可以看得懂(看不懂没关系,多看几遍,我也是看了好几遍才懂的,不得不佩服其设计),这里不做过多解释。
语雀:torch.tensor.scatter() 和 torch.tensor.scatter_() 函数
Pytorch 文档:https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html
scatter_ 方法解析
常规运行
在理解上述原有开发包的基础上,我们再来看一下此方法的作用。按传入方式聚合节点信息。其是用于创建消息传递层的基类。
为了便于理解,这里给出一个具体例子:
假设我们有一个图,图中有四个节点,节点特征维度为 3。
其关系如下图所示(为方便理解我画成了单向图):
import torch
import torch_scatter
def scatter_(name, src, index, dim=0, dim_size=None):
r"""Aggregates all values from the :attr:`src` tensor at the indices
specified in the :attr:`index` tensor along the first dimension.
If multiple indices reference the same location, their contributions
are aggregated according to :attr:`name` (either :obj:`"add"`,
:obj:`"mean"` or :obj:`"max"`).
Args:
name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`,
:obj:`"min"`, :obj:`"max"`).
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index. (default: :obj:`0`)
dim_size (int, optional): Automatically create output tensor with size
:attr:`dim_size` in the first dimension. If set to :attr:`None`, a
minimal sized output tensor is returned. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
# 例行断言
assert name in ['add', 'mean', 'min', 'max']
# 返回对象属性值
op = getattr(torch_scatter, 'scatter_{}'.format(name))
# 获取对象属性后返回值可直接使用,看具体案例
out = op(src, index, dim, None, dim_size)
# out = src.clone()
# out = op(src, index, dim, out, dim_size)
out = out[0] if isinstance(out, tuple) else out
# 限制取最值后数据约束
if name == 'max':
out[out < -10000] = 0
elif name == 'min':
out[out > 10000] = 0
return out
# src 节点特征
a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
# index 索引
b = torch.tensor([2, 3, 0, 1])
# dim 维度信息
c = 1
# dim_size 返回的维度
e = 4
print(scatter_('add', a, b, dim=0, dim_size=e))
结果:
tensor([[3., 3., 3.],
[4., 4., 4.],
[1., 1., 1.],
[2., 2., 2.]])
可以看到他进行了一次消息传递,以节点 1 为例,他接收了来自节点 3 的消息,所以特征变为 [4., 4., 4.]
。
节点特征维度
此样例说明节点信息聚合时,各个维度相互独立。
a = torch.tensor([[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3]])
b = torch.tensor([2, 3, 0, 1])
c = 1
e = 4
print(scatter_('add', a, b, dim=0, dim_size=e))
"""output:
tensor([[3.1000, 3.2000, 3.3000],
[4.1000, 4.2000, 4.3000],
[1.1000, 1.2000, 1.3000],
[2.1000, 2.2000, 2.3000]])
"""
dim_size(运算后返回的维度)参数
只修改 dim_size
,可以看到返回维度增加。
a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
b = torch.tensor([2, 3, 0, 1])
c = 1
e = 8
print(scatter_('add', a, b, dim=0, dim_size=e))
"""output:
tensor([[3., 3., 3.],
[4., 4., 4.],
[1., 1., 1.],
[2., 2., 2.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
"""
index(索引)参数
这里相当于修改了节点的指向,修改后图:
a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
b = torch.tensor([2, 0, 0, 1])
c = 1
e = 4
print(scatter_('add', a, b, dim=0, dim_size=e))
"""output:
tensor([[5., 5., 5.],
[4., 4., 4.],
[1., 1., 1.],
[0., 0., 0.]])
"""
因为这里使用的 'add'
方法,所以同时传递到一个节点的信息被累加。
name(聚合方式)
很显然,当我把聚合方式改为均值,除了节点1有两个值传入被平均,其他保持不变;同理最值一样道理。
a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
b = torch.tensor([2, 0, 0, 1])
c = 1
e = 4
print(scatter_('mean', a, b, dim=0, dim_size=e))
"""output:
tensor([[2.5000, 2.5000, 2.5000],
[4.0000, 4.0000, 4.0000],
[1.0000, 1.0000, 1.0000],
[0.0000, 0.0000, 0.0000]])
"""
1.4.3 之后版本
源码位置:Lib\site-packages\torch_scatter\scatter.py
from typing import Optional, Tuple
import torch
from .utils import broadcast
def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
index = broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return scatter_sum(src, index, dim, out, dim_size)
def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
out = scatter_sum(src, index, dim, out, dim_size)
dim_size = out.size(dim)
index_dim = dim
if index_dim < 0:
index_dim = index_dim + src.dim()
if index.dim() <= index_dim:
index_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, index_dim, None, dim_size)
count[count < 1] = 1
count = broadcast(count, out, dim)
if out.is_floating_point():
out.true_divide_(count)
else:
out.floor_divide_(count)
return out
def scatter_min(
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
def scatter_max(
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor:
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/add.svg?sanitize=true
:align: center
:width: 400px
|
Reduces all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional
tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional
tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`index` must be between :math:`0` and
:math:`y - 1`, although no specific ordering of indices is required.
The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
.. note::
This operation is implemented via atomic operations on the GPU and is
therefore **non-deterministic** since the order of parallel operations
to the same value is undetermined.
For floating-point variables, this results in a source of variance in
the result.
:param src: The source tensor.
:param index: The indices of elements to scatter.
:param dim: The axis along which to index. (default: :obj:`-1`)
:param out: The destination tensor.
:param dim_size: If :attr:`out` is not given, automatically create output
with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:rtype: :class:`Tensor`
.. code-block:: python
from torch_scatter import scatter
src = torch.randn(10, 6, 64)
index = torch.tensor([0, 1, 0, 1, 2, 1])
# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size)
if reduce == 'mul':
return scatter_mul(src, index, dim, out, dim_size)
elif reduce == 'mean':
return scatter_mean(src, index, dim, out, dim_size)
elif reduce == 'min':
return scatter_min(src, index, dim, out, dim_size)[0]
elif reduce == 'max':
return scatter_max(src, index, dim, out, dim_size)[0]
else:
raise ValueError