1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch_geometric.nn import GATConv
  6. import torch_geometric.transforms as T

image.png

官方例子

  1. import torch
  2. from torch_geometric.data import Data
  3. edge_index = torch.tensor([[0, 1, 1, 2],
  4. [1, 0, 2, 1]], dtype=torch.long)
  5. x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
  6. data = Data(x=x, edge_index=edge_index)
  1. print(data.num_nodes)
  2. print(data.num_edges)
  3. print(data.num_node_features)
  4. print(data.is_directed())
  5. >>>
  6. 3
  7. 4
  8. 1
  9. False
  1. #转化成networkx格式
  2. from torch_geometric.utils import to_networkx
  3. data_networkx = to_networkx(data, to_undirected=True)

使用plt绘图

  1. import matplotlib.pyplot as plt
  2. import networkx as nx
  3. pos = nx.layout.spring_layout(data_networkx)
  4. # plt.figure(figsize=(16,12))
  5. nx.draw_networkx_nodes(data_networkx, pos)
  6. nx.draw_networkx_edges(data_networkx, pos,width=1,edge_color="black")
  7. plt.show()

image.png
如果去掉data_networkx = to_networkx(data, to_undirected=True)中的to_undirected=True,结果如下
image.png

使用hvplot绘图

  1. import matplotlib.pyplot as plt
  2. import networkx as nx
  3. import hvplot.networkx as hvnx
  4. node = hvnx.draw_networkx_nodes(data_networkx, pos)
  5. edge = hvnx.draw_networkx_edges(data_networkx, pos, arrowstyle='->', edge_width=2, colorbar=True)
  6. node * edge

image.png可交互式的

封装绘图工具

  1. def draw_with_nx(data_nx):
  2. import matplotlib.pyplot as plt
  3. import networkx as nx
  4. pos = nx.layout.spring_layout(data_nx)
  5. nx.draw_networkx_nodes(data_nx, pos)
  6. nx.draw_networkx_edges(data_nx, pos,width=1,edge_color="black")
  7. plt.show()
  1. def draw_with_hvplot(data_nx):
  2. import matplotlib.pyplot as plt
  3. import networkx as nx
  4. import hvplot.networkx as hvnx
  5. pos = nx.layout.spring_layout(data_nx)
  6. node = hvnx.draw_networkx_nodes(data_nx, pos)
  7. edge = hvnx.draw_networkx_edges(data_nx, pos, arrowstyle='->', edge_width=2, colorbar=True)
  8. return node * edge

Networkx 内置的Graph Data Example

  1. G_bar = nx.barabasi_albert_graph(n=100, m=3)
  2. draw_with_nx(G_bar)
  3. # nx.draw_kamada_kawai(G)

image.png

  1. draw_with_hvplot(G_bar)

image.png

  1. from torch_geometric import utils
  2. data = utils.from_networkx(G)

networkx转化成PyG格式
image.png本质上只有edge_index起作用

  1. data.keys
  2. >>>
  3. ['edge_index']

数据集

Networkx Karate Club空手道俱乐部

  1. import networkx as nx
  2. G_karate = nx.karate_club_graph()
  3. data_karate = utils.from_networkx(G_karate)
  4. print(data_karate)
  5. >>>
  6. Data(club=[34], edge_index=[2, 156])
  1. set(data_karate.club)
  2. >>>
  3. {'Mr. Hi', 'Officer'}
  1. node_color = ["blue", "red"]
  2. node_label = np.array(list(G_karate.nodes))
  3. import matplotlib.pyplot as plt
  4. import networkx as nx
  5. import hvplot.networkx as hvnx
  6. pos = nx.layout.spring_layout(G_karate)
  7. plt_nodes = []
  8. for i, c in enumerate(list(set(data_karate.club))):
  9. nodelist = [n for n in range(len(G_karate.nodes)) if G_karate.nodes[n]['club'] == c]
  10. plt = hvnx.draw_networkx_nodes(G_karate, pos, nodelist=list(nodelist), node_color=node_color[i])
  11. plt_nodes.append(plt)
  12. plt_edges = hvnx.draw_networkx_edges(G_karate, pos, arrowstyle='->', edge_width=2, colorbar=True)
  13. import functools
  14. import operator
  15. plt_edges * functools.reduce(operator.mul, plt_nodes)

image.png

Cora Dataset

  1. from torch_geometric.datasets import Planetoid
  2. name_data = 'Cora'
  3. dataset_cora = Planetoid(root='/proj/data/', name=name_data)
  1. print(type(dataset_cora))
  2. print(type(dataset_cora[0]))
  3. print(type(dataset_cora.data))
  4. print(dataset_cora[0])
  5. print(dataset_cora.data)
  6. print(dataset_cora.data.keys)
  7. >>>
  8. <class 'torch_geometric.datasets.planetoid.Planetoid'>
  9. <class 'torch_geometric.data.data.Data'>
  10. <class 'torch_geometric.data.data.Data'>
  11. Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
  12. Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
  13. ['x', 'edge_index', 'y', 'train_mask', 'val_mask', 'test_mask']
  1. print(dataset_cora.data.y)
  2. >>>
  3. tensor([3, 4, 4, ..., 3, 3, 3])

