Custom DataLoader Example
2) Custom Dataset Class
torch.utils.data.Dataset为抽象类
自定义类需要继承这个类,并实现两个函数,一个是__len__
(提供数据的大小)另一个是__getitem__
(通过给定索引获取数据和标签)
import torch
from PIL import Image
from torch.utils.data import Dataset
import os
import pandas as pd
class MyDataset(Dataset):
def __init__(self,csv_path,img_dir,transform=None):
df = pd.read_csv(csv_path)
self.img_dir = img_dir
self.img_names = df['File Name']
self.y = df['Class Label']
self.transform = transform
def __getitem__(self,index):
img = Image.open(os.path.join(self.img_dir,
self.img_names[index]))
if self.transform is not None:
img = self.transform(img)
label = self.y[index]
return img,label
def __len__(self):
return self.y.shape[0]
通过上面的方式可以定义我们需要的数据类,可以通过迭代的方法来取得每一个数据,但是这样很难实现取batch,shuffle或者多线程去处理数据,所以PyTorch中提供了一个简单的方法来做这个事情
3) Custom Dataloader
__getitem__
一次只能获取一个数据,所以通过torch.utils.data.DataLoader
来定义一个新的迭代器
from torchvision import transforms
from torch.utils.data import DataLoader
custom_transform = transform.Compose([
transforms.ToTensoir(),
transforms.Normalize([0.5],[0.5])
])
train_dataset = MyDataset(csv_path='mnist_train.csv',
img_dir='mnist_train',
transform=custom_transform)
train_loader = DataLoader(dataset = train_dataset,
batch_size = 32,
shuffle = True,
num_works = 4) # number processes/CPUs to use
- transforms.Compose()可以把一些转换函数组合在一起
- transforms.Normalize([0.5],[0.5])对张量进行归一化,这两个0.5分别表示对张量进行归一化的全局平均值和方差,因图像是灰色的只有一个通道,如果有多个通道,需要有多个数字,如3个通道,应该是Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
test_loader = data.DataLoader(Test, batch_size=2, shuffle=True, num_workers=2)
for i, data in enumerate(test_loader):
Data, Label = data
批量读取数据,可以像使用迭代器一样使用它,比如对它进行循环操作
不过它不是迭代器,我们可以通过iter命令将其转化为迭代器
dataiter = iter(test_loader)
imgs, labels = next(dataiter)
可视化
def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))
# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])
4) Iterating Through the Dataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
num_epochs = 2
for epoch in range(num_epochs):
for batch_idx, (x,y) in enumerate(train_loader):
print("Epoch:",epoch+1,end='')
print(" | Batch index:", batch_idx, end='')
print(" | Batch size:", y.size()[0])
x = x.to(device)
y = y.to(device)