主要包括三个类:Dataset、sampler.Sampler、DataLoader
pytorch读取训练集需要使用到2个torch.utils.data类:Dataset、DataLoader
Dataset
init:初始化,转化为张量
getitem(self, index)函数:根据索引序号获取item
len(self)函数:获取数据集的长度
class COVID19Dataset(Dataset):'''x: Features.y: Targets, if none, do prediction.'''def __init__(self, x, y=None):if y is None:self.y = yelse:self.y = torch.FloatTensor(y)self.x = torch.FloatTensor(x)def __getitem__(self, idx):if self.y is None:return self.x[idx]else:return self.x[idx], self.y[idx]def __len__(self):return len(self.x)
sampler.Sampler
创建一个采样器
__iter(self):获取一个迭代器,对数据集中元素的索引进行迭代
**len**(self)__:返回迭代器中包含元素的长度
class RandomSampler(Sampler):def __init__(self, data_source):self.data_source = data_sourcedef __iter__(self):return iter(torch.randperm(len(self.data_source)).long())def __len__(self):return len(self.data_source)
