在实现 softmax 回归前,先做一下数据集的工作。关于 FashionMNIST
以下是本节用到的模块

  1. import torch
  2. import torchvision
  3. import torchvision.transforms as trans
  4. import matplotlib.pyplot as plt
  5. import time
  6. import d2lzh_pytorch

torchvision包,是服务于 PyTorch 深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:

  1. torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
  2. torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
  3. torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
  4. torchvision.utils: 其他的一些有用的方法。

    3.5.1 获取数据集

    我们用torchvision.datasets来获取数据集,包括两部分:

  5. 训练集(training set):用于训练模型。

  6. 测试集(testing set):用于测试模型的学习效果。

torchvision.datasets 已经提供了获取 MNISTFashionMNIST 等常用各种数据集的接口。root 用于指定下载路径, trai=True对应训练集,False 则对应测试集,download=True 表示从互联网上下载,已经下载好的不会再重复下载。
另外我们还指定了参数transform = transforms.ToTensor()使所有数据转换为Tensor,如果不进行转换则返回的是PIL图片。transforms.ToTensor()将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor
关于教程作者的提醒,没有遇到,先mark一下。

  1. # 获取数据集
  2. DATA_SETS_PATH = "~/My-Project/Python学习/PyTorch学习/知乎马卡斯扬-动手学深度学习PyTorch版/Data-Sets"
  3. # 训练集
  4. training_set = torchvision.datasets.FashionMNIST(root=DATA_SETS_PATH, train=True, download=True, transform=trans.ToTensor())
  5. # 测试集
  6. testing_set = torchvision.datasets.FashionMNIST(root=DATA_SETS_PATH, train=False, download=True, transform=trans.ToTensor())
  7. print(type(training_set), type(testing_set))
  8. print(len(training_set), len(testing_set))

下载链接是国外的,所以下载速度实在堪忧。下面提供一下这两个数据集。
Data-Sets.zip

看一下数据集。

  1. # 看一下数据集
  2. print(type(training_set), type(testing_set))
  3. print(len(training_set), len(testing_set))
  4. feature, label = training_set[0]
  5. label = torch.tensor(label)
  6. print(feature.size(), label)

运行结果

  1. <class 'torchvision.datasets.mnist.FashionMNIST'> <class 'torchvision.datasets.mnist.FashionMNIST'>
  2. 60000 10000
  3. torch.Size([1, 28, 28]) tensor(9)
  4. <class 'torch.Tensor'> <class 'torch.Tensor'>

特征为 28*28 的 8 位灰度图, 像素值范围已映射到 [0, 1]。
FashionMNIST 一共包含 10 个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。
接下来我们对数据集做一个简单的可视化。

  1. # 已在 d2lzh_pytorch 中的 get_fashion_mnist_labels 实现
  2. def labels_number2txt(src_labels):
  3. """
  4. 将数值标签转换为文本标签, 便于阅读
  5. Args:
  6. src_labels: 数据集的数值标签
  7. Returns:
  8. 文本标签
  9. Raises:
  10. """
  11. TEXT_LABELS = ["t-shirt", "trouser", "pullover", "dress", "coat",
  12. "sandal", "shirt", "sneaker", "bag", "ankle boot"]
  13. return [TEXT_LABELS[int(i)] for i in src_labels]
  14. # 已在 d2lzh_pytorch 中的 show_fashion_mnist 实现
  15. def show_fashion_mnist(features, labels):
  16. """
  17. 画出原始图像和对应标签
  18. Args:
  19. features: 数据集的特征
  20. labels: 数据集的标签
  21. Returns:
  22. Raises:
  23. """
  24. # d2lzh_pytorch.use_svg_display
  25. # 按样本数量建立子图
  26. figs = plt.subplots(1, len(features), figsize=(12, 12))[1]
  27. for f, img, label in zip(figs, features, labels):
  28. # 显示原始图像
  29. f.imshow(img.view(28, 28).numpy())
  30. # 显示标签
  31. f.set_title(label)
  32. f.axes.get_xaxis().set_visible(False)
  33. f.axes.get_yaxis().set_visible(False)
  34. plt.show()
  35. # 数据集的可视化
  36. x, y = [], []
  37. for i in range(10):
  38. x.append(training_set[i][0])
  39. y.append(training_set[i][1])
  40. show_fashion_mnist(x, labels_number2txt(y))

运行结果

图片.png

3.5.2 读取小批次数据

我们上述获得的 training_settesting_settorch.utils.data.Dataset的子类,因此因此可以用 torch.utils.data.DataLoader() 来创建一个用于读取小批次数据的迭代器 DataLoader 实例。
其中 num_workers 用于指定 进程数量 来加速数据的读取。

  1. # 读取小批次数据
  2. batch_size = 256
  3. # 创建读取小批次数据的迭代器
  4. training_set_iter = torch_data.DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=10)
  5. testing_set_iter = torch_data.DataLoader(testing_set, batch_size=batch_size, shuffle=True, num_workers=10)

看一下我们读取整个训练集共 60000 个样本需要的时间。

  1. # 完整地读取一次数据
  2. start = time.time()
  3. for x, y in training_set_iter:
  4. pass
  5. print("读取全部数据需要的时间: {0} s".format(time.time() - start))

运行结果

  1. 读取全部数据需要的时间: 0.9038164615631104 s

小结

FashionMNIST 和 MNIST 这两个数据集是完全兼容的,用的时候主要不要搞混。相比已经被用烂了的MNIST,FashionMNIST 更加合理。

3.5 图像分类数据集.py