本文基于PyTorch1.7.0,https://github.com/pytorch/pytorch/tree/v1.7.0
Sampler决定数据加载的顺序,可能是顺序,也可能是乱序(如果设置了shuffle为True)。
Sampler只适用于map-style类型的数据集。
Sampler
[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/sampler.py#L9]
class Sampler(Generic[T_co]):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source: Optional[Sized]) -> None:
pass
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
SequentialSampler
[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/sampler.py#L55]
class SequentialSampler(Sampler[int]):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
data_source: Sized
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
RandomSampler
- 如果replacement为False,则把所有样本打乱,然后从中读取
如果replacement为True,则从所有样本中一共读取num_samples个样本,每次取32个样本,一共取
次,不同次之间的样本可能会有重叠
[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/sampler.py#L73] class RandomSampler(Sampler[int]): r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify :attr:`num_samples` to draw. Arguments: data_source (Dataset): dataset to sample from replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. This argument is supposed to be specified only when `replacement` is ``True``. generator (Generator): Generator used in sampling. """ data_source: Sized replacement: bool def __init__(self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None) -> None: self.data_source = data_source self.replacement = replacement self._num_samples = num_samples self.generator = generator if not isinstance(self.replacement, bool): raise TypeError("replacement should be a boolean value, but got " "replacement={}".format(self.replacement)) if self._num_samples is not None and not replacement: raise ValueError("With replacement=False, num_samples should not be specified, " "since a random permute will be performed.") if not isinstance(self.num_samples, int) or self.num_samples <= 0: raise ValueError("num_samples should be a positive integer " "value, but got num_samples={}".format(self.num_samples)) @property def num_samples(self) -> int: # dataset size might change at runtime if self._num_samples is None: return len(self.data_source) return self._num_samples def __iter__(self): n = len(self.data_source) if self.generator is None: generator = torch.Generator() generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) else: generator = self.generator if self.replacement: for _ in range(self.num_samples // 32): # https://pytorch.org/docs/stable/generated/torch.randint.html?highlight=randint#torch.randint yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() else: yield from torch.randperm(n, generator=self.generator).tolist() def __len__(self): return self.num_samples
SubsetRandomSampler
[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/sampler.py#L131]
WeightedRandomSampler
[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/sampler.py#L151]
BatchSampler
[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/sampler.py#L194] class BatchSampler(Sampler[List[int]]): r"""Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler or Iterable): Base sampler. Can be any iterable object batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None: # Since collections.abc.Iterable does not check for `__getitem__`, which # is one way for an object to be an iterable, we don't do an `isinstance` # check here. if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ batch_size <= 0: raise ValueError("batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last)) self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch def __len__(self): # Can only be called if self.sampler has __len__ implemented # We cannot enforce this condition, so we turn off typechecking for the # implementation below. # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] if self.drop_last: return len(self.sampler) // self.batch_size # type: ignore else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore