自定义Dataloader,网上给出的解决方案都是继承 Dataset,然后重写里面的方法,比如这样:

    1. from torch.utils.data import Dataset, DataLoader
    2. from torchvision import transforms, utils
    3. normalize = transforms.Normalize(
    4. mean=[0.485, 0.456, 0.406],
    5. std=[0.229, 0.224, 0.225]
    6. )
    7. preprocess = transforms.Compose([
    8. #transforms.Scale(256),
    9. #transforms.CenterCrop(224),
    10. transforms.ToTensor(),
    11. normalize
    12. ])
    13. def default_loader(path):
    14. img_pil = Image.open(path)
    15. img_pil = img_pil.resize((224,224))
    16. img_tensor = preprocess(img_pil)
    17. return img_tensor
    18. class trainset(Dataset):
    19. def __init__(self, loader=default_loader):
    20. #定义好 image 的路径
    21. self.images = file_train
    22. self.target = number_train
    23. self.loader = loader
    24. def __getitem__(self, index):
    25. fn = self.images[index]
    26. img = self.loader(fn)
    27. target = self.target[index]
    28. return (img,target) # 确定返回的都是 Tensor,而且 shape 要一致
    29. def __len__(self):
    30. return len(self.images)

    用的时候就是这样:

    1. train_data = trainset()
    2. trainloader = DataLoader(train_data, batch_size=4,shuffle=True)

    其中重点是getitem这个方法,要确定返回的都是 Tensor,而且 shape 要一致。如果返回的是对象或者字符串,都不会达到你想批训练的效果的。

    如果你的训练数据是字符串的话,那么就想办法先转换成等长向量,再处理。

    参考
    Tokenizer
    PyTorch手把手自定义Dataloader读取数据
    https://www.pytorchtutorial.com/pytorch-custom-dataset-examples/
    pytorch学习:自定义Datasets以及DataLoader
    https://www.cnblogs.com/kk17/p/10105862.html
    https://github.com/miraclewkf/ImageClassification-PyTorch/blob/master/level2/train_customData.py