1. from torch.utils.data import Dataset, DataLoader
    2. class Example:
    3. def __init__(self):
    4. self.gg = 'hhh'
    5. self.hh = [2, 2, 2]
    6. class test_dataset(Dataset):
    7. def __init__(self):
    8. self.data=[]
    9. for i in range(16):
    10. self.data.append(Example())
    11. def __len__(self):
    12. return len(self.data)
    13. def __getitem__(self, i):
    14. return self.data[i]
    15. def collate_fn(batch):
    16. data1 = []
    17. data2 = []
    18. for i in batch:
    19. data1.append(i.gg)
    20. data2.append(i.hh)
    21. return data1,data2
    22. data = test_dataset()
    23. loader = DataLoader(data, batch_size=1, shuffle=False,collate_fn=collate_fn)
    24. for x,y in loader:
    25. print(x)
    26. print(y)
    27. exit()
    28. output:
    29. ['hhh']
    30. [[2, 2, 2]]