在进行softmax回归前,先引入一个图像分类数据集。相比用烂的MNIST数据集,图像分类数据集分类的是现实生活中的实物。大部分模型在对MNIST的分类精度都超过95%,但对于这个数据集而言,精度往往会下降5~10个百分点不等。因此,该数据集可以作为衡量模型的一个很好的评估标准。

引入torchvision,一个服务于pytorch框架的包:

  1. torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
  2. torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
  3. torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
  4. torchvision.utils: 其他的一些有用的方法。

获取数据集

第一次调用torchvision.datasets.FashionMNIST()函数会自动下载数据集。可以指定下载训练集或是测试集。

transform = transforms.ToTensor()的作用是使所有数据转换为Tensor,如果不进行转换则返回的是PIL图片。transforms.ToTensor()将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor

  1. import torchvision
  2. import torchvision.transforms as transforms
  3. mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True,transform=transforms.ToTensor())
  4. mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True,transform=transforms.ToTensor())
  5. print(type(mnist_train))
  6. print(len(mnist_train), len(mnist_test))
  7. feature, label = mnist_train[0]
  8. print(feature, label)
  9. 结果:
  10. <class 'torchvision.datasets.mnist.FashionMNIST'>
  11. 60000 10000
  12. FashionMNIST是一个单独的类,用于存储mnist数据集
  13. 训练集长度为60000,测试集长度为10000

访问样本:

  • mnist列表中每个元素也是一个列表,包含一个特征值以及一个标签,即输出值。
  • feature张量的大小是FashionMNIST - 图1,第一维是通道数,灰度图像通道数为1,H表高,W表宽。 ```python feature, label = mnist_train[0] print(feature, label)

结果: 因为使用了to_tensor()函数,因此feature为一个三维的张量,大小为1X28X28 label的值为9,代表图像类别编号

  1. 给数字labels打上具体中文标签,要得知哪个样本是哪类图像查询该list即可:<br />这里labels是一个数据集中label的子集,比如说前十个
  2. ```python
  3. # 本函数已保存在d2lzh包中方便以后使用
  4. def get_fashion_mnist_labels(labels):
  5. text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
  6. 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
  7. return [text_labels[int(i)] for i in labels]

在一行中画出多张图像和对应标签的函数:

  1. # 本函数已保存在d2lzh包中方便以后使用
  2. def show_fashion_mnist(images, labels): # labels对应上个函数中输出的labels
  3. d2l.use_svg_display()
  4. # 这里的_表示我们忽略(不使用)的变量
  5. _, figs = plt.subplots(1, len(images), figsize=(12, 12))
  6. for f, img, lbl in zip(figs, images, labels):
  7. f.imshow(img.view((28, 28)).numpy())
  8. f.set_title(lbl)
  9. f.axes.get_xaxis().set_visible(False)
  10. f.axes.get_yaxis().set_visible(False)
  11. plt.show()

例子:查看数据集中前10个样本的图像内容和文本标签:

  1. X, y = [], []
  2. for i in range(10):
  3. X.append(mnist_train[i][0])
  4. y.append(mnist_train[i][1])
  5. show_fashion_mnist(X, get_fashion_mnist_labels(y))

读取小批量

补充:

python获取操作系统类型:

  1. import sys
  2. sys.platform
  3. 结果:
  4. 'win32'

python判断字符串起始是否包含某子串:

  1. # 若上述代码获取到的字符串以win开头,返回True
  2. if sys.platform.startwith("win"):

我们将在训练数据集上训练模型,并将训练好的模型在测试数据集上评价模型的表现。前面说过,mnist_traintorch.utils.data.Dataset的子类,所以我们可以将其传入torch.utils.data.DataLoader来创建一个读取小批量数据样本的DataLoader实例。 在实践中,数据读取经常是训练的性能瓶颈,特别当模型较简单或者计算硬件性能较高时。PyTorch的DataLoader中一个很方便的功能是允许使用多进程来加速数据读取。这里我们通过参数num_workers来设置4个进程读取数据。

  1. batch_size = 256
  2. if sys.platform.startswith('win'):
  3. num_workers = 0 # 不使用额外进程加速,会略慢
  4. else:
  5. num_workers = 4
  6. train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
  7. test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=True, num_workers=num_workers)
  8. # 开始遍历,计算总共需要多少时间
  9. start = time.time()
  10. for X,y in train_iter:
  11. continue
  12. print("%.2f sec" % (time.time() - start))

可以看到,train_iter迭代器极为好用,作者将获取并读取Fashion-MNIST数据集的逻辑封装在d2lzh_pytorch.load_data_fashion_mnist函数中供后面章节调用。该函数将返回train_itertest_iter两个变量。

完整代码如下:

  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. import matplotlib.pyplot as plt
  5. import torch.utils.data
  6. import time
  7. import sys
  8. import d2lzh as d2l
  9. mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True,
  10. transform=transforms.ToTensor())
  11. mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True,
  12. transform=transforms.ToTensor())
  13. print(type(mnist_train))
  14. print(len(mnist_train), len(mnist_test))
  15. feature, label = mnist_train[0]
  16. print(feature, label)
  17. # 本函数已保存在d2lzh包中方便以后使用
  18. def get_fashion_mnist_labels(labels):
  19. text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
  20. 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
  21. return [text_labels[int(i)] for i in labels]
  22. # 本函数已保存在d2lzh包中方便以后使用
  23. def show_fashion_mnist(images, labels):
  24. d2l.use_svg_display()
  25. # 这里的_表示我们忽略(不使用)的变量
  26. _, figs = plt.subplots(1, len(images), figsize=(12, 12))
  27. for f, img, lbl in zip(figs, images, labels):
  28. f.imshow(img.view((28, 28)).numpy())
  29. f.set_title(lbl)
  30. f.axes.get_xaxis().set_visible(False)
  31. f.axes.get_yaxis().set_visible(False)
  32. plt.show()
  33. if __name__ == "__main__":
  34. X, y = [], []
  35. for i in range(10):
  36. X.append(mnist_train[i][0])
  37. y.append(mnist_train[i][1])
  38. show_fashion_mnist(X, get_fashion_mnist_labels(y))
  39. batch_size = 256
  40. if sys.platform.startswith('win'):
  41. num_workers = 0
  42. else:
  43. num_workers = 4
  44. train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
  45. test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=True, num_workers=num_workers)
  46. # 开始遍历
  47. start = time.time()
  48. for X, y in train_iter:
  49. continue
  50. print("%.2f sec" % (time.time() - start))