数据集:快速了解

tf.data 模块包含了一组类,这些类可以让你轻松的加载数据、操作数据,并将数据传送到您的模型中。文档通过如下这两个简单的例子来介绍该 API:

  • 从 numpy 数组读取内存数据。
  • 逐行读取 csv 文件。

基本输入

学习如何获取数组的片段,是开始学习 tf.data 最简单的方式。

Premade Estimators 一节在文件 iris_data.py 中定义了 train_input_fn,它可以将数据传输到 Estimator:

  1. def train_input_fn(features, labels, batch_size):
  2. """一个用来训练的输入函数"""
  3. # 将输入值转化为数据集。
  4. dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
  5. # 混排、重复、批处理样本。
  6. dataset = dataset.shuffle(1000).repeat().batch(batch_size)
  7. # 返回数据集
  8. return dataset

下面我们来对这个函数做更仔细的分析。

参数

这个函数一共需要三个参数。如果一个参数的期望类型是 “array”(数组),那么它将可以接受几乎所有可以用 numpy.array 来转化为数组的值。我们可以看到只有一个例外:tuple,它对 Datasets 有特殊的含义。

  • features:一个形如 {'feature_name':array} 的数据字典(或者是 DataFrame),它包含了原始的输入特征。
  • labels:一个包含每个样本的 label 的数组。
  • batch_size:一个指示所需批量大小的整数。

premade_estimator.py 中,我们使用 iris_data.load_data() 函数来检索虹膜数据。 你可以运行该函数,并按如下方式解压结果:

  1. import iris_data
  2. # 获取数据
  3. train, test = iris_data.load_data()
  4. features, labels = train

然后用像下面这样的一行代码,将数据传递给 input 函数:

  1. batch_size=100
  2. iris_data.train_input_fn(features, labels, batch_size)

让我们来具体看看 train_input_fn() 函数。

(数组)片段

函数首先使用 tf.data.Dataset.from_tensor_slices 函数来创建一个 tf.data.Dataset,表示数组的切片。数组在第一维度被切片。例如,包含 MNIST 的数组的形状为 (60000, 28, 28)。它将传递给 from_tensor_slices,然后返回一个 Dataset 对象,对象中包含 60000 个切片,每一个都是一个 28x28 的图像。

返回这个 Dataset 的代码如下所示:

  1. train, test = tf.keras.datasets.mnist.load_data()
  2. mnist_x, mnist_y = train
  3. mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
  4. print(mnist_ds)

这将打印下一行,显示 dataset 中的项 shapestypes。注意,Dataset 不知道它自己包含的项数。

  1. <TensorSliceDataset shapes: (28,28), types: tf.uint8>

上述的 Dataset 表示数组的简单集合,但数据集比这更复杂。Dataset 可以透明地处理任何嵌套的字典或元组组合(或者 namedtuple)。

例如,将 irls 的 features 转换为标准 python 字典之后,你可以将数组字典转换为字典的 Dataset,如下所示:

  1. dataset = tf.data.Dataset.from_tensor_slices(dict(features))
  2. print(dataset)
  1. <TensorSliceDataset
  2. shapes: {
  3. SepalLength: (), PetalWidth: (),
  4. PetalLength: (), SepalWidth: ()},
  5. types: {
  6. SepalLength: tf.float64, PetalWidth: tf.float64,
  7. PetalLength: tf.float64, SepalWidth: tf.float64}
  8. >

这里我们可以发现,当 Dataset 包含了结构化的元素时,Datasetshapestypes 就会采用相同结构。这个数据集包含了 scalars 字典,并且都是 tf.float64 类型。

iris 的第一行 train_input_fn 使用相同的功能,但是增加了一层结构。它创建了一个包含 (features_dict, label) 数据对的数据集。

以下代码表明,标签是类型为 int64 的标量:

  1. # 将输入转化为数据集。
  2. dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
  3. print(dataset)
  1. <TensorSliceDataset
  2. shapes: (
  3. {
  4. SepalLength: (), PetalWidth: (),
  5. PetalLength: (), SepalWidth: ()},
  6. ()),
  7. types: (
  8. {
  9. SepalLength: tf.float64, PetalWidth: tf.float64,
  10. PetalLength: tf.float64, SepalWidth: tf.float64},
  11. tf.int64)>

操作

目前,Dataset 会按照固定顺序遍历数据一次,且一次只能生成一个元素。在可以用于训练之前,它需要进一步的处理。幸运的是,tf.data.Dataset 类提供了方法来让数据为训练作出更好的准备。train_input_fn 的下一行代码就利用了几个这样的方法:

  1. # 样本的混排、重复、批处理。
  2. dataset = dataset.shuffle(1000).repeat().batch(batch_size)

tf.data.Dataset.shuffle 方法在传递时使用固定大小的缓冲区对其进行清洗。在这种情况下,buffer_size 大于 Dataset 中的示例数,确保数据被完全清洗(Iris 数据集只包含 150 个示例)。

tf.data.Dataset.repeat 方法在 Dataset 结束的时将它重启。如果要限制重复的次数,设置 count 参数。

tf.data.Dataset.batch 方法将会收集一定数量的样本并入栈,以此创建一个批次。这个操作会为样本的 shape 增加一个维度,且新的维度将作为第一维。如下代码在 MNIST 数据集上相对早地应用了 batch 方法,导致 Dataset 包含了表示 (28,28) 图像的三维数组:

  1. print(mnist_ds.batch(100))
  1. <BatchDataset
  2. shapes: (?, 28, 28),
  3. types: tf.uint8>

