- Code for processing data samples can get messy and hard to maintain; we ideally want our dataset code to be decoupled from our model training code for better readability and modularity.
PyTorch has two primitives to work with data(allow you to use pre-loaded datasets as well as your own data):
torch.utils.data.DataLoader
andtorch.utils.data.Dataset
.Dataset
stores the samples and their corresponding labelsDataLoader
wraps an iterable around theDataset
to enable easy access to the samples.torch.utils.data
API 参考:https://pytorch.org/docs/stable/data.htmlimport torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
PyTorch offers domain-specific libraries such as TorchText, TorchVision, and TorchAudio(and implement functions specific to the particular data), all of which include datasets. For this tutorial, we will be using a TorchVision dataset.
- TorchText:https://pytorch.org/text/stable/index.html
- TorchVision:https://pytorch.org/vision/stable/index.html
- TorchAudio:https://pytorch.org/audio/stable/index.html
- The
torchvision.datasets
module containsDataset
objects for many real-world vision data like CIFAR, COCO. In this tutorial, we use the FashionMNIST dataset. Every TorchVisionDataset
includes two arguments:transform
andtarget_transform
to modify the samples and labels respectively.- TorchVision
Dataset
:https://pytorch.org/vision/stable/datasets.html1、Loading a Dataset
```pythonDownload training data from open datasets.
training_data = datasets.FashionMNIST( root=”data”, train=True, download=True, transform=ToTensor(), )
- TorchVision
Download test data from open datasets.
test_data = datasets.FashionMNIST( root=”data”, train=False, download=True, transform=ToTensor(), )
“”” Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw “””
- We pass the `Dataset` as an argument to `DataLoader`. This wraps an iterable over our dataset, and supports automatic batching, sampling, shuffling and multiprocess data loading. Here we define a batch size of 64, i.e. each element in the dataloader iterable will return a batch of 64 features and labels.
```python
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break
"""
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
"""
2、Iterating and Visualizing the Dataset
We can index
Datasets
manually like a list:training_data[index]
. We usematplotlib
to visualize some samples in our training data.labels_map = { 0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot", } figure = plt.figure(figsize=(8, 8)) cols, rows = 3, 3 for i in range(1, cols * rows + 1): sample_idx = torch.randint(len(training_data), size=(1,)).item() img, label = training_data[sample_idx] figure.add_subplot(rows, cols, i) plt.title(labels_map[label]) plt.axis("off") plt.imshow(img.squeeze(), cmap="gray") plt.show()
3、Creating a Custom Dataset for your files
A custom Dataset class must implement three functions:
__init__
,__len__
, and__getitem__
. Take a look at this implementation; the FashionMNIST images are stored in a directoryimg_dir
, and their labels are stored separately in a CSV fileannotations_file
.- In the next sections, we’ll break down what’s happening in each of these functions. ```python import os import pandas as pd from torchvision.io import read_image
class CustomImageDataset(Dataset): def init(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
<a name="N3DkZ"></a>
### (1)__init__
- The `__init__`function is run once when instantiating the Dataset object. We initialize the directory containing the images, the annotations file, and both transforms (covered in more detail in the next section).
- The `labels.csv` file looks like:
```python
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
(2)len
The
__len__
function returns the number of samples in our dataset.def __len__(self): return len(self.img_labels)
(3)getitem
The
__getitem__
function loads and returns a sample from the dataset at the given indexidx
. Based on the index, it identifies the image’s location on disk, converts that to a tensor usingread_image
, retrieves the corresponding label from the csv data inself.img_labels
, calls the transform functions on them (if applicable), and returns the tensor image and corresponding label in a tuple.def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(img_path) label = self.img_labels.iloc[idx, 1] if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image, label
4、Preparing your data for training with DataLoaders
The
Dataset
retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’smultiprocessing
to speed up data retrieval.DataLoader
is an iterable that abstracts this complexity for us in an easy API. ```python from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
<a name="Jyjm9"></a>
## 5、Iterate through the DataLoader
- We have loaded that dataset into the `DataLoader` and can iterate through the dataset as needed. Each iteration below returns a batch of `train_features` and `train_labels` (containing `batch_size=64` features and labels respectively). Because we specified `shuffle=True`, after we iterate over all batches the data is shuffled.
- For finer-grained control over the data loading order, take a look at Samplers:
- Samplers:[https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler](https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler)
```python
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
"""
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 9
"""
6、参考
torch.utils.data
API