编写自定义数据集,数据加载器和转换

原文: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

注意

单击此处的下载完整的示例代码

作者Sasank Chilamkurthy

解决任何机器学习问题都需要花费大量精力来准备数据。 PyTorch 提供了许多工具来简化数据加载过程,并有望使代码更具可读性。 在本教程中,我们将了解如何从非空的数据集中加载和预处理/增强数据。

要运行本教程,请确保已安装以下软件包:

  • scikit-image:用于图像 io 和变换
  • pandas:用于更轻松的 csv 解析
  1. from __future__ import print_function, division
  2. import os
  3. import torch
  4. import pandas as pd
  5. from skimage import io, transform
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from torch.utils.data import Dataset, DataLoader
  9. from torchvision import transforms, utils
  10. # Ignore warnings
  11. import warnings
  12. warnings.filterwarnings("ignore")
  13. plt.ion() # interactive mode

我们要处理的数据集是面部姿势数据集。 这意味着将对面部进行如下注释:

../_images/landmarked_face2.png

总体上,每个面孔都标注了 68 个不同的界标点。

Note

此处下载数据集,将图像存放于名为“ data / faces /”的目录中。 该数据集实际上是通过对来自标记为“面部”的 imagenet 上的一些图像应用出色的 dlib 姿态估计生成的。

数据集带有一个带注释的 csv 文件,如下所示:

  1. image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
  2. 0805personali01.jpg,27,83,27,98, ... 84,134
  3. 1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

让我们快速阅读 CSV 并获取(N,2)数组中的注释,其中 N 是地标数。

  1. landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')
  2. n = 65
  3. img_name = landmarks_frame.iloc[n, 0]
  4. landmarks = landmarks_frame.iloc[n, 1:]
  5. landmarks = np.asarray(landmarks)
  6. landmarks = landmarks.astype('float').reshape(-1, 2)
  7. print('Image name: {}'.format(img_name))
  8. print('Landmarks shape: {}'.format(landmarks.shape))
  9. print('First 4 Landmarks: {}'.format(landmarks[:4]))

输出:

  1. Image name: person-7.jpg
  2. Landmarks shape: (68, 2)
  3. First 4 Landmarks: [[32. 65.]
  4. [33. 76.]
  5. [34. 86.]
  6. [34. 97.]]

让我们编写一个简单的辅助函数来显示图像及其地标,并使用它来显示示例。

  1. def show_landmarks(image, landmarks):
  2. """Show image with landmarks"""
  3. plt.imshow(image)
  4. plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
  5. plt.pause(0.001) # pause a bit so that plots are updated
  6. plt.figure()
  7. show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
  8. landmarks)
  9. plt.show()

../_images/sphx_glr_data_loading_tutorial_001.png

数据集类

torch.utils.data.Dataset是代表数据集的抽象类。 您的自定义数据集应继承Dataset并覆盖以下方法:

  • __len__,以便len(dataset)返回数据集的大小。
  • __getitem__支持索引,以便可以使用dataset[i]获取第编写自定义数据集,数据加载器和转换 - 图3个样本

让我们为面部轮廓数据集创建一个数据集类。 我们将在__init__中读取 csv,但将图像读取留给__getitem__。 由于所有图像不会立即存储在内存中,而是根据需要读取,因此可以提高存储效率。

我们的数据集样本将是 dict {'image': image, 'landmarks': landmarks}。 我们的数据集将使用可选参数transform,以便可以将任何所需的处理应用于样本。 我们将在下一部分中看到transform的有用性。

  1. class FaceLandmarksDataset(Dataset):
  2. """Face Landmarks dataset."""
  3. def __init__(self, csv_file, root_dir, transform=None):
  4. """
  5. Args:
  6. csv_file (string): Path to the csv file with annotations.
  7. root_dir (string): Directory with all the images.
  8. transform (callable, optional): Optional transform to be applied
  9. on a sample.
  10. """
  11. self.landmarks_frame = pd.read_csv(csv_file)
  12. self.root_dir = root_dir
  13. self.transform = transform
  14. def __len__(self):
  15. return len(self.landmarks_frame)
  16. def __getitem__(self, idx):
  17. if torch.is_tensor(idx):
  18. idx = idx.tolist()
  19. img_name = os.path.join(self.root_dir,
  20. self.landmarks_frame.iloc[idx, 0])
  21. image = io.imread(img_name)
  22. landmarks = self.landmarks_frame.iloc[idx, 1:]
  23. landmarks = np.array([landmarks])
  24. landmarks = landmarks.astype('float').reshape(-1, 2)
  25. sample = {'image': image, 'landmarks': landmarks}
  26. if self.transform:
  27. sample = self.transform(sample)
  28. return sample

