from torch.utils.data import Dataset, DataLoader
class Example:
def __init__(self):
self.gg = 'hhh'
self.hh = [2, 2, 2]
class test_dataset(Dataset):
def __init__(self):
self.data=[]
for i in range(16):
self.data.append(Example())
def __len__(self):
return len(self.data)
def __getitem__(self, i):
return self.data[i]
def collate_fn(batch):
data1 = []
data2 = []
for i in batch:
data1.append(i.gg)
data2.append(i.hh)
return data1,data2
data = test_dataset()
loader = DataLoader(data, batch_size=1, shuffle=False,collate_fn=collate_fn)
for x,y in loader:
print(x)
print(y)
exit()
output:
['hhh']
[[2, 2, 2]]