参考来源:
CSDN:PyG-使用 networkx 对 Graph 进行可视化
知乎:在 PyTorch 框架下使用 PyG 和 networkx 对 Graph 进行可视化
NetworkX 文档:Drawing

方法一

根据 networkx 的文档:networkx.drawing.nx_pylab.draw_networkx
我们可以写出来一个非常简单的例子,如下:

  1. import networkx as nx
  2. import matplotlib.pyplot as plt
  3. G = nx.Graph()
  4. edge_index = [(1, 2), (1, 3), (2, 3), (3, 4)]
  5. G.add_edges_from(edge_index)
  6. nx.draw(G)
  7. plt.show()

运行程序之后,可以得到下面的图,(偷了一个懒,没有加 label 之类的信息)
image.png
这个例子给我们的启发就是,我们可以将 PyG 得到的 edge_index 转成 numpy 的格式,然后传给 nx,下面是根据这个写的一个函数:

PyG 中,边的表示放在了 edge_index 中,由一个二维的矩阵构成,edge_index[0] 表示节点,edge_index[1] 表示另一个节点。

  1. def draw(edge_index, name=None):
  2. G = nx.Graph(node_size=15, font_size=8)
  3. src = edge_index[0].cpu().numpy()
  4. dst = edge_index[1].cpu().numpy()
  5. edgelist = zip(src, dst)
  6. for i, j in edgelist:
  7. G.add_edge(i, j)
  8. plt.figure(figsize=(20, 14)) # 设置画布的大小
  9. nx.draw_networkx(G)
  10. plt.savefig('{}.png'.format(name if name else 'path'))

注:该方法可以用于模型中的 forward 函数,用于分析 covpool 等操作。

下面是与上面思想一致可以直接运行的一个例子

  1. from torch_geometric.datasets import KarateClub
  2. import networkx as nx
  3. import matplotlib.pyplot as plt
  4. dataset = KarateClub()
  5. edge, x, y = dataset[0]
  6. # edge, x, y 每个维度都为2,其中第一维度是name,第二个维度是data
  7. # x表示的是结点,y表示的标签,edge表示的连边, 由两个维度的tensor构成
  8. x_np = x[1].numpy()
  9. y_np = y[1].numpy()
  10. g = nx.Graph()
  11. name, edgeinfo = edge
  12. src = edgeinfo[0].numpy()
  13. dst = edgeinfo[1].numpy()
  14. edgelist = zip(src, dst)
  15. for i, j in edgelist:
  16. g.add_edge(i, j)
  17. nx.draw(g)
  18. plt.savefig('test.png')
  19. plt.show()

方法二

其实,torch_geometric.utils 中已经带有 to_networkx 的函数可以直接将格式为 torch_geometric.data.Data 的数据转换为 networkx.DiGraph 的格式,该格式可以直接 networkx 处理,但是我们提前要得到 torch_geometric.data.Data 的数据格式。

  1. import networkx as nx
  2. from torch_geometric.utils.convert import to_networkx
  3. def draw(Data):
  4. G = to_networkx(Data)
  5. nx.draw(G)
  6. plt.savefig("path.png")
  7. plt.show()

这个一般可以用于在 model 加载数据之前数据的分析,比如下面的例子:

  1. for i, data in enumerate(train_loader):
  2. draw(data)
  3. data = data.to(args.device)
  4. out = model(data)
  5. loss = F.nll_loss(out, data.y)
  6. print("Training loss:{}".format(loss.item()))
  7. loss.backward()
  8. optimizer.step()
  9. optimizer.zero_grad()

上面的函数是在 graph classification 进行分析的一段代码,可以把 batch size 的设置为 1,那么 for 循环中得到就是一个 graph 的数据,在把数据 feed 给模型之前,我们可以通过该方法分析一下原始的数据是什么样子的。