MNIST数据集是一组由美国高中生和人口调查局员工手写的70000个数字的图片。
    Scikit-Learn提供了许多助手功能来帮助你下载流行的数据集。MNIST也是其中之一
    下面是获取MNIST数据集的代码:

    1. >>> from sklearn.datasets import fetch_openml
    2. >>> mnist= fetch_openml('mnist_784',version=1)
    3. >>> mnist.keys()
    4. dict_keys(['data', 'target', 'frame', 'feature_names', 'target_names', 'DESCR', 'details', 'categories', 'url'])

    Scikit-Learn加载的数据集通常具有类似的字典结构,包括:

    • DESCR键,描述数据集
    • data键,包含一个数组,每个实例为一行,每个特征为一列。
    • target键,包含一个带有标记的数组。

    我们来看看这些数组:

    1. >>> X,y = mnist["data"],mnist["target"]
    2. >>> X.shape
    3. (70000, 784)
    4. >>>y.shape
    5. (70000,)

    共有7万张图片,每张图片有784个特征。因为图片是28x28像素,每个特征代表了一个像素点的强度,从0(白色)到255(黑色)。先来看看数据集中的一个数字,你只需要随手抓取一个实例的特征向量,将其重新形成一个28x28的数组,然后使用Matplotlib的imshow()函数将其显示出来:

    import matplotlib as mpl
    import matplotlib.pyplot as plt
    some_digit = X[0]
    some_digit_image = some_digit.reshape(28,28)
    plt.imshow(some_digit_image,cmap="binary")
    plt.axis("off")
    plt.show()
    

    image.png
    图3-1:第一个数据
    看起来像5,而标签告诉我们没错:

    >>> y[0]
    '5'
    

    注意标签是字符,大部分机器学习算法希望是数字,让我们把y转换成整数:

    import numpy as np
    >>> y = y.astype(np.uint8)
    

    在开始深入研究这些数据之前,你还是应该先创建一个测试集,并将其放在一边。事实上,MNIST数据集已经分成训练集(前6万张图片)和测试集(最后1万张图片)了:

    X_train,X_test,y_train,y_test = X[:60000],X[60000:],y[:60000],y[60000:]