Java 类名:com.alibaba.alink.operator.batch.classification.RandomForestPredictBatchOp
Python 类名:RandomForestPredictBatchOp

功能介绍

随机森林一种经典的有监督学习非线性决策树模型,可以解决分类,回归和其他的一些决策树模型可以解决的问题,通常可以拿到比单决策树更好的效果。

算法原理

通过 Bagging 的方法组合多棵决策树,生成最终的模型。

算法使用

我们给定 Adult 数据集,在这个场景下介绍随机森林的使用步骤

数据集

Adult

训练集

训练数据集的基本统计结果为
Adult train
Summary:

| colName | count | missing | sum | mean | variance | min | max | | —- | —- | —- | —- | —- | —- | —- | —- |

| age | 32560 | 0 | 1256214 | 38.5815 | 186.0665 | 17 | 90 |

| workclass | 32560 | 1836 | NaN | NaN | NaN | NaN | NaN |

| fnlwgt | 32560 | 0 | 6179243539 | 189780.207 | 11141029667.4508 | 12285 | 1484705 |

| education | 32560 | 0 | NaN | NaN | NaN | NaN | NaN |

| education_num | 32560 | 0 | 328231 | 10.0808 | 6.6186 | 1 | 16 |

| marital_status | 32560 | 0 | NaN | NaN | NaN | NaN | NaN |

| occupation | 32560 | 1843 | NaN | NaN | NaN | NaN | NaN |

| relationship | 32560 | 0 | NaN | NaN | NaN | NaN | NaN |

| race | 32560 | 0 | NaN | NaN | NaN | NaN | NaN |

| sex | 32560 | 0 | NaN | NaN | NaN | NaN | NaN |

| capital_gain | 32560 | 0 | 35089324 | 1077.6819 | 54544178.6998 | 0 | 99999 |

| capital_loss | 32560 | 0 | 2842700 | 87.3065 | 162381.6909 | 0 | 4356 |

| hours_per_week | 32560 | 0 | 1316644 | 40.4375 | 152.4637 | 1 | 99 |

| native_country | 32560 | 583 | NaN | NaN | NaN | NaN | NaN |

| label | 32560 | 0 | NaN | NaN | NaN | NaN | NaN |

读取数据可以使用如下方法进行:

  1. CsvSourceBatchOp trainData = new CsvSourceBatchOp()
  2. .setFilePath("https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_train.csv")
  3. .setIgnoreFirstLine(true)
  4. .setSchemaStr(schemaStr)
  5. .lazyPrintStatistics("Adult train");

上述代码中可以使用

  1. lazyPrintStatistics("Adult train");

即可拿到数据的统计结果

测试集

测试数据集的基本统计结果为
Adult test
Summary:

| colName | count | missing | sum | mean | variance | min | max | | —- | —- | —- | —- | —- | —- | —- | —- |

| age | 16280 | 0 | 631146 | 38.7682 | 191.8033 | 17 | 90 |

| workclass | 16280 | 963 | NaN | NaN | NaN | NaN | NaN |

| fnlwgt | 16280 | 0 | 3083900756 | 189428.7934 | 11175556521.7039 | 13492 | 1490400 |

| education | 16280 | 0 | NaN | NaN | NaN | NaN | NaN |

| education_num | 16280 | 0 | 163987 | 10.0729 | 6.5927 | 1 | 16 |

| marital_status | 16280 | 0 | NaN | NaN | NaN | NaN | NaN |

| occupation | 16280 | 966 | NaN | NaN | NaN | NaN | NaN |

| relationship | 16280 | 0 | NaN | NaN | NaN | NaN | NaN |

| race | 16280 | 0 | NaN | NaN | NaN | NaN | NaN |

| sex | 16280 | 0 | NaN | NaN | NaN | NaN | NaN |

| capital_gain | 16280 | 0 | 17614497 | 1081.9716 | 57519546.0031 | 0 | 99999 |

| capital_loss | 16280 | 0 | 1431088 | 87.9047 | 162503.3785 | 0 | 3770 |

| hours_per_week | 16280 | 0 | 657586 | 40.3923 | 155.7433 | 1 | 99 |

| native_country | 16280 | 274 | NaN | NaN | NaN | NaN | NaN |

| label | 16280 | 0 | NaN | NaN | NaN | NaN | NaN |

读取数据可以使用如下方法进行:

  1. CsvSourceBatchOp testData = new CsvSourceBatchOp()
  2. .setFilePath("https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/adult_test.csv")
  3. .setIgnoreFirstLine(true)
  4. .setSchemaStr(schemaStr)
  5. .lazyPrintStatistics("Adult test");

训练

