参考来源:
Pytorch 划分数据集的方法
Pytorch提供的对数据集进行操作的函数详见:https://pytorch.org/docs/master/data.html#torch.utils.data.SubsetRandomSampler
1. torch.utils.data
torch 的这个文件包含了一些关于数据集处理的类:
**class torch.utils.data.Dataset**
:一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len()
(提供数据集的大小)、getitem()
(支持整数索引)。**class torch.utils.data.TensorDataset**
:封装成tensor
的数据集,每一个样本都通过索引张量来获得。**class torch.utils.data.ConcatDataset**
:连接不同的数据集以构成更大的新数据集。**class torch.utils.data.Subset(dataset, indices)**
:获取指定一个索引序列对应的子数据集。**class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)**
:数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。**torch.utils.data.random_split(dataset, lengths)**
:按照给定的长度将数据集划分成没有重叠的新数据集组合。**class torch.utils.data.Sampler(data_source)**
:所有采样的器的基类。每个采样器子类都需要提供iter()
方法以方便迭代器进行索引 和一个len()
方法 以方便返回迭代器的长度。**class torch.utils.data.SequentialSampler(data_source)**
:顺序采样样本,始终按照同一个顺序。**class torch.utils.data.RandomSampler(data_source)**
:无放回地随机采样样本元素。**class torch.utils.data.SubsetRandomSampler(indices)**
:无放回地按照给定的索引列表采样样本元素。**class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)**
:按照给定的概率来采样样本。**class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)**
:在一个batch
中封装一个其他的采样器。**class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None)**
:采样器可以约束数据加载进数据集的子集。