在Alink教材第16章,以葡萄酒的品质预测为例,演示了线性模型、随机森林、GBDT等算法的回归训练及预测。本节仍以葡萄酒的品质预测为例,重点演示如何使用深度学习进行回归训练及预测。

25.3.1 线性回归算法

首先,我们先使用线性回归算法,得到一个baseline,具体代码如下:

  1. new LinearRegression()
  2. .setFeatureCols(Chap16.FEATURE_COL_NAMES)
  3. .setLabelCol("quality")
  4. .setPredictionCol("pred")
  5. .enableLazyPrintModelInfo()
  6. .fit(train_set)
  7. .transform(test_set)
  8. .lazyPrintStatistics()
  9. .link(
  10. new EvalRegressionBatchOp()
  11. .setLabelCol("quality")
  12. .setPredictionCol("pred")
  13. .lazyPrintMetrics()
  14. );
  15. BatchOperator.execute();

其中使用了enableLazyPrintModelInfo方法,会打印出模型的信息,如下所示,“intercept”对应的是线性回归的常数项,每个特征列对应一个线性回归的参数。

  1. ----------------------------- model meta info -----------------------------
  2. {hasInterception: true, model name: Linear Regression, num feature: 11}
  3. ---------------------------- model weight info ----------------------------
  4. | colName[0,9]| intercept|fixedAcidity|volatileAcidity| citricAcid|residualSugar| chlorides|freeSulfurDioxide|totalSulfurDioxide| density| pH|
  5. | weight[0,9]| 147.2227| 0.05561785| -1.88292677|-0.02573757| 0.08013539|-0.32426272| 0.00370713| -0.00037885|-147.10297328|0.64837239|
  6. |colName[10,11]| sulphates| alcohol| | | | | | | | |
  7. | weight[10,11]|0.64099725| 0.19730515|

线性回归预测结果的统计如下,可以看到quality列的均值为5.8735,方差为0.7235;预测结果的均值为5.8767,方差为0.2083。

  1. Summary:
  2. | colName|count|missing| sum| mean| variance| min| max|
  3. |------------------|-----|-------|----------|--------|---------|------|------|
  4. | fixedAcidity| 980| 0| 6722.4| 6.8596| 0.73| 4.7| 10|
  5. | volatileAcidity| 980| 0| 271.38| 0.2769| 0.0099| 0.08| 1.1|
  6. | citricAcid| 980| 0| 327.46| 0.3341| 0.014| 0| 1|
  7. | residualSugar| 980| 0| 6280.75| 6.4089| 26.1442| 0.7| 26.05|
  8. | chlorides| 980| 0| 44.462| 0.0454| 0.0004| 0.012| 0.239|
  9. | freeSulfurDioxide| 980| 0| 34358.5| 35.0597| 286.5621| 2| 138.5|
  10. |totalSulfurDioxide| 980| 0| 135518.5|138.2842|1948.7059| 9| 303|
  11. | density| 980| 0| 974.1799| 0.9941| 0|0.9872| 1.003|
  12. | pH| 980| 0| 3125.47| 3.1893| 0.0219| 2.85| 3.82|
  13. | sulphates| 980| 0| 483.09| 0.4929| 0.0123| 0.27| 0.98|
  14. | alcohol| 980| 0|10289.4933| 10.4995| 1.5211| 8.4| 14.05|
  15. | quality| 980| 0| 5756| 5.8735| 0.7235| 3| 9|
  16. | pred| 980| 0| 5759.1695| 5.8767| 0.2083|4.1547|7.2093|

使用EvalRegressionBatchOp组件,计算显示回归统计指标如下:

  1. -------------------------------- Metrics: --------------------------------
  2. MSE:0.5309 RMSE:0.7286 MAE:0.5748 MAPE:10.0995 R2:0.2655

25.3.2 深度回归算法

使用深度回归模型的代码如下:

  1. new Pipeline()
  2. .add(
  3. new StandardScaler()
  4. .setSelectedCols(Chap16.FEATURE_COL_NAMES)
  5. )
  6. .add(
  7. new VectorAssembler()
  8. .setSelectedCols(Chap16.FEATURE_COL_NAMES)
  9. .setOutputCol("vec")
  10. )
  11. .add(
  12. new VectorToTensor()
  13. .setSelectedCol("vec")
  14. .setOutputCol("tensor")
  15. .setReservedCols("quality")
  16. )
  17. .add(
  18. new KerasSequentialRegressor()
  19. .setTensorCol("tensor")
  20. .setLabelCol("quality")
  21. .setPredictionCol("pred")
  22. .setLayers(
  23. "Dense(64, activation='relu')",
  24. "Dense(64, activation='relu')",
  25. "Dense(64, activation='relu')",
  26. "Dense(64, activation='relu')",
  27. "Dense(64, activation='relu')"
  28. )
  29. .setNumEpochs(20)
  30. .setNumWorkers(1)
  31. .setNumPSs(0)
  32. )
  33. .fit(train_set)
  34. .transform(test_set)
  35. .lazyPrintStatistics()
  36. .link(
  37. new EvalRegressionBatchOp()
  38. .setLabelCol("quality")
  39. .setPredictionCol("pred")
  40. .lazyPrintMetrics()
  41. );
  42. BatchOperator.execute();

在Pipeline中使用了多个组件:
1、数据标准化组件StandardScaler,因为数据中各列的数值范围差异较大,标准化后,有助于提升深度模型的效果
2、拼接向量组件VectorAssembler,将多列数值数据转化为一列向量数据。
3、向量转化为张量的组件VectorToTensor,后面的Keras组件的输入格式为张量
4、Keras回归器组件KerasSequentialRegressor,定义了深度模型

Pipeline使用fit方法对训练集train_set进行训练,然后使用transform方法对测试集test_set进行预测。预测结果的统计如下

  1. Summary:
  2. |colName|count|missing| sum| mean|variance| min| max|
  3. |-------|-----|-------|---------|------|--------|------|------|
  4. |quality| 980| 0| 5756|5.8735| 0.7235| 3| 9|
  5. | tensor| 980| 0| NaN| NaN| NaN| NaN| NaN|
  6. | pred| 980| 0|5766.3137| 5.884| 0.3881|3.8637|7.5729|

相应的评估指标如下,可以看到均方误差MSE和平均绝对误差MAE等指标,相对线性回归算法有明显改进。

  1. -------------------------------- Metrics: --------------------------------
  2. MSE:0.485 RMSE:0.6964 MAE:0.5323 MAPE:9.3756 R2:0.3289