训练模型可以使用 RandomForestTrainBatchOp , 其中支持一些常用的决策树剪枝参数,可以通过调整这些参数来拿到一些更好的模型,详细可以参考参数说明部分。

  1. String[] numericalFeatureColNames = new String[] {"age", "fnlwgt", "education_num", "capital_gain",
  2. "capital_loss", "hours_per_week"};
  3. String[] categoryFeatureColNames = new String[] {"workclass", "education", "marital_status", "occupation",
  4. "relationship", "race", "sex", "native_country"};
  5. RandomForestTrainBatchOp randomForestBatchOp = new RandomForestTrainBatchOp()
  6. .setFeatureCols(ArrayUtils.addAll(numericalFeatureColNames, categoryFeatureColNames))
  7. .setCategoricalCols(categoryFeatureColNames)
  8. .setSubsamplingRatio(0.6)
  9. .setMaxLeaves(32)
  10. .setLabelCol("label");

预测

  1. RandomForestPredictBatchOp prediction = new RandomForestPredictBatchOp()
  2. .setPredictionCol("prediction")
  3. .setPredictionDetailCol("prediction_detail");

评估

  1. EvalBinaryClassBatchOp eval = new EvalBinaryClassBatchOp()
  2. .setLabelCol("prediction")
  3. .setPredictionDetailCol("prediction_detail");

训练预测流程构建

  1. prediction
  2. .linkFrom(
  3. randomForestBatchOp
  4. .linkFrom(trainData)
  5. .lazyPrintModelInfo("Adult random forest model")
  6. .lazyCollectModelInfo(new Consumer <RandomForestModelInfo>() {
  7. @Override
  8. public void accept(RandomForestModelInfo randomForestModelInfo) {
  9. try {
  10. randomForestModelInfo
  11. .saveTreeAsImage("/tmp/rf_adult_model.png", 0, true);
  12. } catch (IOException e) {
  13. throw new IllegalStateException(e);
  14. }
  15. }
  16. }),
  17. testData
  18. )
  19. .link(eval)
  20. .lazyPrintMetrics("Adult random forest evaluation");

执行

  1. BatchOperator.execute();

运行结果

模型信息

Adult random forest model
Classification trees modelInfo:
Number of trees: 10
Number of features: 14
Number of categorical features: 8
Labels: [<=50K, >50K]
Categorical feature info:

| feature | number of categorical value | | —- | —- |

| workclass | 8 |

| education | 16 |

| marital_status | 7 |

| … | … |

| race | 5 |

| sex | 2 |

| native_country | 41 |

Table of feature importance Top 14:

| feature | importance | | —- | —- |

| age | 0.1997 |

| fnlwgt | 0.1992 |

| capital_gain | 0.1447 |

| hours_per_week | 0.1091 |

| education_num | 0.0889 |

| occupation | 0.0553 |

| relationship | 0.0423 |

| capital_loss | 0.0336 |

| workclass | 0.0306 |

| sex | 0.0299 |

| race | 0.0188 |

| marital_status | 0.0176 |

| native_country | 0.0158 |

| education | 0.0144 |

Classification trees modelInfo:
Number of trees: 10
Number of features: 14
Number of categorical features: 8
Labels: [<=50K, >50K]
Categorical feature info:

| feature | number of categorical value | | —- | —- |

| workclass | 8 |

| education | 16 |

| marital_status | 7 |

| … | … |

| race | 5 |

| sex | 2 |

| native_country | 41 |

Table of feature importance Top 14:

| feature | importance | | —- | —- |

| fnlwgt | 0.2318 |

| age | 0.2286 |

| hours_per_week | 0.1382 |

| education_num | 0.0706 |

| occupation | 0.0645 |

| capital_gain | 0.0568 |

| workclass | 0.0516 |

| sex | 0.033 |

| relationship | 0.0299 |

| capital_loss | 0.0222 |

| education | 0.0218 |

| native_country | 0.0199 |

| race | 0.0175 |

| marital_status | 0.0136 |

模型信息中包含一些常用的训练输入数据的基本信息,特征的基本信息,模型的基本信息。
离散特征的一些统计信息,可以通过 Categorical feature info 部分查看。
特征重要性是一类更常用的筛选特征的指标,可以通过 Table of feature importance Top 14 部分查看。

模型可视化

我们也输出了随进森林中第 0 号树的模型结果可视化结果,通过代码中 lazyCollectModelInfo 收集到模型信息之后,通过模型中提供的 saveTreeAsImage ,可以输出模型的图片结果到指定路径。

  1. .lazyCollectModelInfo(new Consumer <RandomForestModelInfo>() {
  2. @Override
  3. public void accept(RandomForestModelInfo randomForestModelInfo) {
  4. try {
  5. randomForestModelInfo
  6. .saveTreeAsImage("/tmp/rf_adult_model.png", 0, true);
  7. } catch (IOException e) {
  8. throw new IllegalStateException(e);
  9. }
  10. }
  11. })

rf_model_img

评估结果

Adult random forest evaluation
———————————————— Metrics: ————————————————
Auc:1 Accuracy:0.9995 Precision:0.9965 Recall:1 F1:0.9982 LogLoss:0.2584

| Pred\Real | >50K | <=50K | | —- | —- | —- |

|

