本文基于PyTorch1.7.0,https://github.com/pytorch/pytorch/tree/v1.7.0 参考:

PyTorch的Dataset有两种类型,一种是map-style dataset,另一种是iterable-style dataset

Map-style datasets

A map-style dataset is one that implements the __getitem__() and __len__() protocols, and represents a map from (possibly non-integral) indices/keys to data samples.
For example, such a dataset, when accessed with dataset[idx], could read the idx-th image and its corresponding label from a folder on the disk.

Dataset

  1. [https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/dataset.py#L15]
  2. class Dataset(Generic[T_co]):
  3. def __getitem__(self, index) -> T_co:
  4. raise NotImplementedError
  5. def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
  6. return ConcatDataset([self, other])
  7. # No `def __len__(self)` default?
  8. # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  9. # in pytorch/torch/utils/data/sampler.py

如果我们想定义一个新的数据集,那么我们需要继承Dataset类,并重写getitem方法,len方法可重写也可不重写。

  1. [https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files]
  2. import os
  3. import pandas as pd
  4. from torchvision.io import read_image
  5. class CustomImageDataset(Dataset):
  6. def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
  7. self.img_labels = pd.read_csv(annotations_file)
  8. self.img_dir = img_dir
  9. self.transform = transform
  10. self.target_transform = target_transform
  11. def __len__(self):
  12. return len(self.img_labels)
  13. def __getitem__(self, idx):
  14. img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
  15. image = read_image(img_path)
  16. label = self.img_labels.iloc[idx, 1]
  17. if self.transform:
  18. image = self.transform(image)
  19. if self.target_transform:
  20. label = self.target_transform(label)
  21. sample = {"image": image, "label": label}
  22. return sample

ConcatDataset

合并多个数据集

[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/dataset.py#L176]
class ConcatDataset(Dataset[T_co]):
    r"""Dataset as a concatenation of multiple datasets.

    This class is useful to assemble different existing datasets.

    Args:
        datasets (sequence): List of datasets to be concatenated
    """
    datasets: List[Dataset[T_co]]
    cumulative_sizes: List[int]

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super(ConcatDataset, self).__init__()
        # Cannot verify that datasets is Sized
        assert len(datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore
        self.datasets = list(datasets)
        for d in self.datasets:
            assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes

Iterable-style datasets

iterable-style datasetIterableDataset子类的实例,该子类实现了__iter__()方法。iterable-style dataset适用于随机读取代价昂贵或者不可取,batch size的大小取决于获取到的数据。
例如,当我们调用iter(dataset)时,会返回一个数据流,该数据流可以是从数据库读到的,也可以是远程的服务器或者实时生成的。

IterableDataset

可迭代类型的数据都应该继承IterableDataset,此种类型的数据尤其适用于数据流。
所有的子类都应该重载__iter__方法,该方法返回一个迭代器。
当一个子类使用DataLoader,数据集的每个元素会从DataLoader迭代器中返回。当读取数据的工作进程数num_workers > 0时,每个工作进程都会有一个数据集对象,即有num_workers个数据集对象,因此我们需要设法避免读取重复的数据。

[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/dataset.py#L42]
class IterableDataset(Dataset[T_co]):

    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

    def __add__(self, other: Dataset[T_co]):
        return ChainDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]

有两种方法可以避免读取重复的数据,一种是在__iter__()中处理,另一种是使用DataLoaderworker_init_fn 种处理。

1 iter

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]

>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]

2 worker_init_fn

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]

ChainDataset

链接多个可迭代类型的数据集对象

[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/dataset.py#L227]
class ChainDataset(IterableDataset):
    r"""Dataset for chainning multiple :class:`IterableDataset` s.
    This class is useful to assemble different existing dataset streams. The
    chainning operation is done on-the-fly, so concatenating large-scale
    datasets with this class will be efficient.
    Arguments:
        datasets (iterable of IterableDataset): datasets to be chained together
    """
    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super(ChainDataset, self).__init__()
        self.datasets = datasets

    def __iter__(self):
        for d in self.datasets:
            assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
            for x in d:
                yield x

    def __len__(self):
        total = 0
        for d in self.datasets:
            assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
            # Cannot verify that all self.datasets are Sized
            total += len(d)  # type: ignore
        return total

torchvision datasets

torchvision实现了很多常用的数据集,https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision-datasets