https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

Create a custom dataset for your files

A custom Dataset class must implement three functions: __init__,__len__,and __getitem__

  1. import os
  2. import pandas as pd
  3. from torchvision.io import read_image
  4. class CustomImageDataset(Dataset):
  5. def __init__(self,annotations_file,img_dir,transform=None,target_transform=None):
  6. self.img_labels = pd.read_csv(annotation_file)
  7. self.img_dir = img_dir
  8. self.transform = transform
  9. self.target_transform = target_transform
  10. def __len__(self):
  11. return len(self.img_labels)
  12. def __getitem__(self,idx):
  13. img_path = os.path.join(self.img_dir,self.img_labels.iloc[idx,0])
  14. image = read_image(img_path)
  15. label = self.img_labels.iloc[idx,0]
  16. if self.transform:
  17. image = self.transform(image)
  18. if self.target_transform:
  19. label = self.target_transform(label)
  20. return image,label

__init__

The __init__** **function is run once when instantiating the Dataset object. We initialize the directory containing the images, the annotations file, and both transform.

__len__

The len function returns the number of samples in our dataset.

__getitem__

The function loads and returns a sample from the dataset at the given index idx.
Based on the index, it identifies the image’s location on disk, converts that to a tensor using read_image, retrieves the corresponding label from the csv data in self.img_labels, calls the transform function on them, and return the tensor image and corresponding label in a tuple.

Dataloader

The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass sample in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.

  1. from torch.utils.data import DataLoader
  2. train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
  3. test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

We have loaded that dataset into the DataLoader and can iterate through the dataset as needed. Each iteration below returns a batch of train_features and train_labels

从实例化的dataloader对象中随机加载数据

  1. # get some random training images
  2. dataiter = iter(trainloader)
  3. images, labels = dataiter.next()