环境构建

CPU版本。 但是本人使用的时会报错,找不到torch_sparse
先更改colab的torch版本

  1. !pip install torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
  1. !pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
  2. !pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
  3. !pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
  4. !pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
  5. !pip install torch-geometric

因此,改用了GPU版本
下面的代码会根据torch和cuda版本下载对应的PYG版本。
代码来源https://gist.github.com/ameya98/b193856171d11d37ada46458f60e73e7

  1. # Add this in a Google Colab cell to install the correct version of Pytorch Geometric.
  2. import torch
  3. def format_pytorch_version(version):
  4. return version.split('+')[0]
  5. TORCH_version = torch.__version__
  6. TORCH = format_pytorch_version(TORCH_version)
  7. def format_cuda_version(version):
  8. return 'cu' + version.replace('.', '')
  9. CUDA_version = torch.version.cuda
  10. CUDA = format_cuda_version(CUDA_version)
  11. !pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
  12. !pip install torch-sparse -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
  13. !pip install torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
  14. !pip install torch-spline-conv -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
  15. !pip install torch-geometric

测试

  1. from torch_geometric.data import Data

如果能正常import就说明安装成功了

PyG简单入门

  1. # 边的连接信息
  2. # 注意,无向图的边要定义两次
  3. edge_index = torch.tensor(
  4. [
  5. # 这里表示节点0和1有连接,因为是无向图 那么1和0也有连接
  6. # 上下对应着看
  7. [0, 1, 1, 2],
  8. [1, 0, 2, 1],
  9. ],
  10. # 指定数据类型
  11. dtype=torch.long
  12. )
  13. # 节点的属性信息
  14. x = torch.tensor(
  15. [
  16. # 三个节点
  17. # 每个节点的属性向量维度为1
  18. [-1],
  19. [0],
  20. [1],
  21. ]
  22. )
  23. # 实例化为一个图结构的数据
  24. data = Data(x=x, edge_index=edge_index)
  25. # 查看图数据
  26. print(data) #Data(x=[3, 1], edge_index=[2, 4])
  27. # 图数据中包含什么信息
  28. print(data.keys) #['edge_index', 'x']
  29. # 查看节点的属性信息
  30. print(data['x'])
  31. >>>
  32. tensor([[-1],
  33. [ 0],
  34. [ 1]])
  35. # 节点数
  36. print(data.num_nodes) #3
  37. # 边数
  38. print(data.num_edges) #4
  39. # 节点属性向量的维度
  40. print(data.num_node_features)# 1
  41. # 图中是否有孤立节点
  42. print(data.has_isolated_nodes()) #False
  43. # 图中是否有环
  44. print(data.has_self_loops()) #False
  45. # 是否是有向图
  46. print(data.is_directed()) #False

下载ENZYMES数据集

接下来拿ENZYMES数据集(包含600个图,每个图分为6个类别,图级别的分类)举例如何使用PyG的公共数据集

BRENDA酶数据库,起源于德国不伦瑞克在1987年建立的国家生物技术研究中心(GBF),目前由德国科隆大学生物化学研究所负责运营。BRENDA可以提供酶的分类、命名法、生化反应、专一性、结构、细胞定位、提取方法、文献、应用与改造及相关疾病的数据。人们不仅可以通过互联网(http://www.brenda-enzymes.org)免费获得,还也可以将其作为商业用户的内部数据库(需要请求其分销商Biobase)。

  1. from torch_geometric.datasets import TUDataset
  2. # 导入数据集
  3. dataset = TUDataset(
  4. # 指定数据集的存储位置
  5. # 如果指定位置没有相应的数据集
  6. # PyG会自动下载
  7. root='../data/ENZYMES',
  8. # 要使用的数据集
  9. name='ENZYMES',
  10. )
  11. # 数据集的长度
  12. print(len(dataset))
  13. # 数据集的类别数
  14. print(dataset.num_classes)
  15. # 数据集中节点属性向量的维度
  16. print(dataset.num_node_features)
  17. # 600个图,我们可以根据索引选择要使用哪个图
  18. data = dataset[0]
  19. print(data)
  20. # 随机打乱数据集
  21. dataset = dataset.shuffle()
  22. >>>
  23. 600
  24. 6
  25. 3
  26. Data(edge_index=[2, 168], x=[37, 3], y=[1])
  27. Done!

可以看到数据集下载到了colab自己的文件中了
image.png
Data(edge_index=[2, 168], x=[37, 3], y=[1])的说明:
第一张图有168条有向边,37个节点,每个节点3个label,整张图有一个类别

加载ENZYMES数据集

  1. from torch_geometric.loader import DataLoader
  2. # 数据集
  3. dataset = TUDataset(
  4. root='../data/ENZYMES',
  5. name='ENZYMES',
  6. use_node_attr=True,
  7. )
  8. # 建立数据集加载器
  9. # 每次加载32个数据到内存中
  10. loader = DataLoader(
  11. # 要加载的数据集
  12. dataset=dataset,
  13. # ENZYMES包含600个图
  14. # 每次加载32个
  15. batch_size=32,
  16. # 每次加入进来之后是否随机打乱数据(可以增加模型的泛化性)
  17. shuffle=True
  18. )
  19. for batch in loader:
  20. print(batch)
  21. print(batch.num_graphs)

image.png
y=[32]——32个图。
node_labels: (0,1,2); node_features: 21
dataset[0].num_features
输出是21。说明总的特征数是21
之前 data = dataset[0],输出是Data(edge_index=[2, 168], x=[37, 3], y=[1])
x只有3,而现在这里的x有21。个人推测是输出data的时候,数据集里没有使用use_node_attr=True,所以输出的3是结点的label数。
而现在使用了use_node_attr=True,输出的是节点的特征数
image.png

ptr可能是当前batch目前累计看到的图的节点数量。具体实验文章https://blog.csdn.net/qq_40206371/article/details/120615976