一 什么是消息传递方案

将卷积算子推广到不规则域通常表示为邻域聚合或消息传递方案。其中

  • § PyT消息传递网络 - 图1表示在第§ PyT消息传递网络 - 图2层的节点§ PyT消息传递网络 - 图3的节点特征
  • § PyT消息传递网络 - 图4表示从节点§ PyT消息传递网络 - 图5到节点§ PyT消息传递网络 - 图6的边的特征

所以消息传递网络可以被描述为:
§ PyT消息传递网络 - 图7
上述式子一中的参数解释:

  1. § PyT消息传递网络 - 图8表示一个可微分的,改变位置但是不改变值得函数,例如:求和函数§ PyT消息传递网络 - 图9、平均数或者最大数
  2. § PyT消息传递网络 - 图10§ PyT消息传递网络 - 图11表示可微分的函数,例如:MLP(多层感知机)

image.png
图1. 消息传递网络

二 消息传递基类

PyTorch Geometric(后面简称为PyT)提供了MessagePassing作为消息传递得基类,只需要处理好这个基类得关系我们就可以自定义一些自己的网络出来,由公式1来看,我们在使用这个网络过程当中需要关注的东西只有3个

  • 函数§ PyT消息传递网络 - 图13:对应message()函数
  • 函数§ PyT消息传递网络 - 图14:对应update()函数
  • 以及我们要使用到的聚合模式:aggr="add"aggr="mean"或者aggr="max"

下面是官方文档给的关于这3个需要重点关注的一些帮助:

  • MessagePassing(aggr="add", flow="source_to_target", node_dim=-2):这个就是PyT提供的MessagePassing接口需要输入的3个参数,
    1. 第一个就是聚合模式,就是上述说的3中聚合模式
    2. 第二个是表示聚合的方向,上面给的是source_to_target,那么同理会有target_to_source
    3. 第三个参数node_dim表示的是沿着哪个轴进行传递消息
  • MessagePassing.propagate(edge_index, size=None, **kwargs):调用消息传递的初始函数。这个函数使用到的参数edge_index是边的索引,另外要想构建一个网络肯定还需要别的数据。propagate()这个传递函数并不局限于只传递邻接矩阵为§ PyT消息传递网络 - 图15大小的对称矩阵,它也可以传递§ PyT消息传递网络 - 图16大小的邻接矩阵,这个时候传递的参数size=(N,M),而我们上述默认的size=None就代表是一个对称的邻接矩阵。通过数组索引,我们就能找到邻接矩阵中对应节点,例如§ PyT消息传递网络 - 图17

image.png
图2. 邻接矩阵

  • MessagePassing.message(...):在flow="source_to_target"的情况下,就创建向节点§ PyT消息传递网络 - 图19流动的消息;如果flow="target_to_source"就传递向节点§ PyT消息传递网络 - 图20的消息。它会处理propagate()中的输入参数edge_index,然后通过给节点的名字后面添上_i_j来确定带有_i是的中心节点(source),带有_j的邻边节点。对应的是公式1中的§ PyT消息传递网络 - 图21函数。

image.png
图3. 中心节点和邻边节点

  • MessagePassing.update(aggr_out, ...):对于每个节点§ PyT消息传递网络 - 图23更新节点的状态,对应的是公式1中的§ PyT消息传递网络 - 图24函数。

下面通过使用两个简单的GNN的例子,GCN和EdgeConv来看一下怎么使用。

三 GCN例子

GCN层在数学上定义为:
§ PyT消息传递网络 - 图25
通过上面的式子2,我们可以看到一个GCN层做了3个工作:

  • 输入的节点特征首先被权重矩阵§ PyT消息传递网络 - 图26进行了改变
  • 然后通过它们的度进行归一化
  • 再将其相加