50K | 2273 | 8 |

| <=50K | 0 | 13999 |

评估结果中包含一些常用

文献或出处

  1. RandomForest
  2. weka

    参数说明

    | 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 | | —- | —- | —- | —- | —- | —- | —- |

| predictionCol | 预测结果列名 | 预测结果列名 | String | ✓ | | |

| modelFilePath | 模型的文件路径 | 模型的文件路径 | String | | | null |

| predictionDetailCol | 预测详细信息列名 | 预测详细信息列名 | String | | | |

| reservedCols | 算法保留列名 | 算法保留列 | String[] | | | null |

| numThreads | 组件多线程线程个数 | 组件多线程线程个数 | Integer | | | 1 |

代码示例

Python 代码

  1. from pyalink.alink import *
  2. import pandas as pd
  3. useLocalEnv(1)
  4. df = pd.DataFrame([
  5. [1.0, "A", 0, 0, 0],
  6. [2.0, "B", 1, 1, 0],
  7. [3.0, "C", 2, 2, 1],
  8. [4.0, "D", 3, 3, 1]
  9. ])
  10. batchSource = BatchOperator.fromDataframe(
  11. df, schemaStr=' f0 double, f1 string, f2 int, f3 int, label int')
  12. streamSource = StreamOperator.fromDataframe(
  13. df, schemaStr=' f0 double, f1 string, f2 int, f3 int, label int')
  14. trainOp = RandomForestTrainBatchOp()\
  15. .setLabelCol('label')\
  16. .setFeatureCols(['f0', 'f1', 'f2', 'f3'])\
  17. .linkFrom(batchSource)
  18. predictBatchOp = RandomForestPredictBatchOp()\
  19. .setPredictionDetailCol('pred_detail')\
  20. .setPredictionCol('pred')
  21. predictStreamOp = RandomForestPredictStreamOp(trainOp)\
  22. .setPredictionDetailCol('pred_detail')\
  23. .setPredictionCol('pred')
  24. predictBatchOp.linkFrom(trainOp, batchSource).print()
  25. predictStreamOp.linkFrom(streamSource).print()
  26. StreamOperator.execute()

Java 代码

  1. import org.apache.flink.types.Row;
  2. import com.alibaba.alink.operator.batch.BatchOperator;
  3. import com.alibaba.alink.operator.batch.classification.RandomForestPredictBatchOp;
  4. import com.alibaba.alink.operator.batch.classification.RandomForestTrainBatchOp;
  5. import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
  6. import com.alibaba.alink.operator.stream.StreamOperator;
  7. import com.alibaba.alink.operator.stream.classification.RandomForestPredictStreamOp;
  8. import com.alibaba.alink.operator.stream.source.MemSourceStreamOp;
  9. import org.junit.Test;
  10. import java.util.Arrays;
  11. import java.util.List;
  12. public class RandomForestPredictBatchOpTest {
  13. @Test
  14. public void testRandomForestPredictBatchOp() throws Exception {
  15. List <Row> df = Arrays.asList(
  16. Row.of(1.0, "A", 0, 0, 0),
  17. Row.of(2.0, "B", 1, 1, 0),
  18. Row.of(3.0, "C", 2, 2, 1),
  19. Row.of(4.0, "D", 3, 3, 1)
  20. );
  21. BatchOperator <?> batchSource = new MemSourceBatchOp(
  22. df, " f0 double, f1 string, f2 int, f3 int, label int");
  23. StreamOperator <?> streamSource = new MemSourceStreamOp(
  24. df, " f0 double, f1 string, f2 int, f3 int, label int");
  25. BatchOperator <?> trainOp = new RandomForestTrainBatchOp()
  26. .setLabelCol("label")
  27. .setFeatureCols("f0", "f1", "f2", "f3")
  28. .linkFrom(batchSource);
  29. BatchOperator <?> predictBatchOp = new RandomForestPredictBatchOp()
  30. .setPredictionDetailCol("pred_detail")
  31. .setPredictionCol("pred");
  32. StreamOperator <?> predictStreamOp = new RandomForestPredictStreamOp(trainOp)
  33. .setPredictionDetailCol("pred_detail")
  34. .setPredictionCol("pred");
  35. predictBatchOp.linkFrom(trainOp, batchSource).print();
  36. predictStreamOp.linkFrom(streamSource).print();
  37. StreamOperator.execute();
  38. }
  39. }

运行结果

| f0 | f1 | f2 | f3 | label | pred | pred_detail | | —- | —- | —- | —- | —- | —- | —- |

| 1.0000 | A | 0 | 0 | 0 | 0 | {“0”:1.0,”1”:0.0} |

| 2.0000 | B | 1 | 1 | 0 | 0 | {“0”:1.0,”1”:0.0} |

| 3.0000 | C | 2 | 2 | 1 | 1 | {“0”:0.0,”1”:1.0} |

| 4.0000 | D | 3 | 3 | 1 | 1 | {“0”:0.0,”1”:1.0} |