安装
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorchpython -c "import torch; print(torch.__version__)"python -c "import torch; print(torch.version.cuda)"pip install torch-scatter -f https://data.pyg.org/whl/torch-1.11.0+cu102.htmlpip install torch-sparse -f https://data.pyg.org/whl/torch-1.11.0+cu102.htmlpip install torch-geometricpip install torch-cluster -f https://data.pyg.org/whl/torch-1.11.0+cu102.htmlpip install torch-spline-conv -f https://data.pyg.org/whl/torch-1.11.0+cu102.html#后面两句加上之后就报glibc的错误
如果需要外网访问下载数据集,可以先在mu01上面使用proxychains命令下载,然后再使用其他节点运算。
通过例子来简单介绍
数据
单个图使用torch_geometric.data.Data(这个对象用来表示同构图,即节点和边的类型只有一种)来表示。
创建一个图:
import torchfrom torch_geometric.data import Dataedge_index = torch.tensor([[0, 1, 1, 2],[1, 0, 2, 1]], dtype=torch.long)x = torch.tensor([[-1], [0], [1]], dtype=torch.float)data = Data(x=x, edge_index=edge_index)>>> Data(edge_index=[2, 4], x=[3, 1])
print(data.keys)
>>> ['x', 'edge_index']
print(data['x'])
>>> tensor([[-1.0],
[0.0],
[1.0]])
for key, item in data:
print(f'{key} found in data')
>>> x found in data
>>> edge_index found in data
'edge_attr' in data
>>> False
data.num_nodes
>>> 3
data.num_edges
>>> 4
data.num_node_features
>>> 1
data.has_isolated_nodes()
>>> False
data.has_self_loops()
>>> False
data.is_directed()
>>> False
# Transfer data object to GPU.
device = torch.device('cuda')
data = data.to(device)
异构数据
数据集
一个图分类的数据集网站:https://chrsmrrs.github.io/datasets/
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
>>> ENZYMES(600)
len(dataset)
>>> 600
dataset.num_classes
>>> 6
dataset.num_node_features
>>> 3
上面的代码下载并初始化数据集
下面是一个半监督学习的数据集:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
>>> Cora()
len(dataset)
>>> 1
dataset.num_classes
>>> 7
dataset.num_node_features
>>> 1433
关于测试集,验证集的设置:
data = dataset[0]
>>> Data(edge_index=[2, 10556], test_mask=[2708],
train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
data.is_undirected()
>>> True
data.train_mask.sum().item()
>>> 140
data.val_mask.sum().item()
>>> 500
data.test_mask.sum().item()
>>> 1000
Mini-Batch
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
batch
>>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])
#上面1082的意思是这个批次中的所有图的节点加起来共有1082个节点,节点的特征是21维。y的标签是32个
batch.num_graphs
>>> 32
from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for data in loader:
data
>>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])
data.num_graphs
>>> 32
x = scatter_mean(data.x, data.batch, dim=0)
x.size()
>>> torch.Size([32, 21])
#求一个批次中单个图的平均节点个数
数据转换
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'])
dataset[0]
>>> Data(pos=[2518, 3], y=[2518])
pos=[2518,3]的意思是有2518个节点,3表示每个节点用3个维度来表示此节点的位置。
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
pre_transform=T.KNNGraph(k=6))
dataset[0]
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])
定义一个在加载数据集之前的转换。
