蚂蚁-蜜蜂分类数据集下载:https://download.pytorch.org/tutorial/hymenoptera_data.zip

Dataset

Dataset 提供一种方式获取数据及其 label,包含两个功能:

  1. 如何获取每一个数据及其 label?——__getitem__
  2. 总共有多少条数据?——__len__

Dataset 的子类都必须重写__getitem____len__这两个方法

  1. from torch.utils.data import Dataset
  2. from PIL import Image
  3. import os
  4. class MyDataset(Dataset):
  5. def __init__(self, root_dir, label_dir):
  6. self.root_dir = root_dir
  7. self.label_dir = label_dir
  8. self.path = os.path.join(root_dir, label_dir)
  9. self.img_path = os.listdir(self.path)
  10. def __getitem__(self, idx):
  11. img_name = self.img_path[idx]
  12. img_item_path = os.path.join(self.path, img_name)
  13. img = Image.open(img_item_path)
  14. label = self.label_dir
  15. return img, label
  16. def __len__(self):
  17. return len(self.img_path)
  18. root_dir = "./data/hymenoptera_data/train"
  19. ants_label_dir = "ants"
  20. ants_dataset = MyDataset(root_dir, ants_label_dir)
  21. bees_label_dir = "bees"
  22. bees_dataset = MyDataset(root_dir, bees_label_dir)
  23. # 合并两个类别的 dataset 组成训练集
  24. train_dataset = ants_dataset + bees_dataset
  25. len(ants_dataset), len(bees_dataset), len(train_dataset) # (124, 121, 245)
  26. img2, label2 = train_dataset[123]
  27. img2.show()
  28. label2 # 'ants'

Dataloader

Dataloader 为后面的网络提供不同的数学形式。用于控制如何从 Dataset 取数据

  1. import torchvision
  2. from torch.utils.data import DataLoader
  3. from torch.utils.tensorboard import SummaryWriter
  4. test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform = torchvision.transforms.ToTensor())
  5. test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
  6. writer = SummaryWriter("dataloader_logs")
  7. for epoch in range(2):
  8. step = 0
  9. for data in test_loader:
  10. imgs, targets = data
  11. writer.add_images("Epoch: {}".format(epoch), imgs, step)
  12. step = step + 1
  13. writer.close()

image.png