TensorFlow Datasets 是一个开箱即用的数据集集合,包含数十种常用的机器学习数据集。通过简单的几行代码即可将数据以 tf.data.Dataset 的格式载入。关于 tf.data.Dataset 的使用可参考 tf.data。
该工具是一个独立的 Python 包,可以通过pip下载:
pip install tensorflow-datasets
pip install -i https://mirror.baidu.com/pypi/simple tensorflow-datasets
使用:
import tensorflow as tf
import tensorflow_datasets as tfds
提示:在使用 TensorFlow Datasets 时,可能需要设置代理。较为简易的方式是设置 HTTPS_PROXY 环境变量,即
export HTTPS_PROXY=http://代理服务器IP:端口
tfds.load 方法返回一个 tf.data.Dataset 对象。部分重要的参数如下:
as_supervised :若为 True,则根据数据集的特性,将数据集中的每行元素整理为有监督的二元组 (input, label) (即 “数据 + 标签”)形式,否则数据集中的每行元素为包含所有特征的字典。
split:指定返回数据集的特定部分。若不指定,则返回整个数据集。一般有 tfds.Split.TRAIN (训练集)和 tfds.Split.TEST (测试集)选项。
TensorFlow Datasets 当前支持的数据集可在 官方文档 查看,或者也可以使用 tfds.list_builders() 查看。
当得到了 tf.data.Dataset 类型的数据集后,我们即可使用 tf.data 对数据集进行各种预处理以及读取数据。例如:
# 使用 TessorFlow Datasets 载入“tf_flowers”数据集
dataset = tfds.load("tf_flowers", split=tfds.Split.TRAIN, as_supervised=True)
# 对 dataset 进行大小调整、打散和分批次操作
dataset = dataset.map(lambda img, label: (tf.image.resize(img, [224, 224]) / 255.0, label)) \
.shuffle(1024) \
.batch(32)
# 迭代数据
for images, labels in dataset:
# 对images和labels进行操作