主要包括三个类:Dataset、sampler.Sampler、DataLoader
pytorch读取训练集需要使用到2个torch.utils.data类:Dataset、DataLoader
Dataset
init:初始化,转化为张量
getitem(self, index)函数:根据索引序号获取item
len(self)函数:获取数据集的长度
class COVID19Dataset(Dataset):
'''
x: Features.
y: Targets, if none, do prediction.
'''
def __init__(self, x, y=None):
if y is None:
self.y = y
else:
self.y = torch.FloatTensor(y)
self.x = torch.FloatTensor(x)
def __getitem__(self, idx):
if self.y is None:
return self.x[idx]
else:
return self.x[idx], self.y[idx]
def __len__(self):
return len(self.x)
sampler.Sampler
创建一个采样器
__iter(self):获取一个迭代器,对数据集中元素的索引进行迭代
**len**(self)__:返回迭代器中包含元素的长度
class RandomSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(torch.randperm(len(self.data_source)).long())
def __len__(self):
return len(self.data_source)