本文主要介绍如何使用 TensorFlow 2.0 提供的 tf.data
API 从文件夹中加载图像数据集并对数据集进行预处理。
获取图像路径
import pathlib
import tensorflow as tf
dataset_dir = pathlib.Path('./dataset')
# 使用 glob 搜索文件夹下相关文件,获得所有文件路径列表
all_image_paths = list(dataset_dir.glob('*.jpg'))
all_image_paths = [str(path) for path in all_image_paths]
train_data_total = len(all_image_paths)
图像读取及预处理
def load_and_preprocess_image(path):
"""
读取图像并对图像进行预处理
"""
image = tf.io.read_file(path) # 读取图像
image = tf.image.decode_jpeg(image, channels=3) # 对图像进行编码
image = tf.image.resize(image, [192, 192]) # 调整图像尺寸
image /= 255.0 # 调整像素值到合适的范围
return image
构建数据集
请在文件头添加以下常量:
AUTOTUNE = tf.data.experimental.AUTOTUNE
使用 from_tensor_slices()
方法,根据第一步中指定的所有文件的路径列表 all_image_paths
构建 tf.data.Dataset
数据集。
images_path_tensor = tf.data.Dataset.from_tensor_slices(all_image_paths)
train_images = images_path_tensor.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
可以使用如下方法查看数据集中的图像:
plt.figure(figsize=(8,8))
for n, image in enumerate(train_images.take(4)):
plt.subplot(2,2,n+1)
plt.imshow(image)
plt.show()