安装

  1. conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
  2. python -c "import torch; print(torch.__version__)"
  3. python -c "import torch; print(torch.version.cuda)"
  4. pip install torch-scatter -f https://data.pyg.org/whl/torch-1.11.0+cu102.html
  5. pip install torch-sparse -f https://data.pyg.org/whl/torch-1.11.0+cu102.html
  6. pip install torch-geometric
  7. pip install torch-cluster -f https://data.pyg.org/whl/torch-1.11.0+cu102.html
  8. pip install torch-spline-conv -f https://data.pyg.org/whl/torch-1.11.0+cu102.html
  9. #后面两句加上之后就报glibc的错误

如果需要外网访问下载数据集,可以先在mu01上面使用proxychains命令下载,然后再使用其他节点运算。

通过例子来简单介绍

数据

单个图使用torch_geometric.data.Data(这个对象用来表示同构图,即节点和边的类型只有一种)来表示。
image.png
创建一个图:

  1. import torch
  2. from torch_geometric.data import Data
  3. edge_index = torch.tensor([[0, 1, 1, 2],
  4. [1, 0, 2, 1]], dtype=torch.long)
  5. x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
  6. data = Data(x=x, edge_index=edge_index)
  7. >>> Data(edge_index=[2, 4], x=[3, 1])
print(data.keys)
>>> ['x', 'edge_index']

print(data['x'])
>>> tensor([[-1.0],
            [0.0],
            [1.0]])

for key, item in data:
    print(f'{key} found in data')
>>> x found in data
>>> edge_index found in data

'edge_attr' in data
>>> False

data.num_nodes
>>> 3

data.num_edges
>>> 4

data.num_node_features
>>> 1

data.has_isolated_nodes()
>>> False

data.has_self_loops()
>>> False

data.is_directed()
>>> False

# Transfer data object to GPU.
device = torch.device('cuda')
data = data.to(device)

异构数据

数据集

一个图分类的数据集网站:https://chrsmrrs.github.io/datasets/

from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
>>> ENZYMES(600)

len(dataset)
>>> 600

dataset.num_classes
>>> 6

dataset.num_node_features
>>> 3

上面的代码下载并初始化数据集
下面是一个半监督学习的数据集:

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/tmp/Cora', name='Cora')
>>> Cora()

len(dataset)
>>> 1

dataset.num_classes
>>> 7

dataset.num_node_features
>>> 1433

关于测试集,验证集的设置:

data = dataset[0]
>>> Data(edge_index=[2, 10556], test_mask=[2708],
         train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])

data.is_undirected()
>>> True

data.train_mask.sum().item()
>>> 140

data.val_mask.sum().item()
>>> 500

data.test_mask.sum().item()
>>> 1000

Mini-Batch

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:
    batch
    >>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])
    #上面1082的意思是这个批次中的所有图的节点加起来共有1082个节点,节点的特征是21维。y的标签是32个

    batch.num_graphs
    >>> 32
from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for data in loader:
    data
    >>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])

    data.num_graphs
    >>> 32

    x = scatter_mean(data.x, data.batch, dim=0)
    x.size()
    >>> torch.Size([32, 21])
 #求一个批次中单个图的平均节点个数

数据转换

from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'])

dataset[0]
>>> Data(pos=[2518, 3], y=[2518])

pos=[2518,3]的意思是有2518个节点,3表示每个节点用3个维度来表示此节点的位置。

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6))

dataset[0]
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

定义一个在加载数据集之前的转换。