在进行softmax回归前,先引入一个图像分类数据集。相比用烂的MNIST数据集,图像分类数据集分类的是现实生活中的实物。大部分模型在对MNIST的分类精度都超过95%,但对于这个数据集而言,精度往往会下降5~10个百分点不等。因此,该数据集可以作为衡量模型的一个很好的评估标准。
引入torchvision,一个服务于pytorch框架的包:
torchvision.datasets: 一些加载数据的函数及常用的数据集接口;torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;torchvision.utils: 其他的一些有用的方法。
获取数据集
第一次调用torchvision.datasets.FashionMNIST()函数会自动下载数据集。可以指定下载训练集或是测试集。
transform = transforms.ToTensor()的作用是使所有数据转换为Tensor,如果不进行转换则返回的是PIL图片。transforms.ToTensor()将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor。
import torchvisionimport torchvision.transforms as transformsmnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True,transform=transforms.ToTensor())mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True,transform=transforms.ToTensor())print(type(mnist_train))print(len(mnist_train), len(mnist_test))feature, label = mnist_train[0]print(feature, label)结果:<class 'torchvision.datasets.mnist.FashionMNIST'>60000 10000FashionMNIST是一个单独的类,用于存储mnist数据集训练集长度为60000,测试集长度为10000
访问样本:
- mnist列表中每个元素也是一个列表,包含一个特征值以及一个标签,即输出值。
- feature张量的大小是
,第一维是通道数,灰度图像通道数为1,H表高,W表宽。 ```python feature, label = mnist_train[0] print(feature, label)
结果: 因为使用了to_tensor()函数,因此feature为一个三维的张量,大小为1X28X28 label的值为9,代表图像类别编号
给数字labels打上具体中文标签,要得知哪个样本是哪类图像查询该list即可:<br />这里labels是一个数据集中label的子集,比如说前十个```python# 本函数已保存在d2lzh包中方便以后使用def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]
在一行中画出多张图像和对应标签的函数:
# 本函数已保存在d2lzh包中方便以后使用def show_fashion_mnist(images, labels): # labels对应上个函数中输出的labelsd2l.use_svg_display()# 这里的_表示我们忽略(不使用)的变量_, figs = plt.subplots(1, len(images), figsize=(12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()
例子:查看数据集中前10个样本的图像内容和文本标签:
X, y = [], []for i in range(10):X.append(mnist_train[i][0])y.append(mnist_train[i][1])show_fashion_mnist(X, get_fashion_mnist_labels(y))
读取小批量
补充:
python获取操作系统类型:
import syssys.platform结果:'win32'
python判断字符串起始是否包含某子串:
# 若上述代码获取到的字符串以win开头,返回Trueif sys.platform.startwith("win"):
我们将在训练数据集上训练模型,并将训练好的模型在测试数据集上评价模型的表现。前面说过,
mnist_train是torch.utils.data.Dataset的子类,所以我们可以将其传入torch.utils.data.DataLoader来创建一个读取小批量数据样本的DataLoader实例。 在实践中,数据读取经常是训练的性能瓶颈,特别当模型较简单或者计算硬件性能较高时。PyTorch的DataLoader中一个很方便的功能是允许使用多进程来加速数据读取。这里我们通过参数num_workers来设置4个进程读取数据。
batch_size = 256if sys.platform.startswith('win'):num_workers = 0 # 不使用额外进程加速,会略慢else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=True, num_workers=num_workers)# 开始遍历,计算总共需要多少时间start = time.time()for X,y in train_iter:continueprint("%.2f sec" % (time.time() - start))
可以看到,train_iter迭代器极为好用,作者将获取并读取Fashion-MNIST数据集的逻辑封装在d2lzh_pytorch.load_data_fashion_mnist函数中供后面章节调用。该函数将返回train_iter和test_iter两个变量。
完整代码如下:
import torchimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as pltimport torch.utils.dataimport timeimport sysimport d2lzh as d2lmnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True,transform=transforms.ToTensor())mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True,transform=transforms.ToTensor())print(type(mnist_train))print(len(mnist_train), len(mnist_test))feature, label = mnist_train[0]print(feature, label)# 本函数已保存在d2lzh包中方便以后使用def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# 本函数已保存在d2lzh包中方便以后使用def show_fashion_mnist(images, labels):d2l.use_svg_display()# 这里的_表示我们忽略(不使用)的变量_, figs = plt.subplots(1, len(images), figsize=(12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()if __name__ == "__main__":X, y = [], []for i in range(10):X.append(mnist_train[i][0])y.append(mnist_train[i][1])show_fashion_mnist(X, get_fashion_mnist_labels(y))batch_size = 256if sys.platform.startswith('win'):num_workers = 0else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=True, num_workers=num_workers)# 开始遍历start = time.time()for X, y in train_iter:continueprint("%.2f sec" % (time.time() - start))
