定义网络
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
# "Add" aggregation (Step 5).
# flow='source_to_target' 表示消息从源节点传播到目标节点
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
#将特征矩阵线性变化,本次的实验中变化:1433→16→7
print(f'线性变化之前的Xshape:{x.shape}')
#print(f'线性变化之前的X:{x}')
x = self.lin(x)
#print(f'变化后的x:{x}')
print(f'变化后的x.shape:{x.shape}')
# Step 3: Compute normalization.计算归一化
row, col = edge_index
#x.size(0)=2708,也就是结点的数量
#degree算出的是结点的度数,由于是对称矩阵,所以row和col的degree相等
deg = degree(col, x.size(0), dtype=x.dtype)
print(f'row:{row}')
print(f'col:{col}')
print(f'x.size(0):{x.size(0)}')
deg_inv_sqrt = deg.pow(-0.5)
print(f' deg_inv_sqrt:{ deg_inv_sqrt}')
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
print(f'norm:{norm}')
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j
初始化
数据集
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
name_data = 'Cora'
dataset = Planetoid(root='./data/', name=name_data)
class Net(torch.nn.Module):
# torch.nn.Module 是所有神经网络单元的基类
def __init__(self):
super(Net, self).__init__() ###复制并使用Net的父类的初始化方法,即先运行nn.Module的初始化函数
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
net=Net()
out=net(dataset[0])
输出
由于定义了2个GCNConv层,所以输出可以看到变化的信息。x.size(0)=2708
,也就是结点的数量num_nodes
第一个:特征由1433→16,这是由于forward里的x = self.lin(x)
(这里的self.lin = torch.nn.Linear(in_channels, out_channels)
)
第二个:特征由16→7
rowshape:torch.Size([13264])
row和col的形状都是[13264]deg。shape:torch.Size([2708])
deg里面是结点的度