官方数据集加载(以 CIFAR10 为例)
https://pytorch.org/docs/stable/torchvision/datasets.html
import torch
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
trainset = torchvision.datasets.CIFAR10(root ="./data/", train = True, download = True, transform = transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root ="./data/", train = False, download = True, transform = transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size, shuffle=False, num_workers=2)
从文件夹中读取数据集
下面演示的是读取训练集的数据,读取验证集数据同理
其中,图片的结构如下所示
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
import torch
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
])
dataset = torchvision.datasets.ImageFolder("./data/", transform=data_transforms)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)
inputs, classes = next(iter(data_loader))
# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
class_name = dataset.classes
imshow(out, title=[class_name[x] for x in classes])
自定义数据读取
从 txt 文件读取每一个图片的路径和标签,使用空格作为分隔符
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import datasets, models, transforms
import os
# 创建一个数据集类:继承 Dataset
class My_DataSet(Dataset):
"""
root 是图片的路径,要结合 txt_file 文件的内容来设置
"""
def __init__(self, root, txt_file, transform=None):
super(My_DataSet, self).__init__()
self.root = root
self.txt_file = txt_file
with open(self.txt_file, "r") as f:
c = f.read()
self.img_label_list = c.splitlines()
self.transform = transform
def __getitem__(self, index):
img_path = self.img_label_list[index].split(" ")[0]
# print(img_path)
label = self.img_label_list[index].split(" ")[-1]
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
if self.transform is not None:
img = self.transform(img)
return img, int(label) # 得到的是字符串,故要进行类型转换
def __len__(self):
return len(self.img_label_list)
if __name__ == "__main__":
data_transforms = {
'train': transforms.Compose([
transforms.RandomCrop((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
data_set = My_DataSet("./data/", "./data/GHIM-20/list_train.txt", transform=data_transforms["train"])
data_loader = DataLoader(data_set, batch_size=3, shuffle=True)
for data, label in DataLoader(data_set, batch_size=64, shuffle=True):
break
# # data, label =next(iter(data_loader))
print(label)
pytorch 如何加载不同尺寸的数据集?
https://www.zhihu.com/question/395888465/answer/1234563938
from torch.utils.data import DataLoader, Dataset
class OurDataset(Dataset):
def __init__(self, *tensors):
self.tensors = tensors
def __getitem__(self, index):
return self.tensors[index]
def __len__(self):
return len(self.tensors)
def collate_wrapper(batch):
a, b = batch
return a, b
a = torch.randn(3, 2, 3)
b = torch.randn(3, 3, 4)
dataset = OurDataset(a, b)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper)
for sample in loader:
print([x.size() for x in sample])
# Out: [torch.Size([1, 3, 2, 3]), torch.Size([1, 3, 3, 4])]