在实现 softmax 回归前,先做一下数据集的工作。关于 FashionMNIST
以下是本节用到的模块
import torch
import torchvision
import torchvision.transforms as trans
import matplotlib.pyplot as plt
import time
import d2lzh_pytorch
torchvision包,是服务于 PyTorch 深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:
torchvision.datasets
: 一些加载数据的函数及常用的数据集接口;torchvision.models
: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;torchvision.transforms
: 常用的图片变换,例如裁剪、旋转等;torchvision.utils
: 其他的一些有用的方法。3.5.1 获取数据集
我们用
torchvision.datasets
来获取数据集,包括两部分:训练集(training set):用于训练模型。
- 测试集(testing set):用于测试模型的学习效果。
torchvision.datasets
已经提供了获取 MNIST 、FashionMNIST 等常用各种数据集的接口。root
用于指定下载路径, trai=True
对应训练集,False 则对应测试集,download=True
表示从互联网上下载,已经下载好的不会再重复下载。
另外我们还指定了参数transform = transforms.ToTensor()
使所有数据转换为Tensor
,如果不进行转换则返回的是PIL图片。transforms.ToTensor()
将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8
的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32
且位于[0.0, 1.0]的Tensor
。
关于教程作者的提醒,没有遇到,先mark一下。
# 获取数据集
DATA_SETS_PATH = "~/My-Project/Python学习/PyTorch学习/知乎马卡斯扬-动手学深度学习PyTorch版/Data-Sets"
# 训练集
training_set = torchvision.datasets.FashionMNIST(root=DATA_SETS_PATH, train=True, download=True, transform=trans.ToTensor())
# 测试集
testing_set = torchvision.datasets.FashionMNIST(root=DATA_SETS_PATH, train=False, download=True, transform=trans.ToTensor())
print(type(training_set), type(testing_set))
print(len(training_set), len(testing_set))
下载链接是国外的,所以下载速度实在堪忧。下面提供一下这两个数据集。
Data-Sets.zip
看一下数据集。
# 看一下数据集
print(type(training_set), type(testing_set))
print(len(training_set), len(testing_set))
feature, label = training_set[0]
label = torch.tensor(label)
print(feature.size(), label)
运行结果
<class 'torchvision.datasets.mnist.FashionMNIST'> <class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000
torch.Size([1, 28, 28]) tensor(9)
<class 'torch.Tensor'> <class 'torch.Tensor'>
特征为 28*28 的 8 位灰度图, 像素值范围已映射到 [0, 1]。
FashionMNIST 一共包含 10 个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。
接下来我们对数据集做一个简单的可视化。
# 已在 d2lzh_pytorch 中的 get_fashion_mnist_labels 实现
def labels_number2txt(src_labels):
"""
将数值标签转换为文本标签, 便于阅读
Args:
src_labels: 数据集的数值标签
Returns:
文本标签
Raises:
无
"""
TEXT_LABELS = ["t-shirt", "trouser", "pullover", "dress", "coat",
"sandal", "shirt", "sneaker", "bag", "ankle boot"]
return [TEXT_LABELS[int(i)] for i in src_labels]
# 已在 d2lzh_pytorch 中的 show_fashion_mnist 实现
def show_fashion_mnist(features, labels):
"""
画出原始图像和对应标签
Args:
features: 数据集的特征
labels: 数据集的标签
Returns:
无
Raises:
无
"""
# d2lzh_pytorch.use_svg_display
# 按样本数量建立子图
figs = plt.subplots(1, len(features), figsize=(12, 12))[1]
for f, img, label in zip(figs, features, labels):
# 显示原始图像
f.imshow(img.view(28, 28).numpy())
# 显示标签
f.set_title(label)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
# 数据集的可视化
x, y = [], []
for i in range(10):
x.append(training_set[i][0])
y.append(training_set[i][1])
show_fashion_mnist(x, labels_number2txt(y))
运行结果
3.5.2 读取小批次数据
我们上述获得的 training_set
和 testing_set
是 torch.utils.data.Dataset
的子类,因此因此可以用 torch.utils.data.DataLoader()
来创建一个用于读取小批次数据的迭代器 DataLoader
实例。
其中 num_workers
用于指定 进程数量 来加速数据的读取。
# 读取小批次数据
batch_size = 256
# 创建读取小批次数据的迭代器
training_set_iter = torch_data.DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=10)
testing_set_iter = torch_data.DataLoader(testing_set, batch_size=batch_size, shuffle=True, num_workers=10)
看一下我们读取整个训练集共 60000 个样本需要的时间。
# 完整地读取一次数据
start = time.time()
for x, y in training_set_iter:
pass
print("读取全部数据需要的时间: {0} s".format(time.time() - start))
运行结果
读取全部数据需要的时间: 0.9038164615631104 s
小结
FashionMNIST 和 MNIST 这两个数据集是完全兼容的,用的时候主要不要搞混。相比已经被用烂了的MNIST,FashionMNIST 更加合理。