MNIST数据集是一组由美国高中生和人口调查局员工手写的70000个数字的图片。
Scikit-Learn提供了许多助手功能来帮助你下载流行的数据集。MNIST也是其中之一
下面是获取MNIST数据集的代码:
>>> from sklearn.datasets import fetch_openml>>> mnist= fetch_openml('mnist_784',version=1)>>> mnist.keys()dict_keys(['data', 'target', 'frame', 'feature_names', 'target_names', 'DESCR', 'details', 'categories', 'url'])
Scikit-Learn加载的数据集通常具有类似的字典结构,包括:
- DESCR键,描述数据集
- data键,包含一个数组,每个实例为一行,每个特征为一列。
- target键,包含一个带有标记的数组。
我们来看看这些数组:
>>> X,y = mnist["data"],mnist["target"]>>> X.shape(70000, 784)>>>y.shape(70000,)
共有7万张图片,每张图片有784个特征。因为图片是28x28像素,每个特征代表了一个像素点的强度,从0(白色)到255(黑色)。先来看看数据集中的一个数字,你只需要随手抓取一个实例的特征向量,将其重新形成一个28x28的数组,然后使用Matplotlib的imshow()函数将其显示出来:
import matplotlib as mpl
import matplotlib.pyplot as plt
some_digit = X[0]
some_digit_image = some_digit.reshape(28,28)
plt.imshow(some_digit_image,cmap="binary")
plt.axis("off")
plt.show()

图3-1:第一个数据
看起来像5,而标签告诉我们没错:
>>> y[0]
'5'
注意标签是字符,大部分机器学习算法希望是数字,让我们把y转换成整数:
import numpy as np
>>> y = y.astype(np.uint8)
在开始深入研究这些数据之前,你还是应该先创建一个测试集,并将其放在一边。事实上,MNIST数据集已经分成训练集(前6万张图片)和测试集(最后1万张图片)了:
X_train,X_test,y_train,y_test = X[:60000],X[60000:],y[:60000],y[60000:]