让我们实例化该类并遍历数据样本。 我们将打印前 4 个样本的大小并显示其地标。

  1. face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
  2. root_dir='data/faces/')
  3. fig = plt.figure()
  4. for i in range(len(face_dataset)):
  5. sample = face_dataset[i]
  6. print(i, sample['image'].shape, sample['landmarks'].shape)
  7. ax = plt.subplot(1, 4, i + 1)
  8. plt.tight_layout()
  9. ax.set_title('Sample #{}'.format(i))
  10. ax.axis('off')
  11. show_landmarks(**sample)
  12. if i == 3:
  13. plt.show()
  14. break

../_images/sphx_glr_data_loading_tutorial_002.png

输出:

  1. 0 (324, 215, 3) (68, 2)
  2. 1 (500, 333, 3) (68, 2)
  3. 2 (250, 258, 3) (68, 2)
  4. 3 (434, 290, 3) (68, 2)

Transforms 变换

从上面可以看到的一个问题是样本的大小不同。 大多数神经网络期望图像的大小固定。 因此,我们将需要编写一些预处理代码。 让我们创建三个转换:

  • Rescale:缩放图像
  • RandomCrop:从图像中随机裁剪。 这是数据增强。
  • ToTensor:将 numpy 图像转换为 torch 图像(我们需要交换轴)。

我们会将它们编写为可调用的类,而不是简单的函数,这样就不必每次调用转换时都传递其参数。 为此,我们只需要实现__call__方法,如果需要,还可以实现__init__方法。 然后我们可以使用这样的变换:

  1. tsfm = Transform(params)
  2. transformed_sample = tsfm(sample)

在下面观察如何将这些变换同时应用于图像和地标。

  1. class Rescale(object):
  2. """Rescale the image in a sample to a given size.
  3. Args:
  4. output_size (tuple or int): Desired output size. If tuple, output is
  5. matched to output_size. If int, smaller of image edges is matched
  6. to output_size keeping aspect ratio the same.
  7. """
  8. def __init__(self, output_size):
  9. assert isinstance(output_size, (int, tuple))
  10. self.output_size = output_size
  11. def __call__(self, sample):
  12. image, landmarks = sample['image'], sample['landmarks']
  13. h, w = image.shape[:2]
  14. if isinstance(self.output_size, int):
  15. if h > w:
  16. new_h, new_w = self.output_size * h / w, self.output_size
  17. else:
  18. new_h, new_w = self.output_size, self.output_size * w / h
  19. else:
  20. new_h, new_w = self.output_size
  21. new_h, new_w = int(new_h), int(new_w)
  22. img = transform.resize(image, (new_h, new_w))
  23. # h and w are swapped for landmarks because for images,
  24. # x and y axes are axis 1 and 0 respectively
  25. landmarks = landmarks * [new_w / w, new_h / h]
  26. return {'image': img, 'landmarks': landmarks}
  27. class RandomCrop(object):
  28. """Crop randomly the image in a sample.
  29. Args:
  30. output_size (tuple or int): Desired output size. If int, square crop
  31. is made.
  32. """
  33. def __init__(self, output_size):
  34. assert isinstance(output_size, (int, tuple))
  35. if isinstance(output_size, int):
  36. self.output_size = (output_size, output_size)
  37. else:
  38. assert len(output_size) == 2
  39. self.output_size = output_size
  40. def __call__(self, sample):
  41. image, landmarks = sample['image'], sample['landmarks']
  42. h, w = image.shape[:2]
  43. new_h, new_w = self.output_size
  44. top = np.random.randint(0, h - new_h)
  45. left = np.random.randint(0, w - new_w)
  46. image = image[top: top + new_h,
  47. left: left + new_w]
  48. landmarks = landmarks - [left, top]
  49. return {'image': image, 'landmarks': landmarks}
  50. class ToTensor(object):
  51. """Convert ndarrays in sample to Tensors."""
  52. def __call__(self, sample):
  53. image, landmarks = sample['image'], sample['landmarks']
  54. # swap color axis because
  55. # numpy image: H x W x C
  56. # torch image: C X H X W
  57. image = image.transpose((2, 0, 1))
  58. return {'image': torch.from_numpy(image),
  59. 'landmarks': torch.from_numpy(landmarks)}

撰写变换

现在,我们将转换应用于样本。

假设我们要将图片的较短边重新缩放为 256,然后从中随机裁剪一个尺寸为 224 的正方形。 也就是说,我们要组成RescaleRandomCrop转换。 torchvision.transforms.Compose是一个简单的可调用类,它使我们可以执行此操作。

  1. scale = Rescale(256)
  2. crop = RandomCrop(128)
  3. composed = transforms.Compose([Rescale(256),
  4. RandomCrop(224)])
  5. # Apply each of the above transforms on sample.
  6. fig = plt.figure()
  7. sample = face_dataset[65]
  8. for i, tsfrm in enumerate([scale, crop, composed]):
  9. transformed_sample = tsfrm(sample)
  10. ax = plt.subplot(1, 3, i + 1)
  11. plt.tight_layout()
  12. ax.set_title(type(tsfrm).__name__)
  13. show_landmarks(**transformed_sample)
  14. plt.show()

