参考来源:
Pytorch 划分数据集的方法
Pytorch提供的对数据集进行操作的函数详见:https://pytorch.org/docs/master/data.html#torch.utils.data.SubsetRandomSampler

1. torch.utils.data

torch 的这个文件包含了一些关于数据集处理的类:

  • **class torch.utils.data.Dataset**:一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len()(提供数据集的大小)、getitem() (支持整数索引)。
  • **class torch.utils.data.TensorDataset**:封装成 tensor 的数据集,每一个样本都通过索引张量来获得。
  • **class torch.utils.data.ConcatDataset**:连接不同的数据集以构成更大的新数据集。
  • **class torch.utils.data.Subset(dataset, indices)**:获取指定一个索引序列对应的子数据集。
  • **class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)**:数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。
  • **torch.utils.data.random_split(dataset, lengths)**:按照给定的长度将数据集划分成没有重叠的新数据集组合。
  • **class torch.utils.data.Sampler(data_source)**:所有采样的器的基类。每个采样器子类都需要提供 iter() 方法以方便迭代器进行索引 和一个 len()方法 以方便返回迭代器的长度。
  • **class torch.utils.data.SequentialSampler(data_source)**:顺序采样样本,始终按照同一个顺序。
  • **class torch.utils.data.RandomSampler(data_source)**:无放回地随机采样样本元素。
  • **class torch.utils.data.SubsetRandomSampler(indices)**:无放回地按照给定的索引列表采样样本元素。
  • **class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)**:按照给定的概率来采样样本。
  • **class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)**:在一个 batch 中封装一个其他的采样器。
  • **class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None)**:采样器可以约束数据加载进数据集的子集。