主要包括三个类:Dataset、sampler.Sampler、DataLoader
pytorch读取训练集需要使用到2个torch.utils.data类:Dataset、DataLoader

Dataset

init:初始化,转化为张量
getitem(self, index)函数:根据索引序号获取item
len(self)函数:获取数据集的长度

  1. class COVID19Dataset(Dataset):
  2. '''
  3. x: Features.
  4. y: Targets, if none, do prediction.
  5. '''
  6. def __init__(self, x, y=None):
  7. if y is None:
  8. self.y = y
  9. else:
  10. self.y = torch.FloatTensor(y)
  11. self.x = torch.FloatTensor(x)
  12. def __getitem__(self, idx):
  13. if self.y is None:
  14. return self.x[idx]
  15. else:
  16. return self.x[idx], self.y[idx]
  17. def __len__(self):
  18. return len(self.x)

sampler.Sampler

创建一个采样器
__iter(self):获取一个迭代器,对数据集中元素的索引进行迭代
**
len**(self)__:返回迭代器中包含元素的长度

  1. class RandomSampler(Sampler):
  2. def __init__(self, data_source):
  3. self.data_source = data_source
  4. def __iter__(self):
  5. return iter(torch.randperm(len(self.data_source)).long())
  6. def __len__(self):
  7. return len(self.data_source)

DataLoader