../_images/sphx_glr_data_loading_tutorial_003.png

遍历数据集

让我们将所有这些放在一起,以创建具有组合转换的数据集。 总而言之,每次采样此数据集时:

  • 从文件中即时读取图像
  • 转换应用于读取的图像
  • 由于其中一种转换是随机的,因此数据是在采样时进行增强

我们可以像以前一样使用for i in range循环遍历创建的数据集。

  1. transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
  2. root_dir='data/faces/',
  3. transform=transforms.Compose([
  4. Rescale(256),
  5. RandomCrop(224),
  6. ToTensor()
  7. ]))
  8. for i in range(len(transformed_dataset)):
  9. sample = transformed_dataset[i]
  10. print(i, sample['image'].size(), sample['landmarks'].size())
  11. if i == 3:
  12. break

输出:

  1. 0 torch.Size([3, 224, 224]) torch.Size([68, 2])
  2. 1 torch.Size([3, 224, 224]) torch.Size([68, 2])
  3. 2 torch.Size([3, 224, 224]) torch.Size([68, 2])
  4. 3 torch.Size([3, 224, 224]) torch.Size([68, 2])

但是,通过使用简单的for循环迭代数据,我们失去了很多功能。 特别是,我们错过了:

  • 批量处理数据
  • 打乱数据
  • 使用multiprocessing工作程序并行加载数据。

torch.utils.data.DataLoader是提供所有这些功能的迭代器。 下面使用的参数应该清楚。 感兴趣的一个参数是collate_fn。 您可以使用collate_fn指定需要如何精确地分批样品。 但是,默认精度在大多数情况下都可以正常工作。

  1. dataloader = DataLoader(transformed_dataset, batch_size=4,
  2. shuffle=True, num_workers=4)
  3. # Helper function to show a batch
  4. def show_landmarks_batch(sample_batched):
  5. """Show image with landmarks for a batch of samples."""
  6. images_batch, landmarks_batch = \
  7. sample_batched['image'], sample_batched['landmarks']
  8. batch_size = len(images_batch)
  9. im_size = images_batch.size(2)
  10. grid_border_size = 2
  11. grid = utils.make_grid(images_batch)
  12. plt.imshow(grid.numpy().transpose((1, 2, 0)))
  13. for i in range(batch_size):
  14. plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
  15. landmarks_batch[i, :, 1].numpy() + grid_border_size,
  16. s=10, marker='.', c='r')
  17. plt.title('Batch from dataloader')
  18. for i_batch, sample_batched in enumerate(dataloader):
  19. print(i_batch, sample_batched['image'].size(),
  20. sample_batched['landmarks'].size())
  21. # observe 4th batch and stop.
  22. if i_batch == 3:
  23. plt.figure()
  24. show_landmarks_batch(sample_batched)
  25. plt.axis('off')
  26. plt.ioff()
  27. plt.show()
  28. break

../_images/sphx_glr_data_loading_tutorial_004.png

输出:

  1. 0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
  2. 1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
  3. 2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
  4. 3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

后记:torchvision

在本教程中,我们已经看到了如何编写和使用数据集,转换和数据加载器。 torchvision包提供了一些常见的数据集和转换。 您甚至不必编写自定义类。 Torchvision 中可用的更通用的数据集之一是ImageFolder。 假定图像的组织方式如下:

  1. root/ants/xxx.png
  2. root/ants/xxy.jpeg
  3. root/ants/xxz.png
  4. .
  5. .
  6. .
  7. root/bees/123.jpg
  8. root/bees/nsdf3.png
  9. root/bees/asd932_.png

其中“蚂蚁”,“蜜蜂”等是类别标签。 同样也可以使用对PIL.ImageScalePIL.Image进行操作的通用转换。 您可以使用以下代码编写数据加载器,如下所示:

  1. import torch
  2. from torchvision import transforms, datasets
  3. data_transform = transforms.Compose([
  4. transforms.RandomSizedCrop(224),
  5. transforms.RandomHorizontalFlip(),
  6. transforms.ToTensor(),
  7. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  8. std=[0.229, 0.224, 0.225])
  9. ])
  10. hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
  11. transform=data_transform)
  12. dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
  13. batch_size=4, shuffle=True,
  14. num_workers=4)

有关训练代码的示例,请参见计算机视觉转换学习教程

脚本的总运行时间:(0 分钟 58.611 秒)

Download Python source code: data_loading_tutorial.py Download Jupyter notebook: data_loading_tutorial.ipynb

由狮身人面像画廊生成的画廊