参考来源:
https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.dropout_adj
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/utils/dropout.html
dropout_adj(
edge_index,
edge_attr=None,
p=0.5,
force_undirected=False,
num_nodes=None,
training=True
)
功能:
使用来自伯努利分布的样本,以概率 p 从邻接矩阵 (edge_index, edge_attr) 中随机删除边。
参数:
**edge_index**
(LongTensor):边的索引矩阵。**edge_attr**
(Tensor, optional):边的权重或边的多维特征。 (默认:None)**p**
(float, optional):Dropout 概率。 (默认值:0.5)**force_undirected**
(bool, optional):如果设置为 True,将删除或保留无向边的两个边。 (默认:False)**num_nodes**
(int, optional):节点数,即 edge_index 的 max_val + 1。 (默认:None)**training**
(bool, optional):如果设置为 False,则此操作为空操作。 (默认:True)
def dropout_adj(edge_index, edge_attr=None, p=0.5, force_undirected=False,
num_nodes=None, training=True):
r"""Randomly drops edges from the adjacency matrix
:obj:`(edge_index, edge_attr)` with probability :obj:`p` using samples from
a Bernoulli distribution.
Args:
edge_index (LongTensor): The edge indices.
edge_attr (Tensor, optional): Edge weights or multi-dimensional
edge features. (default: :obj:`None`)
p (float, optional): Dropout probability. (default: :obj:`0.5`)
force_undirected (bool, optional): If set to :obj:`True`, will either
drop or keep both edges of an undirected edge.
(default: :obj:`False`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
training (bool, optional): If set to :obj:`False`, this operation is a
no-op. (default: :obj:`True`)
"""
if p < 0. or p > 1.:
raise ValueError('Dropout probability has to be between 0 and 1, '
'but got {}'.format(p))
if not training or p == 0.0:
return edge_index, edge_attr
N = maybe_num_nodes(edge_index, num_nodes)
row, col = edge_index
if force_undirected:
row, col, edge_attr = filter_adj(row, col, edge_attr, row < col)
mask = edge_index.new_full((row.size(0), ), 1 - p, dtype=torch.float)
mask = torch.bernoulli(mask).to(torch.bool)
row, col, edge_attr = filter_adj(row, col, edge_attr, mask)
if force_undirected:
edge_index = torch.stack(
[torch.cat([row, col], dim=0),
torch.cat([col, row], dim=0)], dim=0)
if edge_attr is not None:
edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)
else:
edge_index = torch.stack([row, col], dim=0)
return edge_index, edge_attr