参考来源:
    https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.dropout_adj
    https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/utils/dropout.html

    1. dropout_adj(
    2. edge_index,
    3. edge_attr=None,
    4. p=0.5,
    5. force_undirected=False,
    6. num_nodes=None,
    7. training=True
    8. )

    功能:
    使用来自伯努利分布的样本,以概率 p 从邻接矩阵 (edge_index, edge_attr) 中随机删除边。
    参数:

    • **edge_index** (LongTensor):边的索引矩阵。
    • **edge_attr** (Tensor, optional):边的权重或边的多维特征。 (默认:None)
    • **p** (float, optional):Dropout 概率。 (默认值:0.5)
    • **force_undirected** (bool, optional):如果设置为 True,将删除或保留无向边的两个边。 (默认:False)
    • **num_nodes** (int, optional):节点数,即 edge_index 的 max_val + 1。 (默认:None)
    • **training** (bool, optional):如果设置为 False,则此操作为空操作。 (默认:True)
    1. def dropout_adj(edge_index, edge_attr=None, p=0.5, force_undirected=False,
    2. num_nodes=None, training=True):
    3. r"""Randomly drops edges from the adjacency matrix
    4. :obj:`(edge_index, edge_attr)` with probability :obj:`p` using samples from
    5. a Bernoulli distribution.
    6. Args:
    7. edge_index (LongTensor): The edge indices.
    8. edge_attr (Tensor, optional): Edge weights or multi-dimensional
    9. edge features. (default: :obj:`None`)
    10. p (float, optional): Dropout probability. (default: :obj:`0.5`)
    11. force_undirected (bool, optional): If set to :obj:`True`, will either
    12. drop or keep both edges of an undirected edge.
    13. (default: :obj:`False`)
    14. num_nodes (int, optional): The number of nodes, *i.e.*
    15. :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
    16. training (bool, optional): If set to :obj:`False`, this operation is a
    17. no-op. (default: :obj:`True`)
    18. """
    19. if p < 0. or p > 1.:
    20. raise ValueError('Dropout probability has to be between 0 and 1, '
    21. 'but got {}'.format(p))
    22. if not training or p == 0.0:
    23. return edge_index, edge_attr
    24. N = maybe_num_nodes(edge_index, num_nodes)
    25. row, col = edge_index
    26. if force_undirected:
    27. row, col, edge_attr = filter_adj(row, col, edge_attr, row < col)
    28. mask = edge_index.new_full((row.size(0), ), 1 - p, dtype=torch.float)
    29. mask = torch.bernoulli(mask).to(torch.bool)
    30. row, col, edge_attr = filter_adj(row, col, edge_attr, mask)
    31. if force_undirected:
    32. edge_index = torch.stack(
    33. [torch.cat([row, col], dim=0),
    34. torch.cat([col, row], dim=0)], dim=0)
    35. if edge_attr is not None:
    36. edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
    37. edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)
    38. else:
    39. edge_index = torch.stack([row, col], dim=0)
    40. return edge_index, edge_attr