自定义Dataloader,网上给出的解决方案都是继承 Dataset,然后重写里面的方法,比如这样:
from torch.utils.data import Dataset, DataLoaderfrom torchvision import transforms, utilsnormalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])preprocess = transforms.Compose([#transforms.Scale(256),#transforms.CenterCrop(224),transforms.ToTensor(),normalize])def default_loader(path):img_pil = Image.open(path)img_pil = img_pil.resize((224,224))img_tensor = preprocess(img_pil)return img_tensorclass trainset(Dataset):def __init__(self, loader=default_loader):#定义好 image 的路径self.images = file_trainself.target = number_trainself.loader = loaderdef __getitem__(self, index):fn = self.images[index]img = self.loader(fn)target = self.target[index]return (img,target) # 确定返回的都是 Tensor,而且 shape 要一致def __len__(self):return len(self.images)
用的时候就是这样:
train_data = trainset()trainloader = DataLoader(train_data, batch_size=4,shuffle=True)
其中重点是getitem这个方法,要确定返回的都是 Tensor,而且 shape 要一致。如果返回的是对象或者字符串,都不会达到你想批训练的效果的。
如果你的训练数据是字符串的话,那么就想办法先转换成等长向量,再处理。
参考
Tokenizer
PyTorch手把手自定义Dataloader读取数据
https://www.pytorchtutorial.com/pytorch-custom-dataset-examples/
pytorch学习:自定义Datasets以及DataLoader
https://www.cnblogs.com/kk17/p/10105862.html
https://github.com/miraclewkf/ImageClassification-PyTorch/blob/master/level2/train_customData.py
