数据读取

MNIST数据本身是图像格式的,我们用mode=”vector”去读取,转变成矢量格式的。

  1. def LoadData():
  2. print("reading data...")
  3. dr = MnistImageDataReader(mode="vector")
  4. ......

搭建模型

一共4个隐层,都用ReLU激活函数连接,最后的输出层接Softmax分类函数。

多分类任务 - MNIST手写体识别 - 图1

以下是主要的参数设置:

  1. if __name__ == '__main__':
  2. dataReader = LoadData()
  3. num_feature = dataReader.num_feature
  4. num_example = dataReader.num_example
  5. num_input = num_feature
  6. num_hidden1 = 128
  7. num_hidden2 = 64
  8. num_hidden3 = 32
  9. num_hidden4 = 16
  10. num_output = 10
  11. max_epoch = 10
  12. batch_size = 64
  13. learning_rate = 0.1
  14. params = HyperParameters_4_0(
  15. learning_rate, max_epoch, batch_size,
  16. net_type=NetType.MultipleClassifier,
  17. init_method=InitialMethod.MSRA,
  18. stopper=Stopper(StopCondition.StopLoss, 0.12))
  19. net = NeuralNet_4_0(params, "MNIST")
  20. fc1 = FcLayer_1_0(num_input, num_hidden1, params)
  21. net.add_layer(fc1, "fc1")
  22. r1 = ActivationLayer(Relu())
  23. net.add_layer(r1, "r1")
  24. ......
  25. fc5 = FcLayer_1_0(num_hidden4, num_output, params)
  26. net.add_layer(fc5, "fc5")
  27. softmax = ClassificationLayer(Softmax())
  28. net.add_layer(softmax, "softmax")
  29. net.train(dataReader, checkpoint=0.05, need_test=True)
  30. net.ShowLossHistory(xcoord=XCoordinate.Iteration)

运行结果

我们设计的停止条件是绝对Loss值达到0.12时,所以迭代到6个epoch时,达到了0.119的损失值,就停止训练了。

图14-19 训练过程中损失函数值和准确率的变化

图14-19是训练过程图示,下面是最后几行的打印输出。

  1. ......
  2. epoch=6, total_iteration=5763
  3. loss_train=0.005559, accuracy_train=1.000000
  4. loss_valid=0.119701, accuracy_valid=0.971667
  5. time used: 17.500738859176636
  6. save parameters
  7. testing...
  8. 0.9697

最后用测试集得到的准确率为96.97%。

代码位置

原代码位置:ch14, Level6

个人代码:MnistClassifier**

keras实现

  1. from ExtendedDataReader.MnistImageDataReader import *
  2. from keras.models import Sequential
  3. from keras.layers import Dense
  4. import matplotlib.pyplot as plt
  5. import os
  6. os.environ['KMP_DUPLICATE_LIB_OK']='True'
  7. def load_data():
  8. dataReader = MnistImageDataReader(mode="vector")
  9. dataReader.ReadData()
  10. dataReader.NormalizeX()
  11. dataReader.NormalizeY(NetType.MultipleClassifier)
  12. dataReader.GenerateValidationSet(k=20)
  13. x_train, y_train = dataReader.XTrain, dataReader.YTrain
  14. x_test, y_test = dataReader.XTest, dataReader.YTest
  15. x_val, y_val = dataReader.XDev, dataReader.YDev
  16. x_train = x_train.reshape(x_train.shape[0], 28 * 28)
  17. x_test = x_test.reshape(x_test.shape[0], 28 * 28)
  18. x_val = x_val.reshape(x_val.shape[0], 28 * 28)
  19. return x_train, y_train, x_test, y_test, x_val, y_val
  20. def build_model():
  21. model = Sequential()
  22. model.add(Dense(128, activation='relu', input_shape=(784, )))
  23. model.add(Dense(64, activation='relu'))
  24. model.add(Dense(32, activation='relu'))
  25. model.add(Dense(16, activation='relu'))
  26. model.add(Dense(10, activation='softmax'))
  27. model.compile(optimizer='Adam',
  28. loss='categorical_crossentropy',
  29. metrics=['accuracy'])
  30. return model
  31. #画出训练过程中训练和验证的精度与损失
  32. def draw_train_history(history):
  33. plt.figure(1)
  34. # summarize history for accuracy
  35. plt.subplot(211)
  36. plt.plot(history.history['accuracy'])
  37. plt.plot(history.history['val_accuracy'])
  38. plt.title('model accuracy')
  39. plt.ylabel('accuracy')
  40. plt.xlabel('epoch')
  41. plt.legend(['train', 'validation'])
  42. # summarize history for loss
  43. plt.subplot(212)
  44. plt.plot(history.history['loss'])
  45. plt.plot(history.history['val_loss'])
  46. plt.title('model loss')
  47. plt.ylabel('loss')
  48. plt.xlabel('epoch')
  49. plt.legend(['train', 'validation'])
  50. plt.show()
  51. if __name__ == '__main__':
  52. x_train, y_train, x_test, y_test, x_val, y_val = load_data()
  53. # print(x_train.shape)
  54. # print(x_test.shape)
  55. # print(x_val.shape)
  56. model = build_model()
  57. history = model.fit(x_train, y_train, epochs=20, batch_size=64, validation_data=(x_val, y_val))
  58. draw_train_history(history)
  59. loss, accuracy = model.evaluate(x_test, y_test)
  60. print("test loss: {}, test accuracy: {}".format(loss, accuracy))
  61. weights = model.get_weights()
  62. print("weights: ", weights)

模型输出

  1. test loss: 0.11646892445675121, test accuracy: 0.9768999814987183

模型损失以及准确率曲线

多分类任务 - MNIST手写体识别 - 图2