如果需要训练的数据大小不大,例如不到1G,那么可以直接全部读入内存中进行训练,这样一般效率最高。 但如果需要训练的数据很大,例如超过10G,无法一次载入内存,那么通常需要在训练的过程中分批逐渐读入。 使用 tf.data API 可以构建数据输入管道,轻松处理大量的数据,不同的数据格式,以及不同的数据转换。

可以从 Numpy array, Pandas DataFrame, Python generator, csv文件, 文本文件, 文件路径, tfrecords文件等方式构建数据管道。

1、示例

  1. import tensorflow as tf
  2. import numpy as np
  3. from sklearn import datasets
  4. iris = datasets.load_iris()
  5. ds1 = tf.data.Dataset.from_tensor_slices((iris["data"],iris["target"]))
  6. for features,label in ds1.take(5):
  7. print(features,label)

2、