本文基于PyTorch1.7.0,https://github.com/pytorch/pytorch/tree/v1.7.0

Sampler决定数据加载的顺序,可能是顺序,也可能是乱序(如果设置了shuffle为True)。
Sampler只适用于map-style类型的数据集。

Sampler

  1. [https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/sampler.py#L9]
  2. class Sampler(Generic[T_co]):
  3. r"""Base class for all Samplers.
  4. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
  5. way to iterate over indices of dataset elements, and a :meth:`__len__` method
  6. that returns the length of the returned iterators.
  7. .. note:: The :meth:`__len__` method isn't strictly required by
  8. :class:`~torch.utils.data.DataLoader`, but is expected in any
  9. calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
  10. """
  11. def __init__(self, data_source: Optional[Sized]) -> None:
  12. pass
  13. def __iter__(self) -> Iterator[T_co]:
  14. raise NotImplementedError

SequentialSampler

[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/sampler.py#L55]
class SequentialSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.

    Arguments:
        data_source (Dataset): dataset to sample from
    """
    data_source: Sized

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

RandomSampler

  • 如果replacement为False,则把所有样本打乱,然后从中读取
  • 如果replacement为True,则从所有样本中一共读取num_samples个样本,每次取32个样本,一共取Sampler - 图1次,不同次之间的样本可能会有重叠

    [https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/sampler.py#L73]
    class RandomSampler(Sampler[int]):
      r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
      If with replacement, then user can specify :attr:`num_samples` to draw.
    
      Arguments:
          data_source (Dataset): dataset to sample from
          replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
          num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
              is supposed to be specified only when `replacement` is ``True``.
          generator (Generator): Generator used in sampling.
      """
      data_source: Sized
      replacement: bool
    
      def __init__(self, data_source: Sized, replacement: bool = False,
                   num_samples: Optional[int] = None, generator=None) -> None:
          self.data_source = data_source
          self.replacement = replacement
          self._num_samples = num_samples
          self.generator = generator
    
          if not isinstance(self.replacement, bool):
              raise TypeError("replacement should be a boolean value, but got "
                              "replacement={}".format(self.replacement))
    
          if self._num_samples is not None and not replacement:
              raise ValueError("With replacement=False, num_samples should not be specified, "
                               "since a random permute will be performed.")
    
          if not isinstance(self.num_samples, int) or self.num_samples <= 0:
              raise ValueError("num_samples should be a positive integer "
                               "value, but got num_samples={}".format(self.num_samples))
    
      @property
      def num_samples(self) -> int:
          # dataset size might change at runtime
          if self._num_samples is None:
              return len(self.data_source)
          return self._num_samples
    
      def __iter__(self):
          n = len(self.data_source)
          if self.generator is None:
              generator = torch.Generator()
              generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
          else:
              generator = self.generator
          if self.replacement:
              for _ in range(self.num_samples // 32):
                  # https://pytorch.org/docs/stable/generated/torch.randint.html?highlight=randint#torch.randint
                  yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
              yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
          else:
              yield from torch.randperm(n, generator=self.generator).tolist()
    
      def __len__(self):
          return self.num_samples
    

    SubsetRandomSampler

    [https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/sampler.py#L131]
    

    WeightedRandomSampler

    [https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/sampler.py#L151]
    

    BatchSampler

    [https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/sampler.py#L194]
    class BatchSampler(Sampler[List[int]]):
      r"""Wraps another sampler to yield a mini-batch of indices.
    
      Args:
          sampler (Sampler or Iterable): Base sampler. Can be any iterable object
          batch_size (int): Size of mini-batch.
          drop_last (bool): If ``True``, the sampler will drop the last batch if
              its size would be less than ``batch_size``
    
      Example:
          >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
          [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
          >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
          [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
      """
    
      def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:
          # Since collections.abc.Iterable does not check for `__getitem__`, which
          # is one way for an object to be an iterable, we don't do an `isinstance`
          # check here.
          if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                  batch_size <= 0:
              raise ValueError("batch_size should be a positive integer value, "
                               "but got batch_size={}".format(batch_size))
          if not isinstance(drop_last, bool):
              raise ValueError("drop_last should be a boolean value, but got "
                               "drop_last={}".format(drop_last))
          self.sampler = sampler
          self.batch_size = batch_size
          self.drop_last = drop_last
    
      def __iter__(self):
          batch = []
          for idx in self.sampler:
              batch.append(idx)
              if len(batch) == self.batch_size:
                  yield batch
                  batch = []
          if len(batch) > 0 and not self.drop_last:
              yield batch
    
      def __len__(self):
          # Can only be called if self.sampler has __len__ implemented
          # We cannot enforce this condition, so we turn off typechecking for the
          # implementation below.
          # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
          if self.drop_last:
              return len(self.sampler) // self.batch_size  # type: ignore
          else:
              return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore