蚂蚁-蜜蜂分类数据集下载:https://download.pytorch.org/tutorial/hymenoptera_data.zip
Dataset
Dataset
提供一种方式获取数据及其 label,包含两个功能:
- 如何获取每一个数据及其 label?——
__getitem__
- 总共有多少条数据?——
__len__
Dataset
的子类都必须重写__getitem__
和__len__
这两个方法
from torch.utils.data import Dataset
from PIL import Image
import os
class MyDataset(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(root_dir, label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.path, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "./data/hymenoptera_data/train"
ants_label_dir = "ants"
ants_dataset = MyDataset(root_dir, ants_label_dir)
bees_label_dir = "bees"
bees_dataset = MyDataset(root_dir, bees_label_dir)
# 合并两个类别的 dataset 组成训练集
train_dataset = ants_dataset + bees_dataset
len(ants_dataset), len(bees_dataset), len(train_dataset) # (124, 121, 245)
img2, label2 = train_dataset[123]
img2.show()
label2 # 'ants'
Dataloader
Dataloader
为后面的网络提供不同的数学形式。用于控制如何从 Dataset 取数据
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
writer = SummaryWriter("dataloader_logs")
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
writer.add_images("Epoch: {}".format(epoch), imgs, step)
step = step + 1
writer.close()