注意,因为最后一个批次将会有比较少的元素,因此数据集的批量大小是不确定的。

train_input_fn 中,批处理之后,数据集 包含元素们的一维向量,这些一维向量的前面部分是:

  1. print(dataset)
  1. <TensorSliceDataset
  2. shapes: (
  3. {
  4. SepalLength: (?,), PetalWidth: (?,),
  5. PetalLength: (?,), SepalWidth: (?,)},
  6. (?,)),
  7. types: (
  8. {
  9. SepalLength: tf.float64, PetalWidth: tf.float64,
  10. PetalLength: tf.float64, SepalWidth: tf.float64},
  11. tf.int64)>

返回

此时,Dataset 包含 (features_dict, labels) 对。这是 trainevaluate 方法所期望的格式,因此 input_fn 将返回数据集。

在使用 predict 方法时,可以/应该省略 labels

读取 CSV 文件

现实中对 Dataset 类最常见的应用是从磁盘的文档中获取数据流。tf.data 模块包括了一系列的文件读取器。我们来看看如何使用 Dataset 从 csv 文件中分析虹膜数据集。

如下对 iris_data.maybe_download 函数的调用,将会在必要的时候下载数据,并返回结果文件的路径:

  1. import iris_data
  2. train_path, test_path = iris_data.maybe_download()

iris_data.csv_input_fn 函数包括了一个用 Dataset 解析 csv 文件的替代方案。

让我们来看看如何构建一个兼容 Estimator 的、可以读取本地文件的输入函数。

建立 Dataset

我们从建立一个 tf.data.TextLineDataset 对象开始,这个对象一次只读取文件的一行。之后,调用 tf.data.Dataset.skip 方法,跳过文件的第一行——这是文件的头部,而不是样本:

  1. ds = tf.data.TextLineDataset(train_path).skip(1)

建立一个 csv 行解析器

我们从建立一个可以解析一行的函数开始。

如下的 iris_data.parse_line 函数完成了这个目标,它使用了 tf.decode_csv 方法以及一些简单的 python 代码:

为了生成必需的 (features, label) 数据对,我们必须解析数据集内的每一行。如下的 _parse_line 函数调用了 tf.decode_csv 来将单独一行解析为特征和标签。因为 Estimators 需要特征以字典的方式展现,我们就依靠 python 内建的 dictzip 函数来建立这个字典。特征的名字是字典的键值 key。然后,调用字典的 pop 方法来从特征字典中移除标签字段:

  1. # 描述文本列的元数据
  2. COLUMNS = ['SepalLength', 'SepalWidth',
  3. 'PetalLength', 'PetalWidth',
  4. 'label']
  5. FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
  6. def _parse_line(line):
  7. # 将行解码到 fields 中
  8. fields = tf.decode_csv(line, FIELD_DEFAULTS)
  9. # 将结果打包成字典
  10. features = dict(zip(COLUMNS,fields))
  11. # 将标签从特征中分离
  12. label = features.pop('label')
  13. return features, label

解析多行

当数据集将被传输到一个模型中时,它有很多操作数据的方法。其中,使用最多的是 tf.data.Dataset.map,它将转换应用到 Dataset 的每个元素中。

这个 map 方法接受一个 map_func 参数,这个参数描述了 Dataset 中的每一个元素应该如何被转化。

数据集:快速了解 - 图1

tf.data.Dataset.map 方法将会对 Dataset 中的每一个元素应用 map_func 来完成它们的转化。

因此,为了在多行数据被从 csv 文件中读取出来的时候解析它们,我们为 map 方法提供 _parse_line 函数:

  1. ds = ds.map(_parse_line)
  2. print(ds)
  1. <MapDataset
  2. shapes: (
  3. {SepalLength: (), PetalWidth: (), ...},
  4. ()),
  5. types: (
  6. {SepalLength: tf.float32, PetalWidth: tf.float32, ...},
  7. tf.int32)>

现在,数据集中包含的是 (features, label) 数据对,而不是简单的字符串标量了。

iris_data.csv_input_fn 函数的余下部分和 Basic input 中介绍的 iris_data.train_input_fn 函数相同。

实践

这个函数可以作为 iris_data.train_input_fn 的替代。它可以像如下这样,来给 estimator 提供数据:

  1. train_path, test_path = iris_data.maybe_download()
  2. # 所有的输入都是数字
  3. feature_columns = [
  4. tf.feature_column.numeric_column(name)
  5. for name in iris_data.CSV_COLUMN_NAMES[:-1]]
  6. # 构建 estimator
  7. est = tf.estimator.LinearClassifier(feature_columns,
  8. n_classes=3)
  9. # 训练 estimator
  10. batch_size = 100
  11. est.train(
  12. steps=1000,
  13. input_fn=lambda : iris_data.csv_input_fn(train_path, batch_size))

Estimator 期望 input_fn 没有任何参数。要解除这个限制,我们使用 lambda 来捕获参数并提供预期的接口。

总结

为了从不同的数据源中便捷的读取数据,tf.data 模块提供了类和函数的集合。除此之外,tf.data 有简单并且强大的方法,来应用各种标准和自定义转换。

现在你已经基本了解了如何为 Estimator 高效的获取数据。(作为扩展)接下来可以思考如下的文档: