PyTorch
PyTorch作为一款流行深度学习框架其热度大有超越TensorFlow的感觉。根据此前的统计,目前TensorFlow虽然仍然占据着工业界,但PyTorch在视觉和NLP领域的顶级会议上已呈一统之势。
这里聚焦于PyTorch的自定义数据读取pipeline模板和相关trciks以及如何优化数据读取的pipeline等。从PyTorch的数据对象类Dataset开始。Dataset在PyTorch中的模块位于utils.data下。

  1. from torch.utils.data import Dataset

将围绕Dataset对象分别从原始模板、torchvisiontransforms模块、使用pandas来辅助读取、torch内置数据划分功能和DataLoader来展开阐述。

Dataset原始模板

PyTorch官方提供了自定义数据读取的标准化代码代码模块,作为一个读取框架,这里称之为原始模板。其代码结构如下:

  1. from torch.utils.data import Dataset
  2. class CustomDataset(Dataset):
  3. def __init__(self, ...):
  4. # stuff
  5. def __getitem__(self, index):
  6. # stuff
  7. return (img, label)
  8. def __len__(self):
  9. # return examples size
  10. return count

根据这个标准化的代码模板,只需要根据自己的数据读取任务,分别往__init__()__getitem__()__len__()三个方法里添加读取逻辑即可。作为PyTorch范式下的数据读取以及为了后续的data loader,三个方法缺一不可。其中:

  • __init__()函数用于初始化数据读取逻辑,比如读取包含标签和图片地址的csv文件、定义transform组合等。
  • __getitem__()函数用来返回数据和标签。目的上是为了能够被后续的dataloader所调用。
  • __len__()函数则用于返回样本数量。

现在往这个框架里填几行代码来形成一个简单的数字案例。创建一个从1到100的数字例子:

  1. from torch.utils.data import Dataset
  2. class CustomDataset(Dataset):
  3. def __init__(self):
  4. self.samples = list(range(1, 101))
  5. def __len__(self):
  6. return len(self.samples)
  7. def __getitem__(self, idx):
  8. return self.samples[idx]
  9. if __name__ == '__main__':
  10. dataset = CustomDataset()
  11. print(len(dataset))
  12. print(dataset[50])
  13. print(dataset[1:100])

PyTorch数据Pipeline标准化代码模板 - 图1

添加torchvision.transforms

然后来看如何从内存中读取数据以及如何在读取过程中嵌入torchvision中的transforms功能。torchvision是一个独立于torch的关于数据、模型和一些图像增强操作的辅助库。主要包括datasets默认数据集模块、models经典模型模块、transforms图像增强模块以及utils模块等。在使用torch读取数据的时候,一般会搭配上transforms模块对数据进行一些处理和增强工作。
添加了tranforms之后的读取模块可以改写为:

  1. from torch.utils.data import Dataset
  2. from torchvision import transforms as T
  3. class CustomDataset(Dataset):
  4. def __init__(self, ...):
  5. # stuff
  6. ...
  7. # compose the transforms methods
  8. self.transform = T.Compose([T.CenterCrop(100),
  9. T.ToTensor()])
  10. def __getitem__(self, index):
  11. # stuff
  12. ...
  13. data = # Some data read from a file or image
  14. # execute the transform
  15. data = self.transform(data)
  16. return (img, label)
  17. def __len__(self):
  18. # return examples size
  19. return count
  20. if __name__ == '__main__':
  21. # Call the dataset
  22. custom_dataset = CustomDataset(...)

可以看到,使用了Compose方法来把各种数据处理方法聚合到一起进行定义数据转换方法。通常作为初始化方法放在__init__()函数下。以猫狗图像数据为例进行说明。
PyTorch数据Pipeline标准化代码模板 - 图2

定义数据读取方法如下:

  1. class DogCat(Dataset):
  2. def __init__(self, root, transforms=None, train=True, val=False):
  3. """
  4. get images and execute transforms.
  5. """
  6. self.val = val
  7. imgs = [os.path.join(root, img) for img in os.listdir(root)]
  8. # train: Cats_Dogs/trainset/cat.1.jpg
  9. # val: Cats_Dogs/valset/cat.10004.jpg
  10. imgs = sorted(imgs, key=lambda x: x.split('.')[-2])
  11. self.imgs = imgs
  12. if transforms is None:
  13. # normalize
  14. normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
  15. std = [0.229, 0.224, 0.225])
  16. # trainset and valset have different data transform
  17. # trainset need data augmentation but valset don't.
  18. # valset
  19. if self.val:
  20. self.transforms = T.Compose([
  21. T.Resize(224),
  22. T.CenterCrop(224),
  23. T.ToTensor(),
  24. normalize
  25. ])
  26. # trainset
  27. else:
  28. self.transforms = T.Compose([
  29. T.Resize(256),
  30. T.RandomResizedCrop(224),
  31. T.RandomHorizontalFlip(),
  32. T.ToTensor(),
  33. normalize
  34. ])
  35. def __getitem__(self, index):
  36. """
  37. return data and label
  38. """
  39. img_path = self.imgs[index]
  40. label = 1 if 'dog' in img_path.split('/')[-1] else 0
  41. data = Image.open(img_path)
  42. data = self.transforms(data)
  43. return data, label
  44. def __len__(self):
  45. """
  46. return images size.
  47. """
  48. return len(self.imgs)
  49. if __name__ == "__main__":
  50. train_dataset = DogCat('./Cats_Dogs/trainset/', train=True)
  51. print(len(train_dataset))
  52. print(train_dataset[0])

因为这个数据集已经分好了训练集和验证集,所以在读取和transforms的时候需要进行区分。运行示例如下:
PyTorch数据Pipeline标准化代码模板 - 图3

与pandas一起使用

很多时候数据的目录地址和标签都是通过csv文件给出的。如下所示:
PyTorch数据Pipeline标准化代码模板 - 图4
此时在数据读取的pipeline中需要在__init__()方法中利用pandas把csv文件中包含的图片地址和标签融合进去。相应的数据读取pipeline模板可以改写为:

  1. class CustomDatasetFromCSV(Dataset):
  2. def __init__(self, csv_path):
  3. """
  4. Args:
  5. csv_path (string): path to csv file
  6. transform: pytorch transforms for transforms and tensor conversion
  7. """
  8. # Transforms
  9. self.to_tensor = transforms.ToTensor()
  10. # Read the csv file
  11. self.data_info = pd.read_csv(csv_path, header=None)
  12. # First column contains the image paths
  13. self.image_arr = np.asarray(self.data_info.iloc[:, 0])
  14. # Second column is the labels
  15. self.label_arr = np.asarray(self.data_info.iloc[:, 1])
  16. # Calculate len
  17. self.data_len = len(self.data_info.index)
  18. def __getitem__(self, index):
  19. # Get image name from the pandas df
  20. single_image_name = self.image_arr[index]
  21. # Open image
  22. img_as_img = Image.open(single_image_name)
  23. # Transform image to tensor
  24. img_as_tensor = self.to_tensor(img_as_img)
  25. # Get label of the image based on the cropped pandas column
  26. single_image_label = self.label_arr[index]
  27. return (img_as_tensor, single_image_label)
  28. def __len__(self):
  29. return self.data_len
  30. if __name__ == "__main__":
  31. # Call dataset
  32. dataset = CustomDatasetFromCSV('./labels.csv')

以mnist_label.csv文件为示例:

  1. from torch.utils.data import Dataset
  2. from torch.utils.data import DataLoader
  3. from torchvision import transforms as T
  4. from PIL import Image
  5. import os
  6. import numpy as np
  7. import pandas as pd
  8. class CustomDatasetFromCSV(Dataset):
  9. def __init__(self, csv_path):
  10. """
  11. Args:
  12. csv_path (string): path to csv file
  13. transform: pytorch transforms for transforms and tensor conversion
  14. """
  15. # Transforms
  16. self.to_tensor = T.ToTensor()
  17. # Read the csv file
  18. self.data_info = pd.read_csv(csv_path, header=None)
  19. # First column contains the image paths
  20. self.image_arr = np.asarray(self.data_info.iloc[:, 0])
  21. # Second column is the labels
  22. self.label_arr = np.asarray(self.data_info.iloc[:, 1])
  23. # Third column is for an operation indicator
  24. self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
  25. # Calculate len
  26. self.data_len = len(self.data_info.index)
  27. def __getitem__(self, index):
  28. # Get image name from the pandas df
  29. single_image_name = self.image_arr[index]
  30. # Open image
  31. img_as_img = Image.open(single_image_name)
  32. # Check if there is an operation
  33. some_operation = self.operation_arr[index]
  34. # If there is an operation
  35. if some_operation:
  36. # Do some operation on image
  37. # ...
  38. # ...
  39. pass
  40. # Transform image to tensor
  41. img_as_tensor = self.to_tensor(img_as_img)
  42. # Get label of the image based on the cropped pandas column
  43. single_image_label = self.label_arr[index]
  44. return (img_as_tensor, single_image_label)
  45. def __len__(self):
  46. return self.data_len
  47. if __name__ == "__main__":
  48. transform = T.Compose([T.ToTensor()])
  49. dataset = CustomDatasetFromCSV('./mnist_labels.csv')
  50. print(len(dataset))
  51. print(dataset[5])

运行示例如下:
PyTorch数据Pipeline标准化代码模板 - 图5

训练集验证集划分

一般来说,为了模型训练的稳定,需要对数据划分训练集和验证集。torch的Dataset对象也提供了random_split函数作为数据划分工具,且划分结果可直接供后续的DataLoader使用。
以kaggle的花朵数据为例:

  1. from torch.utils.data import DataLoader
  2. from torchvision.datasets import ImageFolder
  3. from torchvision import transforms as T
  4. from torch.utils.data import random_split
  5. transform = T.Compose([
  6. T.Resize((224, 224)),
  7. T.RandomHorizontalFlip(),
  8. T.ToTensor()
  9. ])
  10. dataset = ImageFolder('./flowers_photos', transform=transform)
  11. print(dataset.class_to_idx)
  12. trainset, valset = random_split(dataset,
  13. [int(len(dataset)*0.7), len(dataset)-int(len(dataset)*0.7)])
  14. trainloader = DataLoader(dataset=trainset, batch_size=32, shuffle=True, num_workers=1)
  15. for i, (img, label) in enumerate(trainloader):
  16. img, label = img.numpy(), label.numpy()
  17. print(img, label)
  18. valloader = DataLoader(dataset=valset, batch_size=32, shuffle=True, num_workers=1)
  19. for i, (img, label) in enumerate(trainloader):
  20. img, label = img.numpy(), label.numpy()
  21. print(img.shape, label)

这里使用了ImageFolder模块,可以直接读取各标签对应的文件夹,部分运行示例如下:
PyTorch数据Pipeline标准化代码模板 - 图6

使用DataLoader

dataset方法写好之后,还需要使用DataLoader将其逐个喂给模型。上一节的数据划分已经用到了DataLoader函数。从本质上来讲,DataLoader只是调用了__getitem__()方法并按批次返回数据和标签。使用方法如下:

  1. from torch.utils.data import DataLoader
  2. from torchvision import transforms as T
  3. if __name__ == "__main__":
  4. # Define transforms
  5. transformations = T.Compose([T.ToTensor()])
  6. # Define custom dataset
  7. dataset = CustomDatasetFromCSV('./labels.csv')
  8. # Define data loader
  9. data_loader = DataLoader(dataset=dataset, batch_size=10, shuffle=True)
  10. for images, labels in data_loader:
  11. # Feed the data to the model

以上就是PyTorch读取数据的Pipeline主要方法和流程。基于Dataset对象的基本框架不变,具体细节可自定义化调整。