import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
import torch_geometric.transforms as T
官方例子
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
print(data.num_nodes)
print(data.num_edges)
print(data.num_node_features)
print(data.is_directed())
>>>
3
4
1
False
#转化成networkx格式
from torch_geometric.utils import to_networkx
data_networkx = to_networkx(data, to_undirected=True)
使用plt绘图
import matplotlib.pyplot as plt
import networkx as nx
pos = nx.layout.spring_layout(data_networkx)
# plt.figure(figsize=(16,12))
nx.draw_networkx_nodes(data_networkx, pos)
nx.draw_networkx_edges(data_networkx, pos,width=1,edge_color="black")
plt.show()
如果去掉data_networkx = to_networkx(data, to_undirected=True)
中的to_undirected=True
,结果如下
使用hvplot绘图
import matplotlib.pyplot as plt
import networkx as nx
import hvplot.networkx as hvnx
node = hvnx.draw_networkx_nodes(data_networkx, pos)
edge = hvnx.draw_networkx_edges(data_networkx, pos, arrowstyle='->', edge_width=2, colorbar=True)
node * edge
封装绘图工具
def draw_with_nx(data_nx):
import matplotlib.pyplot as plt
import networkx as nx
pos = nx.layout.spring_layout(data_nx)
nx.draw_networkx_nodes(data_nx, pos)
nx.draw_networkx_edges(data_nx, pos,width=1,edge_color="black")
plt.show()
def draw_with_hvplot(data_nx):
import matplotlib.pyplot as plt
import networkx as nx
import hvplot.networkx as hvnx
pos = nx.layout.spring_layout(data_nx)
node = hvnx.draw_networkx_nodes(data_nx, pos)
edge = hvnx.draw_networkx_edges(data_nx, pos, arrowstyle='->', edge_width=2, colorbar=True)
return node * edge
Networkx 内置的Graph Data Example
G_bar = nx.barabasi_albert_graph(n=100, m=3)
draw_with_nx(G_bar)
# nx.draw_kamada_kawai(G)
draw_with_hvplot(G_bar)
from torch_geometric import utils
data = utils.from_networkx(G)
networkx转化成PyG格式本质上只有edge_index起作用
data.keys
>>>
['edge_index']
数据集
Networkx Karate Club空手道俱乐部
import networkx as nx
G_karate = nx.karate_club_graph()
data_karate = utils.from_networkx(G_karate)
print(data_karate)
>>>
Data(club=[34], edge_index=[2, 156])
set(data_karate.club)
>>>
{'Mr. Hi', 'Officer'}
node_color = ["blue", "red"]
node_label = np.array(list(G_karate.nodes))
import matplotlib.pyplot as plt
import networkx as nx
import hvplot.networkx as hvnx
pos = nx.layout.spring_layout(G_karate)
plt_nodes = []
for i, c in enumerate(list(set(data_karate.club))):
nodelist = [n for n in range(len(G_karate.nodes)) if G_karate.nodes[n]['club'] == c]
plt = hvnx.draw_networkx_nodes(G_karate, pos, nodelist=list(nodelist), node_color=node_color[i])
plt_nodes.append(plt)
plt_edges = hvnx.draw_networkx_edges(G_karate, pos, arrowstyle='->', edge_width=2, colorbar=True)
import functools
import operator
plt_edges * functools.reduce(operator.mul, plt_nodes)
Cora Dataset
from torch_geometric.datasets import Planetoid
name_data = 'Cora'
dataset_cora = Planetoid(root='/proj/data/', name=name_data)
print(type(dataset_cora))
print(type(dataset_cora[0]))
print(type(dataset_cora.data))
print(dataset_cora[0])
print(dataset_cora.data)
print(dataset_cora.data.keys)
>>>
<class 'torch_geometric.datasets.planetoid.Planetoid'>
<class 'torch_geometric.data.data.Data'>
<class 'torch_geometric.data.data.Data'>
Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
['x', 'edge_index', 'y', 'train_mask', 'val_mask', 'test_mask']
print(dataset_cora.data.y)
>>>
tensor([3, 4, 4, ..., 3, 3, 3])
Enzymes Dataset
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
dataset_enz = TUDataset('./data', name='ENZYMES', use_node_attr=True)
print(type(dataset_enz))
print(type(dataset_enz[0]))
print(type(dataset_enz.data))
print(dataset_enz[0])
print(dataset_enz.data)
print(dataset_enz.data.keys)
>>>
<class 'torch_geometric.datasets.tu_dataset.TUDataset'>
<class 'torch_geometric.data.data.Data'>
<class 'torch_geometric.data.data.Data'>
Data(edge_index=[2, 168], x=[37, 21], y=[1])
Data(edge_index=[2, 74564], x=[19580, 21], y=[600])
['x', 'edge_index', 'y']
画出前3个graph
from torch_geometric.utils import to_networkx
for i in range(3):
data_nx_enz = to_networkx(dataset_enz[i])
draw_with_nx(data_nx_enz)
import matplotlib.pyplot as plt
import networkx as nx
pos = nx.layout.spring_layout(data_nx_enz_0)
# plt.figure(figsize=(16,12))
nx.draw_networkx_nodes(data_nx_enz_0, pos)
nx.draw_networkx_edges(data_nx_enz_0, pos,width=1,edge_color="black")
plt.show()
Batch Graph
将Enzymes数据集的前10个graph组成一个batch
from torch_geometric.data import Batch
b = Batch.from_data_list(dataset_enz[:10])
b
>>>
Batch(batch=[325], edge_index=[2, 1202], ptr=[11], x=[325, 21], y=[10])
b.edge_index[0]
>>>
tensor([ 0, 0, 0, ..., 324, 324, 324])
将整个数据集以batch展示.
观察到batch大小的节点大小是一样的。这里的batch是说明每个x出现在哪个graph里
loader = DataLoader(dataset_enz, batch_size=64, shuffle=True)
for batch in loader:
print(batch)
>>>
Batch(batch=[1993], edge_index=[2, 7806], ptr=[65], x=[1993, 21], y=[64])
Batch(batch=[2098], edge_index=[2, 8016], ptr=[65], x=[2098, 21], y=[64])
Batch(batch=[2135], edge_index=[2, 7752], ptr=[65], x=[2135, 21], y=[64])
Batch(batch=[1977], edge_index=[2, 7634], ptr=[65], x=[1977, 21], y=[64])
Batch(batch=[1946], edge_index=[2, 7482], ptr=[65], x=[1946, 21], y=[64])
Batch(batch=[2093], edge_index=[2, 7986], ptr=[65], x=[2093, 21], y=[64])
Batch(batch=[2239], edge_index=[2, 8514], ptr=[65], x=[2239, 21], y=[64])
Batch(batch=[2053], edge_index=[2, 8008], ptr=[65], x=[2053, 21], y=[64])
Batch(batch=[2219], edge_index=[2, 8462], ptr=[65], x=[2219, 21], y=[64])
Batch(batch=[827], edge_index=[2, 2904], ptr=[25], x=[827, 21], y=[24])
Two Graph Batch Example
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data1 = Data(x=x, edge_index=edge_index)
data2 = Data(x=x, edge_index=edge_index)
from torch_geometric.data import Batch
g_2 = Batch.from_data_list([data1, data2])
draw_with_nx(to_networkx(g_2, to_undirected=True))
draw_with_hvplot(to_networkx(g_2, to_undirected=True))
g_2.batch
>>>
tensor([0, 0, 0, 1, 1, 1])
g_2.edge_index
>>>
tensor([[0, 1, 1, 2, 3, 4, 4, 5],
[1, 0, 2, 1, 4, 3, 5, 4]])
g_2_nx = to_networkx(g_2, to_undirected=True)
g_2_nx.is_directed()
>>>
False
print(nx.number_connected_components(g_2_nx))
>>>
2