本文主要介绍如何使用 TensorFlow 2.0 提供的 tf.data API 从文件夹中加载图像数据集并对数据集进行预处理。
获取图像路径
import pathlibimport tensorflow as tfdataset_dir = pathlib.Path('./dataset')# 使用 glob 搜索文件夹下相关文件,获得所有文件路径列表all_image_paths = list(dataset_dir.glob('*.jpg'))all_image_paths = [str(path) for path in all_image_paths]train_data_total = len(all_image_paths)
图像读取及预处理
def load_and_preprocess_image(path):"""读取图像并对图像进行预处理"""image = tf.io.read_file(path) # 读取图像image = tf.image.decode_jpeg(image, channels=3) # 对图像进行编码image = tf.image.resize(image, [192, 192]) # 调整图像尺寸image /= 255.0 # 调整像素值到合适的范围return image
构建数据集
请在文件头添加以下常量:
AUTOTUNE = tf.data.experimental.AUTOTUNE
使用 from_tensor_slices() 方法,根据第一步中指定的所有文件的路径列表 all_image_paths 构建 tf.data.Dataset 数据集。
images_path_tensor = tf.data.Dataset.from_tensor_slices(all_image_paths)train_images = images_path_tensor.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
可以使用如下方法查看数据集中的图像:
plt.figure(figsize=(8,8))for n, image in enumerate(train_images.take(4)):plt.subplot(2,2,n+1)plt.imshow(image)plt.show()
