本文主要介绍如何使用 TensorFlow 2.0 提供的 tf.data API 从文件夹中加载图像数据集并对数据集进行预处理。

获取图像路径

  1. import pathlib
  2. import tensorflow as tf
  3. dataset_dir = pathlib.Path('./dataset')
  4. # 使用 glob 搜索文件夹下相关文件,获得所有文件路径列表
  5. all_image_paths = list(dataset_dir.glob('*.jpg'))
  6. all_image_paths = [str(path) for path in all_image_paths]
  7. train_data_total = len(all_image_paths)

图像读取及预处理

  1. def load_and_preprocess_image(path):
  2. """
  3. 读取图像并对图像进行预处理
  4. """
  5. image = tf.io.read_file(path) # 读取图像
  6. image = tf.image.decode_jpeg(image, channels=3) # 对图像进行编码
  7. image = tf.image.resize(image, [192, 192]) # 调整图像尺寸
  8. image /= 255.0 # 调整像素值到合适的范围
  9. return image

构建数据集

请在文件头添加以下常量:

  1. AUTOTUNE = tf.data.experimental.AUTOTUNE

使用 from_tensor_slices() 方法,根据第一步中指定的所有文件的路径列表 all_image_paths 构建 tf.data.Dataset 数据集。

  1. images_path_tensor = tf.data.Dataset.from_tensor_slices(all_image_paths)
  2. train_images = images_path_tensor.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)

可以使用如下方法查看数据集中的图像:

  1. plt.figure(figsize=(8,8))
  2. for n, image in enumerate(train_images.take(4)):
  3. plt.subplot(2,2,n+1)
  4. plt.imshow(image)
  5. plt.show()

打乱与分割