https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
Create a custom dataset for your files
A custom Dataset class must implement three functions: __init__,__len__,and __getitem__
import osimport pandas as pdfrom torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self,annotations_file,img_dir,transform=None,target_transform=None):self.img_labels = pd.read_csv(annotation_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self,idx):img_path = os.path.join(self.img_dir,self.img_labels.iloc[idx,0])image = read_image(img_path)label = self.img_labels.iloc[idx,0]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image,label
__init__
The __init__** **function is run once when instantiating the Dataset object. We initialize the directory containing the images, the annotations file, and both transform.
__len__
The len function returns the number of samples in our dataset.
__getitem__
The function loads and returns a sample from the dataset at the given index idx.
Based on the index, it identifies the image’s location on disk, converts that to a tensor using read_image, retrieves the corresponding label from the csv data in self.img_labels, calls the transform function on them, and return the tensor image and corresponding label in a tuple.
Dataloader
The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass sample in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.
from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
We have loaded that dataset into the DataLoader and can iterate through the dataset as needed. Each iteration below returns a batch of train_features and train_labels
从实例化的dataloader对象中随机加载数据
# get some random training imagesdataiter = iter(trainloader)images, labels = dataiter.next()