Enzymes Dataset

  1. from torch_geometric.datasets import TUDataset
  2. from torch_geometric.data import DataLoader
  3. dataset_enz = TUDataset('./data', name='ENZYMES', use_node_attr=True)
  1. print(type(dataset_enz))
  2. print(type(dataset_enz[0]))
  3. print(type(dataset_enz.data))
  4. print(dataset_enz[0])
  5. print(dataset_enz.data)
  6. print(dataset_enz.data.keys)
  7. >>>
  8. <class 'torch_geometric.datasets.tu_dataset.TUDataset'>
  9. <class 'torch_geometric.data.data.Data'>
  10. <class 'torch_geometric.data.data.Data'>
  11. Data(edge_index=[2, 168], x=[37, 21], y=[1])
  12. Data(edge_index=[2, 74564], x=[19580, 21], y=[600])
  13. ['x', 'edge_index', 'y']

画出前3个graph

  1. from torch_geometric.utils import to_networkx
  2. for i in range(3):
  3. data_nx_enz = to_networkx(dataset_enz[i])
  4. draw_with_nx(data_nx_enz)

image.png
image.png
image.png

  1. import matplotlib.pyplot as plt
  2. import networkx as nx
  3. pos = nx.layout.spring_layout(data_nx_enz_0)
  4. # plt.figure(figsize=(16,12))
  5. nx.draw_networkx_nodes(data_nx_enz_0, pos)
  6. nx.draw_networkx_edges(data_nx_enz_0, pos,width=1,edge_color="black")
  7. plt.show()

image.png

Batch Graph

image.png
将Enzymes数据集的前10个graph组成一个batch

  1. from torch_geometric.data import Batch
  2. b = Batch.from_data_list(dataset_enz[:10])
  3. b
  4. >>>
  5. Batch(batch=[325], edge_index=[2, 1202], ptr=[11], x=[325, 21], y=[10])
  6. b.edge_index[0]
  7. >>>
  8. tensor([ 0, 0, 0, ..., 324, 324, 324])

将整个数据集以batch展示.
观察到batch大小的节点大小是一样的。这里的batch是说明每个x出现在哪个graph里

  1. loader = DataLoader(dataset_enz, batch_size=64, shuffle=True)
  2. for batch in loader:
  3. print(batch)
  4. >>>
  5. Batch(batch=[1993], edge_index=[2, 7806], ptr=[65], x=[1993, 21], y=[64])
  6. Batch(batch=[2098], edge_index=[2, 8016], ptr=[65], x=[2098, 21], y=[64])
  7. Batch(batch=[2135], edge_index=[2, 7752], ptr=[65], x=[2135, 21], y=[64])
  8. Batch(batch=[1977], edge_index=[2, 7634], ptr=[65], x=[1977, 21], y=[64])
  9. Batch(batch=[1946], edge_index=[2, 7482], ptr=[65], x=[1946, 21], y=[64])
  10. Batch(batch=[2093], edge_index=[2, 7986], ptr=[65], x=[2093, 21], y=[64])
  11. Batch(batch=[2239], edge_index=[2, 8514], ptr=[65], x=[2239, 21], y=[64])
  12. Batch(batch=[2053], edge_index=[2, 8008], ptr=[65], x=[2053, 21], y=[64])
  13. Batch(batch=[2219], edge_index=[2, 8462], ptr=[65], x=[2219, 21], y=[64])
  14. Batch(batch=[827], edge_index=[2, 2904], ptr=[25], x=[827, 21], y=[24])

image.png

image.png

Two Graph Batch Example

  1. import torch
  2. from torch_geometric.data import Data
  3. edge_index = torch.tensor([[0, 1, 1, 2],
  4. [1, 0, 2, 1]], dtype=torch.long)
  5. x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
  6. data1 = Data(x=x, edge_index=edge_index)
  7. data2 = Data(x=x, edge_index=edge_index)
  1. from torch_geometric.data import Batch
  2. g_2 = Batch.from_data_list([data1, data2])
  1. draw_with_nx(to_networkx(g_2, to_undirected=True))

image.png

  1. draw_with_hvplot(to_networkx(g_2, to_undirected=True))

image.png

  1. g_2.batch
  2. >>>
  3. tensor([0, 0, 0, 1, 1, 1])
  1. g_2.edge_index
  2. >>>
  3. tensor([[0, 1, 1, 2, 3, 4, 4, 5],
  4. [1, 0, 2, 1, 4, 3, 5, 4]])
  1. g_2_nx = to_networkx(g_2, to_undirected=True)
  2. g_2_nx.is_directed()
  3. >>>
  4. False
  5. print(nx.number_connected_components(g_2_nx))
  6. >>>
  7. 2