当时Google图神经网络的时候,搜到的一份PDF,这里我将这个PDF中的内容运行一遍,并且记录一下其中的内容。
Dense v. Sparse
例子:存储图的边
- 以矩阵形式存储(
dense
)#card=math&code=%5CRightarrow%20O%28%7CV%7C%5E2%29) - 以索引形式存储(
sparse
)%5Cleqslant%20O(%7CV%7C%5E2)#card=math&code=%5CRightarrow%20O%28%7CE%7C%29%5Cleqslant%20O%28%7CV%7C%5E2%29)
import networkx as nx
import matplotlib.pyplot as plt
G = nx.barabasi_albert_graph(100, 3)
_, axes = plt.subplots(1, 2, figsize=(10, 4), gridspec_kw={'wspace': 0.5})
nx.draw_kamada_kawai(G, ax=axes[0], node_size=120)
axes[1].imshow(nx.to_numpy_matrix(G), aspect='auto', cmap='Blues')
axes[0].set_title('$G$')
axes[1].set_title('$\mathbf{A}$')
plt.show()
# print(G.nodes)
# print(G.edges)
上面的networkx
是一个用于研究网络的库,点击链接查看教程。其中重要的部分是barabasi_albert_graph(n,m,seed=None)
这个函数。
- 参数
n
:图中的节点个数 - 参数
m
:一个新的节点和现有的m
个节点进行连接
例如,我们现在只让新的节点和现有节点之间只有一条边的连接G = nx.barabasi_albert_graph(100, 3)
,可以看到效果如下:
Sparse Representations
存储一个大小为的图,其中
- 边:作为矩阵的索引
- 三角:作为矩阵的索引
- 属性:特征矩阵或者
PyTorch中的索引
import torch
mat = torch.arange(12).view(3, 4) # 将0-11拍成一个3行4列的二维矩阵
print(mat)
print(mat[0]) # 打印出第一行
print(mat[:, -1]) # 打印出最后一列
print(mat[:, 2:]) # 打印出从第3行开始的列
print(mat[:, ::3]) # 打印出第1列和第4列
print(mat[:, :3]) # 打印出第1列和到第4列之前的所有列
输出:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]) tensor([0, 1, 2, 3]) tensor([ 3, 7, 11]) tensor([[ 2, 3],
[ 6, 7],
[10, 11]]) tensor([[ 0, 3],
[ 4, 7],
[ 8, 11]]) tensor([[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10]])
阈值分割
import torch
rnd = torch.rand(1, 3)
print(rnd)
mask = rnd >= 0.5
print(mask)
print(rnd[mask]) # 找到rnd>=0.5的数据
tensor([[0.4009, 0.2571, 0.8022]]) tensor([[False, False, True]]) tensor([0.8022])
索引选择
我们可以把这个部分和上面的部分结合起来,找到某个数的坐标索引
import torch
A = torch.randint(2, (5, 5))
print(A)
idx = A.nonzero().T # 找到不为0的数的位置索引
print(idx)
row, col = idx # row, col = idx[0], idx[1] 将不为0的数据列出
print(A[row, col])
tensor([[1, 1, 0, 0, 1],
[0, 1, 1, 0, 1],
[0, 0, 0, 1, 1],
[0, 1, 0, 0, 1],
[1, 0, 1, 1, 0]]) tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4, 4],
[0, 1, 4, 1, 2, 4, 3, 4, 1, 4, 0, 2, 3]]) tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
数据排序
import torch
rnd = torch.randint(10, (2, 4))
print(rnd)
sort, perm = torch.sort(rnd, dim=-1)
rnd = torch.gather(input=rnd, dim=-1, index=perm)
print(rnd)
tensor([[1, 1, 3, 2],
[7, 8, 3, 8]]) tensor([[1, 1, 2, 3],
[3, 7, 8, 8]])
PyTorchGeometric框架介绍
PyTorch-Geometric的子模型:
nn
:包含了许多的GNN模型,池化层和归一化层data
:用于管理稀疏和密集数据的类datasets
:基本上每个图形,网格和点云的标准基准数据集都有不同类型的任务transform
:数据操作函数utils
andio
:接口函数
例子
下面引入TUDataset数据集中的第一幅图,将其画出来
from torch_geometric.datasets import TUDataset, ModelNet, ShapeNet
from torch_geometric import utils
import networkx as nx
import matplotlib.pyplot as plt
ds = TUDataset(root='./data/', name='PROTEINS')
G = utils.to_networkx(ds[0])
nx.draw_kamada_kawai(G)
plt.show()
制作自己的数据集
可以通过InMemoryDataset
来制作自己的数据集:
- 在
raw_file_names()
中确定需要的数据 - 在
processed_file_names()
确定需要产生的数据 - 使用
download()
和process()
函数 - 在
__init__()
中加载已经处理的数据
通过一个例子,来看一下如何制作数据集
from torch_geometric.data import InMemoryDataset, download_url
from rdkit import Chem
import pandas as pd
class COVID(InMemoryDataset):
url = 'https://github.com/yangkevin2/coronavirus_data/raw/master/data/mpro_xchem.csv'
def __init__(self, root, transform=None, pre_transform=None,pre_filter=None):
super(COVID, self).__init__(root, transform, pre_transform, pre_filter)
# Load processed data
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['mpro_xchem.csv']
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
download_url(self.url, self.raw_dir)
def process(self):
df = pd.read_csv(self.raw_paths[0])
data_list = []
for smiles, label in df.itertuples(False, None):
mol = Chem.MolFromSmiles(smiles) # Read the molecule info
adj = Chem.GetAdjacencyMatrix(mol) # Get molecule structure
# You should extract other features here!
data = Data(num_nodes=adj.shape[0],
edge_index=torch.Tensor(adj).nonzero().T, y=label)
data_list.append(data)
self.data, self.slices = self.collate(data_list)
torch.save((self.data, self.slices), self.processed_paths[0])
covid = COVID(root='./data/COVID/')
G = utils.to_networkx(covid[0])
nx.draw_kamada_kawai(G)
plt.show()