参考来源:
https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.dropout_adj
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/utils/dropout.html
dropout_adj(edge_index,edge_attr=None,p=0.5,force_undirected=False,num_nodes=None,training=True)
功能:
使用来自伯努利分布的样本,以概率 p 从邻接矩阵 (edge_index, edge_attr) 中随机删除边。
参数:
**edge_index**(LongTensor):边的索引矩阵。**edge_attr**(Tensor, optional):边的权重或边的多维特征。 (默认:None)**p**(float, optional):Dropout 概率。 (默认值:0.5)**force_undirected**(bool, optional):如果设置为 True,将删除或保留无向边的两个边。 (默认:False)**num_nodes**(int, optional):节点数,即 edge_index 的 max_val + 1。 (默认:None)**training**(bool, optional):如果设置为 False,则此操作为空操作。 (默认:True)
def dropout_adj(edge_index, edge_attr=None, p=0.5, force_undirected=False,num_nodes=None, training=True):r"""Randomly drops edges from the adjacency matrix:obj:`(edge_index, edge_attr)` with probability :obj:`p` using samples froma Bernoulli distribution.Args:edge_index (LongTensor): The edge indices.edge_attr (Tensor, optional): Edge weights or multi-dimensionaledge features. (default: :obj:`None`)p (float, optional): Dropout probability. (default: :obj:`0.5`)force_undirected (bool, optional): If set to :obj:`True`, will eitherdrop or keep both edges of an undirected edge.(default: :obj:`False`)num_nodes (int, optional): The number of nodes, *i.e.*:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)training (bool, optional): If set to :obj:`False`, this operation is ano-op. (default: :obj:`True`)"""if p < 0. or p > 1.:raise ValueError('Dropout probability has to be between 0 and 1, ''but got {}'.format(p))if not training or p == 0.0:return edge_index, edge_attrN = maybe_num_nodes(edge_index, num_nodes)row, col = edge_indexif force_undirected:row, col, edge_attr = filter_adj(row, col, edge_attr, row < col)mask = edge_index.new_full((row.size(0), ), 1 - p, dtype=torch.float)mask = torch.bernoulli(mask).to(torch.bool)row, col, edge_attr = filter_adj(row, col, edge_attr, mask)if force_undirected:edge_index = torch.stack([torch.cat([row, col], dim=0),torch.cat([col, row], dim=0)], dim=0)if edge_attr is not None:edge_attr = torch.cat([edge_attr, edge_attr], dim=0)edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)else:edge_index = torch.stack([row, col], dim=0)return edge_index, edge_attr
