PyTorch Geometric Temporal库简称GT库 GT库中的相关方法可以分为离散和连续的,本文专注于用于离散方法的数据集
- GT库已经包含了一些数据集,用于比较temporal graph neural networks算法的性能。
与这些数据集相关的机器学习任务是**节点级(node level)和图级(graph level)**的监督学习。
数据集的使用
- 加载数据集
- 加载数据集需要导入
torch_geometric_temporal.data.dataset
- 加载数据集需要导入
如Hungarian Chickenpox Dataset可以通过以下代码段加载
from torch_geometric_temporal.data.dataset import ChickenpoxDatasetLoader
loader = ChickenpoxDatasetLoader()
dataset = loader.get_dataset() # 这里返回的是StaticGraphDiscreteSignal对象
- GT库包含的数据集不需要手动下载,其数据集加载函数会通过网络自动下载数据集。
- 如果因为网络问题,会导致下载失败,这时就需要修改GT库中的数据集加载函数,将其导向本地数据集
如修改用于加载Hungarian Chickenpox Dataset的方法
torch_geometric_temporal.data.dataset.ChickenpoxDatasetLoader
def _read_web_data(self):
# url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/discrete/chickenpox.json"
# self._dataset = json.loads(urllib.request.urlopen(url).read())
with open('F:\work\PyCharmProject\pytorch_geometric_temporal-master\dataset\discrete\chickenpox.json', 'r',
encoding='utf8')as fp:
self._dataset = json.load(fp)
- 训练集/测试集的分割
GT库提供了函数用于分割数据集,该函数可以接受StaticGraphDiscreteSignal
对象和DynamicGraphDiscreteSignal
对象
from torch_geometric_temporal.data.splitter import discrete_train_test_split
train_dataset, test_dataset = discrete_train_test_split(dataset, train_ratio=0.8)
train_ratio
参数指定训练集所占的比例,因此有test_ratio = 1 - train_ratio
GT库现有数据集
- 目前GT库提供了4个离散时间的数据集
- Hungarian Chickenpox Dataset.
- PedalMe London Dataset.
- Pems Bay Dataset.
- Metr LA Dataset.
Hungarian Chickenpox Dataset
- 数据集介绍
2004年至2014年匈牙利县一级水痘病例的数据集。底层的图是静态的,结点是县,边指示县的临街县。
样本 | 518 |
---|---|
边数(static) | 102 |
边权值 | 无 |
结点(static) | 20 |
结点特征(dynamic) | 4 |
- 结点
结点特征是每周水痘病例的间隔计数(Vertex features are lagged weekly counts of the chickenpox cases),
- 每张静态图有
20
个县,即**20**
个结点
- 每个结点拥有
4
个间隔(4lags)计数,即4
个特征**(我认为这里体现了时序的概念)**
- 边
根据观察,边是没有时序变化的
- 每张静态图共有
102
条边,每条边指示了县(结点)的邻接关系
- 边的权值统一为1,因为仅指示了结点的邻接关系
- 标签
标签是下一周的案例数(The target is the weekly number of cases for the upcoming week)
- 共有
518
个标签,对应518
周的水痘间隔计数
- 每个标签有
20
个值,对于一张静态图中的20
个结点(20
个县)
- 实践证明
- 可以看到,每次输入
- 结点特征:
x.shape = [20, 4]
; - 边:
edge_index.shape = [2, 102]
; - 标签:
y.shape = [20]
- 结点特征:
PedalMe London Dataset
- 数据集介绍
该数据集是PedalMe公司在2020至2021年期间,在伦敦交付的订单数量。数据集由基于邻近度的加权邻接矩阵和2020年和2021年每周需求的时间序列组成。有两个具体的相关任务:
- 地区级需求预测。
- 伦敦水平的需求预测。
结点是地区需求的时间序列;每张图是全连通图,每条边有权值,表示结点之间的距离
样本 | 30 |
---|---|
边数(static) | 225 |
边权值 | 有 |
结点(static) | 15 |
结点特征(dynamic) | 4 |
- 结点
**
- 边
- 边权值
**
- 标签
Pems Bay Dataset
数据集论文 —— Chickenpox Cases in Hungary: a Benchmark Dataset for Spatiotemporal Signal Processing with Graph Neural Networks
- 摘要
Recurrent graph convolutional neural networks是一种用于处理spatio-temporal signal的机器学习技术。新提出的graph neural network architectures在交通或天气预报等标准任务上进行了反复评估。本文提出Chickenpox Cases in Hungary dataset作为比较graph neural network architectures的新数据集。我们的时间序列分析和预测实验表明,Chickenpox Cases in Hungary dataset足以比较新型recurrent graph neural network architectures。
- 评估结果