与25.2节的想法类似,我们先将每个像素看作一个特征,用常用的逻辑回归模型做一下尝试,看看对于彩色图像的分类效果;随后,再实验图像分类问题的经典模型:卷积神经网络(CNN)。

26.2.1 逻辑回归模型

尝试逻辑回归模型,将每个像素看作一个特征,使用TensorToVector组件,将张量格式的图片数据转换为向量,然后使用LogisticRegression进行训练,并计算模型指标。

  1. def lr(train_set, test_set) :
  2. Pipeline()\
  3. .add(\
  4. TensorToVector()\
  5. .setSelectedCol("tensor")\
  6. .setReservedCols(["label"])\
  7. )\
  8. .add(\
  9. LogisticRegression()\
  10. .setVectorCol("tensor")\
  11. .setLabelCol("label")\
  12. .setPredictionCol(PREDICTION_COL)\
  13. .setPredictionDetailCol(PREDICTION_DETAIL_COL)\
  14. )\
  15. .fit(train_set)\
  16. .transform(test_set)\
  17. .link(\
  18. EvalBinaryClassBatchOp()\
  19. .setLabelCol("label")\
  20. .setPredictionDetailCol(PREDICTION_DETAIL_COL)\
  21. .lazyPrintMetrics()\
  22. )
  23. BatchOperator.execute()

得到LR模型的评估指标如下,精确度为0.6164。

  1. -------------------------------- Metrics: --------------------------------
  2. Auc:0.6496 Accuracy:0.6164 Precision:0.6264 Recall:0.6022 F1:0.6141 LogLoss:0.6812
  3. |Pred\Real|dog|cat|
  4. |---------|---|---|
  5. | dog|763|455|
  6. | cat|504|778|

26.2.2 CNN模型

定义CNN模型结构如下:

  1. _________________________________________________________________
  2. Layer (type) Output Shape Param #
  3. =================================================================
  4. tensor (InputLayer) [(None, 32, 32, 3)] 0
  5. _________________________________________________________________
  6. conv2d (Conv2D) (None, 30, 30, 32) 896
  7. _________________________________________________________________
  8. max_pooling2d (MaxPooling2D) (None, 15, 15, 32) 0
  9. _________________________________________________________________
  10. conv2d_1 (Conv2D) (None, 13, 13, 64) 18496
  11. _________________________________________________________________
  12. max_pooling2d_1 (MaxPooling2 (None, 6, 6, 64) 0
  13. _________________________________________________________________
  14. flatten (Flatten) (None, 2304) 0
  15. _________________________________________________________________
  16. dropout (Dropout) (None, 2304) 0
  17. _________________________________________________________________
  18. logits (Dense) (None, 1) 2305
  19. =================================================================
  20. Total params: 21,697
  21. Trainable params: 21,697
  22. Non-trainable params: 0
  23. _________________________________________________________________

使用KerasSequentialClassifierTrainBatchOp进行模型训练,并将模型保存到文件MODEL_CNN_FILE,相应代码如下

  1. if not(os.path.exists(DATA_DIR + MODEL_CNN_FILE)):
  2. train_set\
  3. .link(
  4. KerasSequentialClassifierTrainBatchOp()\
  5. .setTensorCol("tensor")\
  6. .setLabelCol("label")\
  7. .setLayers([
  8. "Conv2D(32, kernel_size=(3, 3), activation='relu')",
  9. "MaxPooling2D(pool_size=(2, 2))",
  10. "Conv2D(64, kernel_size=(3, 3), activation='relu')",
  11. "MaxPooling2D(pool_size=(2, 2))",
  12. "Flatten()",
  13. "Dropout(0.5)"
  14. ])\
  15. .setNumEpochs(50)\
  16. .setSaveCheckpointsEpochs(2.0)\
  17. .setValidationSplit(0.1)\
  18. .setSaveBestOnly(True)\
  19. .setBestMetric("auc")\
  20. )\
  21. .link(
  22. AkSinkBatchOp()\
  23. .setFilePath(DATA_DIR + MODEL_CNN_FILE)\
  24. )
  25. BatchOperator.execute()

再使用导入训练好的模型,对测试集进行预测,并做二分类模型评估。

  1. KerasSequentialClassifierPredictBatchOp()\
  2. .setPredictionCol(PREDICTION_COL)\
  3. .setPredictionDetailCol(PREDICTION_DETAIL_COL)\
  4. .setReservedCols(["relative_path", "label"])\
  5. .linkFrom(
  6. AkSourceBatchOp().setFilePath(DATA_DIR + MODEL_CNN_FILE),
  7. test_set
  8. )\
  9. .lazyPrint(10)\
  10. .lazyPrintStatistics()\
  11. .link(
  12. EvalBinaryClassBatchOp()\
  13. .setLabelCol("label")\
  14. .setPredictionDetailCol(PREDICTION_DETAIL_COL)\
  15. .lazyPrintMetrics()
  16. )
  17. BatchOperator.execute();

模型评估结果如下,明显优于逻辑回归模型。由于本实验考虑训练时间不宜太长,训练次数设定为50次,如果读者想要获得更好的模型效果,可以调整训练参数。另外,下一节介绍使用预训练模型的方法,可以帮助我们在较短的时间内拿到更好的效果。
image.png

  1. Summary:
  2. | colName|count|missing|sum|mean|variance|min|max|
  3. |-------------|-----|-------|---|----|--------|---|---|
  4. |relative_path| 2500| 0|NaN| NaN| NaN|NaN|NaN|
  5. | label| 2500| 0|NaN| NaN| NaN|NaN|NaN|
  6. | pred| 2500| 0|NaN| NaN| NaN|NaN|NaN|
  7. | pred_info| 2500| 0|NaN| NaN| NaN|NaN|NaN|
  8. -------------------------------- Metrics: --------------------------------
  9. Auc:0.951 Accuracy:0.8672 Precision:0.9057 Recall:0.812 F1:0.8563 LogLoss:0.3023
  10. |Pred\Real|dog| cat|
  11. |---------|---|----|
  12. | dog|989| 103|
  13. | cat|229|1179|