自定义Dataloader,网上给出的解决方案都是继承 Dataset,然后重写里面的方法,比如这样:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
#transforms.Scale(256),
#transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
def default_loader(path):
img_pil = Image.open(path)
img_pil = img_pil.resize((224,224))
img_tensor = preprocess(img_pil)
return img_tensor
class trainset(Dataset):
def __init__(self, loader=default_loader):
#定义好 image 的路径
self.images = file_train
self.target = number_train
self.loader = loader
def __getitem__(self, index):
fn = self.images[index]
img = self.loader(fn)
target = self.target[index]
return (img,target) # 确定返回的都是 Tensor,而且 shape 要一致
def __len__(self):
return len(self.images)
用的时候就是这样:
train_data = trainset()
trainloader = DataLoader(train_data, batch_size=4,shuffle=True)
其中重点是getitem这个方法,要确定返回的都是 Tensor,而且 shape 要一致。如果返回的是对象或者字符串,都不会达到你想批训练的效果的。
如果你的训练数据是字符串的话,那么就想办法先转换成等长向量,再处理。
参考
Tokenizer
PyTorch手把手自定义Dataloader读取数据
https://www.pytorchtutorial.com/pytorch-custom-dataset-examples/
pytorch学习:自定义Datasets以及DataLoader
https://www.cnblogs.com/kk17/p/10105862.html
https://github.com/miraclewkf/ImageClassification-PyTorch/blob/master/level2/train_customData.py