官方数据集加载(以 CIFAR10 为例)
https://pytorch.org/docs/stable/torchvision/datasets.html
import torchimport torchvisionfrom torchvision import transformsimport numpy as npimport matplotlib.pyplot as pltprint('==> Preparing data..')transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])trainset = torchvision.datasets.CIFAR10(root ="./data/", train = True, download = True, transform = transform_train)trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root ="./data/", train = False, download = True, transform = transform_test)testloader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size, shuffle=False, num_workers=2)
从文件夹中读取数据集
下面演示的是读取训练集的数据,读取验证集数据同理
其中,图片的结构如下所示
root/dog/xxx.pngroot/dog/xxy.pngroot/dog/xxz.pngroot/cat/123.pngroot/cat/nsdf3.pngroot/cat/asd932_.png
import torchimport torchvisionfrom torchvision import transformsimport numpy as npimport matplotlib.pyplot as pltdef imshow(inp, title=None):"""Imshow for Tensor."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)plt.imshow(inp)if title is not None:plt.title(title)plt.pause(0.001) # pause a bit so that plots are updateddata_transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])dataset = torchvision.datasets.ImageFolder("./data/", transform=data_transforms)data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)inputs, classes = next(iter(data_loader))# Make a grid from batchout = torchvision.utils.make_grid(inputs)class_name = dataset.classesimshow(out, title=[class_name[x] for x in classes])
自定义数据读取
从 txt 文件读取每一个图片的路径和标签,使用空格作为分隔符
from torch.utils.data import Dataset, DataLoaderfrom PIL import Imagefrom torchvision import datasets, models, transformsimport os# 创建一个数据集类:继承 Datasetclass My_DataSet(Dataset):"""root 是图片的路径,要结合 txt_file 文件的内容来设置"""def __init__(self, root, txt_file, transform=None):super(My_DataSet, self).__init__()self.root = rootself.txt_file = txt_filewith open(self.txt_file, "r") as f:c = f.read()self.img_label_list = c.splitlines()self.transform = transformdef __getitem__(self, index):img_path = self.img_label_list[index].split(" ")[0]# print(img_path)label = self.img_label_list[index].split(" ")[-1]img = Image.open(os.path.join(self.root, img_path)).convert("RGB")if self.transform is not None:img = self.transform(img)return img, int(label) # 得到的是字符串,故要进行类型转换def __len__(self):return len(self.img_label_list)if __name__ == "__main__":data_transforms = {'train': transforms.Compose([transforms.RandomCrop((224, 224)),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize((256, 256)),transforms.CenterCrop((224, 224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),}data_set = My_DataSet("./data/", "./data/GHIM-20/list_train.txt", transform=data_transforms["train"])data_loader = DataLoader(data_set, batch_size=3, shuffle=True)for data, label in DataLoader(data_set, batch_size=64, shuffle=True):break# # data, label =next(iter(data_loader))print(label)
pytorch 如何加载不同尺寸的数据集?
https://www.zhihu.com/question/395888465/answer/1234563938
from torch.utils.data import DataLoader, Datasetclass OurDataset(Dataset):def __init__(self, *tensors):self.tensors = tensorsdef __getitem__(self, index):return self.tensors[index]def __len__(self):return len(self.tensors)def collate_wrapper(batch):a, b = batchreturn a, ba = torch.randn(3, 2, 3)b = torch.randn(3, 3, 4)dataset = OurDataset(a, b)loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper)for sample in loader:print([x.size() for x in sample])# Out: [torch.Size([1, 3, 2, 3]), torch.Size([1, 3, 3, 4])]