同样的,在写代码的时候,我们也可以将步骤分解,步骤如下:

  1. 在邻接矩阵中添加自循环(self-loops)
  2. 对节点的特征矩阵进行线性变换
  3. 计算归一化参数
  4. § PyT消息传递网络 - 图27中归一化节点特征
  5. 将邻边节点特征加起来(对应的aggr="add"

其中的步骤1-3在消息传递之前被计算,通过使用MessagePassing这个基类接口步骤4-5也同样能够很简单地实现,整个网络层实现如下:

  1. import torch
  2. from torch_geometric.nn import MessagePassing
  3. from torch_geometric.utils import add_self_loops, degree
  4. class GCNConv(MessagePassing):
  5. def __init__(self, in_channels, out_channels):
  6. super(GCNConv, self).__init__(aggr='add') # 对应步骤5的aggr="add"聚类方式
  7. self.lin = torch.nn.Linear(in_channels, out_channels)
  8. def forward(self, x, edge_index):
  9. # x的大小为[N, in_channels]
  10. # edge_index的大小为[2, E]
  11. # 步骤1:给邻接矩阵添加自循环
  12. edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
  13. # 步骤2:给特征矩阵进行线性变换
  14. x = self.lin(x)
  15. # 步骤3:计算归一化
  16. row, col = edge_index
  17. deg = degree(col, x.size(0), dtype=x.type)
  18. deg_inv_sqrt = deg.pow(-0.5)
  19. norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
  20. # 步骤4-5:开始传递消息
  21. return self.propagate(edge_index, x=x, norm=norm)
  22. def message(self, x_j, norm):
  23. # x_j的大小为 [E, out_channels]
  24. # 步骤4: 归一化节点特征
  25. return norm.view(-1,1) * x_j

在上述代码中中,数据在进入网络->出去网络都是在forward()函数中进行的。上述代码中通过使用torch_geometric.utils.add_self_loops()函数实现步骤1,即给边索引添加自循环。通过使用torch.nn.Linear函数实现节点特征的线性变换,实现了步骤2。

3.1 add_self_loops()函数

下面是add_self_loops()函数的源码:

  1. def add_self_loops(edge_index, edge_weight: Optional[torch.Tensor] = None,
  2. fill_value: float = 1., num_nodes: Optional[int] = None):
  3. r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node
  4. :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`.
  5. In case the graph is weighted, self-loops will be added with edge weights
  6. denoted by :obj:`fill_value`.
  7. Args:
  8. edge_index (LongTensor): The edge indices.
  9. edge_weight (Tensor, optional): One-dimensional edge weights.
  10. (default: :obj:`None`)
  11. fill_value (float, optional): If :obj:`edge_weight` is not :obj:`None`,
  12. will add self-loops with edge weights of :obj:`fill_value` to the
  13. graph. (default: :obj:`1.`)
  14. num_nodes (int, optional): The number of nodes, *i.e.*
  15. :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
  16. :rtype: (:class:`LongTensor`, :class:`Tensor`)
  17. """
  18. N = maybe_num_nodes(edge_index, num_nodes)
  19. loop_index = torch.arange(0, N, dtype=torch.long, device=edge_index.device)
  20. loop_index = loop_index.unsqueeze(0).repeat(2, 1)
  21. if edge_weight is not None:
  22. assert edge_weight.numel() == edge_index.size(1)
  23. loop_weight = edge_weight.new_full((N, ), fill_value)
  24. edge_weight = torch.cat([edge_weight, loop_weight], dim=0)
  25. edge_index = torch.cat([edge_index, loop_index], dim=1)
  26. return edge_index, edge_weight

通过上述源码的解释,可以知道这个函数是对节点§ PyT消息传递网络 - 图28添加一条边§ PyT消息传递网络 - 图29,添加的值默认为1.

3.2 torch.nn.Linear()

这个函数是比较常见的,在CNN网络中也非常常见,这里就不解释了。

3.3 归一化

对于每个节点§ PyT消息传递网络 - 图30,其对应的边§ PyT消息传递网络 - 图31对应的有度§ PyT消息传递网络 - 图32,然后将其归一化§ PyT消息传递网络 - 图33

3.4 传递消息

之后就开始在节点§ PyT消息传递网络 - 图34和节点§ PyT消息传递网络 - 图35之间传递消息,在上述的例子中调用的是函数propagate()。这个函数会一次调用message()aggregateupdate()函数,我们需要传递节点的特征x和归一化因子norm从而完成整个消息的传递。

3.5 使用网络

  1. conv = GCNConv(16, 32)
  2. x = conv(x, edge_index)

四 EdgeConv例子

edge卷积层是用来处理点云数据的,数学上定义为:
§ PyT消息传递网络 - 图36

其中的§ PyT消息传递网络 - 图37表示的是一个多层感知机,类比于上面的GCN,我们同样可以使用MessagePassing基类实现这个网络,这次使用的是aggr="max".

  1. import torch
  2. from torch.nn import Sequential as Seq, Linear, ReLU
  3. class EdgeConv(MessagePassing):
  4. def __init__(self, in_channels, out_channels):
  5. super(EdgeConv, self).__init__(aggr="max") # 聚合方式为max
  6. self.mlp = Seq(Linear(2*in_channels, out_channels), ReLU(), Linear(out_channels, out_channels))
  7. def forward(self, x, edge_index):
  8. # x的大小[N, in_channels]
  9. # edge_index的大小[2,E]
  10. return self.propagate(edge_index, x=x)
  11. def message(self, x_i, x_j):
  12. # x_i的大小[E, in_channels]
  13. # x_j的大小[E, in_channels]
  14. tmp = torch.cat([x_i, x_j-x_i], dim=1) # tmp的大小 [E, 2*in_channels]
  15. return self.mlp(tmp)

在这个message()函数中,使用了self.mlp同时计算了目标节点的特征x_i和对每条边§ PyT消息传递网络 - 图38其相邻的节点的特征x_j - x_i.

边缘卷积实际上是一种动态卷积,它使用特征空间中的最近邻方法重新计算每一层的图,PyT中有一个函数带有GPU加速的分批k-NN图生成方法,是torch_geometric.nn.pool.knn_graph().

这里的knn_graph()计算了最近邻的图,并且调用了EdgeConv中的forward()函数.

使用上述定义的网络:

  1. conv = DynamicEdgeConv(3, 128, k=6)
  2. x = conv(x, batch)

五 MessagePassing()源码

最后将官方的源码放上,可以自己调试一下这个MessagePassing()接口,看看到底做了什么.

  1. import os
  2. import re
  3. import inspect
  4. import os.path as osp
  5. from uuid import uuid1
  6. from itertools import chain
  7. from inspect import Parameter
  8. from typing import List, Optional, Set
  9. from torch_geometric.typing import Adj, Size
  10. import torch
  11. from torch import Tensor
  12. from jinja2 import Template
  13. from torch_sparse import SparseTensor
  14. from torch_scatter import gather_csr, scatter, segment_csr
  15. from .utils.helpers import expand_left
  16. from .utils.jit import class_from_module_repr
  17. from .utils.typing import (sanitize, split_types_repr, parse_types,
  18. resolve_types)
  19. from .utils.inspector import Inspector, func_header_repr, func_body_repr
  20. class MessagePassing(torch.nn.Module):
  21. r"""Base class for creating message passing layers of the form
  22. .. math::
  23. \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i,
  24. \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}}
  25. \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),
  26. where :math:`\square` denotes a differentiable, permutation invariant
  27. function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}`
  28. and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as
  29. MLPs.
  30. See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
  31. create_gnn.html>`__ for the accompanying tutorial.
  32. Args:
  33. aggr (string, optional): The aggregation scheme to use
  34. (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"` or :obj:`None`).
  35. (default: :obj:`"add"`)
  36. flow (string, optional): The flow direction of message passing
  37. (:obj:`"source_to_target"` or :obj:`"target_to_source"`).
  38. (default: :obj:`"source_to_target"`)
  39. node_dim (int, optional): The axis along which to propagate.
  40. (default: :obj:`-2`)
  41. """
  42. special_args: Set[str] = {
  43. 'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size',
  44. 'size_i', 'size_j', 'ptr', 'index', 'dim_size'
  45. }
  46. def __init__(self, aggr: Optional[str] = "add",
  47. flow: str = "source_to_target", node_dim: int = -2):
  48. super(MessagePassing, self).__init__()
  49. self.aggr = aggr
  50. assert self.aggr in ['add', 'mean', 'max', None]
  51. self.flow = flow
  52. assert self.flow in ['source_to_target', 'target_to_source']
  53. self.node_dim = node_dim
  54. self.inspector = Inspector(self)
  55. self.inspector.inspect(self.message)
  56. self.inspector.inspect(self.aggregate, pop_first=True)
  57. self.inspector.inspect(self.message_and_aggregate, pop_first=True)
  58. self.inspector.inspect(self.update, pop_first=True)
  59. self.__user_args__ = self.inspector.keys(
  60. ['message', 'aggregate', 'update']).difference(self.special_args)
  61. self.__fused_user_args__ = self.inspector.keys(
  62. ['message_and_aggregate', 'update']).difference(self.special_args)
  63. # Support for "fused" message passing.
  64. self.fuse = self.inspector.implements('message_and_aggregate')
  65. # Support for GNNExplainer.
  66. self.__explain__ = False
  67. self.__edge_mask__ = None
  68. def __check_input__(self, edge_index, size):
  69. the_size: List[Optional[int]] = [None, None]
  70. if isinstance(edge_index, Tensor):
  71. assert edge_index.dtype == torch.long
  72. assert edge_index.dim() == 2
  73. assert edge_index.size(0) == 2
  74. if size is not None:
  75. the_size[0] = size[0]
  76. the_size[1] = size[1]
  77. return the_size
  78. elif isinstance(edge_index, SparseTensor):
  79. if self.flow == 'target_to_source':
  80. raise ValueError(
  81. ('Flow direction "target_to_source" is invalid for '
  82. 'message propagation via `torch_sparse.SparseTensor`. If '
  83. 'you really want to make use of a reverse message '
  84. 'passing flow, pass in the transposed sparse tensor to '
  85. 'the message passing module, e.g., `adj_t.t()`.'))
  86. the_size[0] = edge_index.sparse_size(1)
  87. the_size[1] = edge_index.sparse_size(0)
  88. return the_size
  89. raise ValueError(
  90. ('`MessagePassing.propagate` only supports `torch.LongTensor` of '
  91. 'shape `[2, num_messages]` or `torch_sparse.SparseTensor` for '
  92. 'argument `edge_index`.'))
  93. def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor):
  94. the_size = size[dim]
  95. if the_size is None:
  96. size[dim] = src.size(self.node_dim)
  97. elif the_size != src.size(self.node_dim):
  98. raise ValueError(
  99. (f'Encountered tensor with size {src.size(self.node_dim)} in '
  100. f'dimension {self.node_dim}, but expected size {the_size}.'))
  101. def __lift__(self, src, edge_index, dim):
  102. if isinstance(edge_index, Tensor):
  103. index = edge_index[dim]
  104. return src.index_select(self.node_dim, index)
  105. elif isinstance(edge_index, SparseTensor):
  106. if dim == 1:
  107. rowptr = edge_index.storage.rowptr()
  108. rowptr = expand_left(rowptr, dim=self.node_dim, dims=src.dim())
  109. return gather_csr(src, rowptr)
  110. elif dim == 0:
  111. col = edge_index.storage.col()
  112. return src.index_select(self.node_dim, col)
  113. raise ValueError
  114. def __collect__(self, args, edge_index, size, kwargs):
  115. i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
  116. out = {}
  117. for arg in args:
  118. if arg[-2:] not in ['_i', '_j']:
  119. out[arg] = kwargs.get(arg, Parameter.empty)
  120. else:
  121. dim = 0 if arg[-2:] == '_j' else 1
  122. data = kwargs.get(arg[:-2], Parameter.empty)
  123. if isinstance(data, (tuple, list)):
  124. assert len(data) == 2
  125. if isinstance(data[1 - dim], Tensor):
  126. self.__set_size__(size, 1 - dim, data[1 - dim])
  127. data = data[dim]
  128. if isinstance(data, Tensor):
  129. self.__set_size__(size, dim, data)
  130. data = self.__lift__(data, edge_index,
  131. j if arg[-2:] == '_j' else i)
  132. out[arg] = data
  133. if isinstance(edge_index, Tensor):
  134. out['adj_t'] = None
  135. out['edge_index'] = edge_index
  136. out['edge_index_i'] = edge_index[i]
  137. out['edge_index_j'] = edge_index[j]
  138. out['ptr'] = None
  139. elif isinstance(edge_index, SparseTensor):
  140. out['adj_t'] = edge_index
  141. out['edge_index'] = None
  142. out['edge_index_i'] = edge_index.storage.row()
  143. out['edge_index_j'] = edge_index.storage.col()
  144. out['ptr'] = edge_index.storage.rowptr()
  145. out['edge_weight'] = edge_index.storage.value()
  146. out['edge_attr'] = edge_index.storage.value()
  147. out['edge_type'] = edge_index.storage.value()
  148. out['index'] = out['edge_index_i']
  149. out['size'] = size
  150. out['size_i'] = size[1] or size[0]
  151. out['size_j'] = size[0] or size[1]
  152. out['dim_size'] = out['size_i']
  153. return out
  154. def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
  155. r"""The initial call to start propagating messages.
  156. Args:
  157. edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
  158. :obj:`torch_sparse.SparseTensor` that defines the underlying
  159. graph connectivity/message passing flow.
  160. :obj:`edge_index` holds the indices of a general (sparse)
  161. assignment matrix of shape :obj:`[N, M]`.
  162. If :obj:`edge_index` is of type :obj:`torch.LongTensor`, its
  163. shape must be defined as :obj:`[2, num_messages]`, where
  164. messages from nodes in :obj:`edge_index[0]` are sent to
  165. nodes in :obj:`edge_index[1]`
  166. (in case :obj:`flow="source_to_target"`).
  167. If :obj:`edge_index` is of type
  168. :obj:`torch_sparse.SparseTensor`, its sparse indices
  169. :obj:`(row, col)` should relate to :obj:`row = edge_index[1]`
  170. and :obj:`col = edge_index[0]`.
  171. The major difference between both formats is that we need to
  172. input the *transposed* sparse adjacency matrix into
  173. :func:`propagate`.
  174. size (tuple, optional): The size :obj:`(N, M)` of the assignment
  175. matrix in case :obj:`edge_index` is a :obj:`LongTensor`.
  176. If set to :obj:`None`, the size will be automatically inferred
  177. and assumed to be quadratic.
  178. This argument is ignored in case :obj:`edge_index` is a
  179. :obj:`torch_sparse.SparseTensor`. (default: :obj:`None`)
  180. **kwargs: Any additional data which is needed to construct and
  181. aggregate messages, and to update node embeddings.
  182. """
  183. size = self.__check_input__(edge_index, size)
  184. # Run "fused" message and aggregation (if applicable).
  185. if (isinstance(edge_index, SparseTensor) and self.fuse
  186. and not self.__explain__):
  187. coll_dict = self.__collect__(self.__fused_user_args__, edge_index,
  188. size, kwargs)
  189. msg_aggr_kwargs = self.inspector.distribute(
  190. 'message_and_aggregate', coll_dict)
  191. out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
  192. update_kwargs = self.inspector.distribute('update', coll_dict)
  193. return self.update(out, **update_kwargs)
  194. # Otherwise, run both functions in separation.
  195. elif isinstance(edge_index, Tensor) or not self.fuse:
  196. coll_dict = self.__collect__(self.__user_args__, edge_index, size,
  197. kwargs)
  198. msg_kwargs = self.inspector.distribute('message', coll_dict)
  199. out = self.message(**msg_kwargs)
  200. # For `GNNExplainer`, we require a separate message and aggregate
  201. # procedure since this allows us to inject the `edge_mask` into the
  202. # message passing computation scheme.
  203. if self.__explain__:
  204. edge_mask = self.__edge_mask__.sigmoid()
  205. # Some ops add self-loops to `edge_index`. We need to do the
  206. # same for `edge_mask` (but do not train those).
  207. if out.size(self.node_dim) != edge_mask.size(0):
  208. loop = edge_mask.new_ones(size[0])
  209. edge_mask = torch.cat([edge_mask, loop], dim=0)
  210. assert out.size(self.node_dim) == edge_mask.size(0)
  211. out = out * edge_mask.view([-1] + [1] * (out.dim() - 1))
  212. aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
  213. out = self.aggregate(out, **aggr_kwargs)
  214. update_kwargs = self.inspector.distribute('update', coll_dict)
  215. return self.update(out, **update_kwargs)
  216. def message(self, x_j: Tensor) -> Tensor:
  217. r"""Constructs messages from node :math:`j` to node :math:`i`
  218. in analogy to :math:`\phi_{\mathbf{\Theta}}` for each edge in
  219. :obj:`edge_index`.
  220. This function can take any argument as input which was initially
  221. passed to :meth:`propagate`.
  222. Furthermore, tensors passed to :meth:`propagate` can be mapped to the
  223. respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
  224. :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
  225. """
  226. return x_j
  227. def aggregate(self, inputs: Tensor, index: Tensor,
  228. ptr: Optional[Tensor] = None,
  229. dim_size: Optional[int] = None) -> Tensor:
  230. r"""Aggregates messages from neighbors as
  231. :math:`\square_{j \in \mathcal{N}(i)}`.
  232. Takes in the output of message computation as first argument and any
  233. argument which was initially passed to :meth:`propagate`.
  234. By default, this function will delegate its call to scatter functions
  235. that support "add", "mean" and "max" operations as specified in
  236. :meth:`__init__` by the :obj:`aggr` argument.
  237. """
  238. if ptr is not None:
  239. ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
  240. return segment_csr(inputs, ptr, reduce=self.aggr)
  241. else:
  242. return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
  243. reduce=self.aggr)
  244. def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
  245. r"""Fuses computations of :func:`message` and :func:`aggregate` into a
  246. single function.
  247. If applicable, this saves both time and memory since messages do not
  248. explicitly need to be materialized.
  249. This function will only gets called in case it is implemented and
  250. propagation takes place based on a :obj:`torch_sparse.SparseTensor`.
  251. """
  252. raise NotImplementedError
  253. def update(self, inputs: Tensor) -> Tensor:
  254. r"""Updates node embeddings in analogy to
  255. :math:`\gamma_{\mathbf{\Theta}}` for each node
  256. :math:`i \in \mathcal{V}`.
  257. Takes in the output of aggregation as first argument and any argument
  258. which was initially passed to :meth:`propagate`.
  259. """
  260. return inputs
  261. @torch.jit.unused
  262. def jittable(self, typing: Optional[str] = None):
  263. r"""Analyzes the :class:`MessagePassing` instance and produces a new
  264. jittable module.
  265. Args:
  266. typing (string, optional): If given, will generate a concrete
  267. instance with :meth:`forward` types based on :obj:`typing`,
  268. *e.g.*: :obj:`"(Tensor, Optional[Tensor]) -> Tensor"`.
  269. """
  270. # Find and parse `propagate()` types to format `{arg1: type1, ...}`.
  271. if hasattr(self, 'propagate_type'):
  272. prop_types = {
  273. k: sanitize(str(v))
  274. for k, v in self.propagate_type.items()
  275. }
  276. else:
  277. source = inspect.getsource(self.__class__)
  278. match = re.search(r'#\s*propagate_type:\s*\((.*)\)', source)
  279. if match is None:
  280. raise TypeError(
  281. 'TorchScript support requires the definition of the types '
  282. 'passed to `propagate()`. Please specificy them via\n\n'
  283. 'propagate_type = {"arg1": type1, "arg2": type2, ... }\n\n'
  284. 'or via\n\n'
  285. '# propagate_type: (arg1: type1, arg2: type2, ...)\n\n'
  286. 'inside the `MessagePassing` module.')
  287. prop_types = split_types_repr(match.group(1))
  288. prop_types = dict([re.split(r'\s*:\s*', t) for t in prop_types])
  289. # Parse `__collect__()` types to format `{arg:1, type1, ...}`.
  290. collect_types = self.inspector.types(
  291. ['message', 'aggregate', 'update'])
  292. # Collect `forward()` header, body and @overload types.
  293. forward_types = parse_types(self.forward)
  294. forward_types = [resolve_types(*types) for types in forward_types]
  295. forward_types = list(chain.from_iterable(forward_types))
  296. keep_annotation = len(forward_types) < 2
  297. forward_header = func_header_repr(self.forward, keep_annotation)
  298. forward_body = func_body_repr(self.forward, keep_annotation)
  299. if keep_annotation:
  300. forward_types = []
  301. elif typing is not None:
  302. forward_types = []
  303. forward_body = 8 * ' ' + f'# type: {typing}\n{forward_body}'
  304. root = os.path.dirname(osp.realpath(__file__))
  305. with open(osp.join(root, 'message_passing.jinja'), 'r') as f:
  306. template = Template(f.read())
  307. uid = uuid1().hex[:6]
  308. cls_name = f'{self.__class__.__name__}Jittable_{uid}'
  309. jit_module_repr = template.render(
  310. uid=uid,
  311. module=str(self.__class__.__module__),
  312. cls_name=cls_name,
  313. parent_cls_name=self.__class__.__name__,
  314. prop_types=prop_types,
  315. collect_types=collect_types,
  316. user_args=self.__user_args__,
  317. forward_header=forward_header,
  318. forward_types=forward_types,
  319. forward_body=forward_body,
  320. msg_args=self.inspector.keys(['message']),
  321. aggr_args=self.inspector.keys(['aggregate']),
  322. msg_and_aggr_args=self.inspector.keys(['message_and_aggregate']),
  323. update_args=self.inspector.keys(['update']),
  324. check_input=inspect.getsource(self.__check_input__)[:-1],
  325. lift=inspect.getsource(self.__lift__)[:-1],
  326. )
  327. # Instantiate a class from the rendered JIT module representation.
  328. cls = class_from_module_repr(cls_name, jit_module_repr)
  329. module = cls.__new__(cls)
  330. module.__dict__ = self.__dict__.copy()
  331. module.jittable = None
  332. return module