官方数据集加载(以 CIFAR10 为例)

https://pytorch.org/docs/stable/torchvision/datasets.html

  1. import torch
  2. import torchvision
  3. from torchvision import transforms
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. print('==> Preparing data..')
  7. transform_train = transforms.Compose([
  8. transforms.RandomCrop(32, padding=4),
  9. transforms.RandomHorizontalFlip(),
  10. transforms.ToTensor(),
  11. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
  12. transform_test = transforms.Compose([
  13. transforms.ToTensor(),
  14. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
  15. trainset = torchvision.datasets.CIFAR10(root ="./data/", train = True, download = True, transform = transform_train)
  16. trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=2)
  17. testset = torchvision.datasets.CIFAR10(root ="./data/", train = False, download = True, transform = transform_test)
  18. testloader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size, shuffle=False, num_workers=2)

从文件夹中读取数据集

下面演示的是读取训练集的数据,读取验证集数据同理

其中,图片的结构如下所示

  1. root/dog/xxx.png
  2. root/dog/xxy.png
  3. root/dog/xxz.png
  4. root/cat/123.png
  5. root/cat/nsdf3.png
  6. root/cat/asd932_.png
  1. import torch
  2. import torchvision
  3. from torchvision import transforms
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. def imshow(inp, title=None):
  7. """Imshow for Tensor."""
  8. inp = inp.numpy().transpose((1, 2, 0))
  9. mean = np.array([0.485, 0.456, 0.406])
  10. std = np.array([0.229, 0.224, 0.225])
  11. inp = std * inp + mean
  12. inp = np.clip(inp, 0, 1)
  13. plt.imshow(inp)
  14. if title is not None:
  15. plt.title(title)
  16. plt.pause(0.001) # pause a bit so that plots are updated
  17. data_transforms = transforms.Compose([
  18. transforms.ToTensor(),
  19. transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
  20. ])
  21. dataset = torchvision.datasets.ImageFolder("./data/", transform=data_transforms)
  22. data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)
  23. inputs, classes = next(iter(data_loader))
  24. # Make a grid from batch
  25. out = torchvision.utils.make_grid(inputs)
  26. class_name = dataset.classes
  27. imshow(out, title=[class_name[x] for x in classes])

自定义数据读取

从 txt 文件读取每一个图片的路径和标签,使用空格作为分隔符
image.png

  1. from torch.utils.data import Dataset, DataLoader
  2. from PIL import Image
  3. from torchvision import datasets, models, transforms
  4. import os
  5. # 创建一个数据集类:继承 Dataset
  6. class My_DataSet(Dataset):
  7. """
  8. root 是图片的路径,要结合 txt_file 文件的内容来设置
  9. """
  10. def __init__(self, root, txt_file, transform=None):
  11. super(My_DataSet, self).__init__()
  12. self.root = root
  13. self.txt_file = txt_file
  14. with open(self.txt_file, "r") as f:
  15. c = f.read()
  16. self.img_label_list = c.splitlines()
  17. self.transform = transform
  18. def __getitem__(self, index):
  19. img_path = self.img_label_list[index].split(" ")[0]
  20. # print(img_path)
  21. label = self.img_label_list[index].split(" ")[-1]
  22. img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
  23. if self.transform is not None:
  24. img = self.transform(img)
  25. return img, int(label) # 得到的是字符串,故要进行类型转换
  26. def __len__(self):
  27. return len(self.img_label_list)
  28. if __name__ == "__main__":
  29. data_transforms = {
  30. 'train': transforms.Compose([
  31. transforms.RandomCrop((224, 224)),
  32. transforms.RandomHorizontalFlip(p=0.5),
  33. transforms.ToTensor(),
  34. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  35. ]),
  36. 'val': transforms.Compose([
  37. transforms.Resize((256, 256)),
  38. transforms.CenterCrop((224, 224)),
  39. transforms.ToTensor(),
  40. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  41. ]),
  42. }
  43. data_set = My_DataSet("./data/", "./data/GHIM-20/list_train.txt", transform=data_transforms["train"])
  44. data_loader = DataLoader(data_set, batch_size=3, shuffle=True)
  45. for data, label in DataLoader(data_set, batch_size=64, shuffle=True):
  46. break
  47. # # data, label =next(iter(data_loader))
  48. print(label)

pytorch 如何加载不同尺寸的数据集?

https://www.zhihu.com/question/395888465/answer/1234563938

  1. from torch.utils.data import DataLoader, Dataset
  2. class OurDataset(Dataset):
  3. def __init__(self, *tensors):
  4. self.tensors = tensors
  5. def __getitem__(self, index):
  6. return self.tensors[index]
  7. def __len__(self):
  8. return len(self.tensors)
  9. def collate_wrapper(batch):
  10. a, b = batch
  11. return a, b
  12. a = torch.randn(3, 2, 3)
  13. b = torch.randn(3, 3, 4)
  14. dataset = OurDataset(a, b)
  15. loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper)
  16. for sample in loader:
  17. print([x.size() for x in sample])
  18. # Out: [torch.Size([1, 3, 2, 3]), torch.Size([1, 3, 3, 4])]