Java 类名:com.alibaba.alink.operator.batch.classification.XGBoostPredictBatchOp
Python 类名:XGBoostPredictBatchOp

功能介绍

XGBoost 组件是在开源社区的基础上进行包装,使功能和 PAI 更兼容,更易用。
XGBoost 算法在 Boosting 算法的基础上进行了扩展和升级,具有较好的易用性和鲁棒性,被广泛用在各种机器学习生产系统和竞赛领域。
当前支持分类,回归和排序。

参数说明

| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 | | —- | —- | —- | —- | —- | —- | —- |

| predictionCol | 预测结果列名 | 预测结果列名 | String | ✓ | | |

| modelFilePath | 模型的文件路径 | 模型的文件路径 | String | | | null |

| pluginVersion | 插件版本号 | 插件版本号 | String | | | “1.5.1” |

| predictionDetailCol | 预测详细信息列名 | 预测详细信息列名 | String | | | |

| reservedCols | 算法保留列名 | 算法保留列 | String[] | | | null |

| numThreads | 组件多线程线程个数 | 组件多线程线程个数 | Integer | | | 1 |

代码示例

以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!

Python 代码

  1. df = pd.DataFrame([
  2. [0, 1, 1.1, 1.0],
  3. [1, -2, 0.9, 2.0],
  4. [0, 100, -0.01, 3.0],
  5. [1, -99, 0.1, 4.0],
  6. [0, 1, 1.1, 5.0],
  7. [1, -2, 0.9, 6.0]
  8. ])
  9. batchSource = BatchOperator.fromDataframe(
  10. df, schemaStr='y int, x1 int, x2 double, x3 double'
  11. )
  12. streamSource = StreamOperator.fromDataframe(
  13. df, schemaStr='y int, x1 int, x2 double, x3 double'
  14. )
  15. trainOp = XGBoostTrainBatchOp()\
  16. .setNumRound(1)\
  17. .setPluginVersion('1.5.1')\
  18. .setLabelCol('y')\
  19. .linkFrom(batchSource)
  20. predictBatchOp = XGBoostPredictBatchOp()\
  21. .setPredictionDetailCol('pred_detail')\
  22. .setPredictionCol('pred')\
  23. .setPluginVersion('1.5.1')
  24. predictStreamOp = XGBoostPredictStreamOp(trainOp)\
  25. .setPredictionDetailCol('pred_detail')\
  26. .setPredictionCol('pred')\
  27. .setPluginVersion('1.5.1')
  28. predictBatchOp.linkFrom(trainOp, batchSource).print()
  29. predictStreamOp.linkFrom(streamSource).print()
  30. StreamOperator.execute()

Java 代码

  1. import org.apache.flink.types.Row;
  2. import com.alibaba.alink.operator.batch.BatchOperator;
  3. import com.alibaba.alink.operator.batch.classification.XGBoostPredictBatchOp;
  4. import com.alibaba.alink.operator.batch.classification.XGBoostTrainBatchOp;
  5. import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
  6. import com.alibaba.alink.operator.stream.StreamOperator;
  7. import com.alibaba.alink.operator.stream.classification.XGBoostPredictStreamOp;
  8. import com.alibaba.alink.operator.stream.source.MemSourceStreamOp;
  9. import org.junit.Test;
  10. import java.util.Arrays;
  11. import java.util.List;
  12. public class XGBoostTrainBatchOpTest {
  13. @Test
  14. public void testXGBoostTrainBatchOp() throws Exception {
  15. List <Row> data = Arrays.asList(
  16. Row.of(0, 1, 1.1, 1.0),
  17. Row.of(1, -2, 0.9, 2.0),
  18. Row.of(0, 100, -0.01, 3.0),
  19. Row.of(1, -99, 0.1, 4.0),
  20. Row.of(0, 1, 1.1, 5.0),
  21. Row.of(1, -2, 0.9, 6.0)
  22. );
  23. BatchOperator <?> batchSource = new MemSourceBatchOp(data, "y int, x1 int, x2 double, x3 double");
  24. StreamOperator <?> streamSource = new MemSourceStreamOp(data, "y int, x1 int, x2 double, x3 double");
  25. BatchOperator <?> trainOp = new XGBoostTrainBatchOp()
  26. .setNumRound(1)
  27. .setPluginVersion("1.5.1")
  28. .setLabelCol("y")
  29. .linkFrom(batchSource);
  30. BatchOperator <?> predictBatchOp = new XGBoostPredictBatchOp()
  31. .setPredictionDetailCol("pred_detail")
  32. .setPredictionCol("pred")
  33. .setPluginVersion("1.5.1");
  34. StreamOperator <?> predictStreamOp = new XGBoostPredictStreamOp(trainOp)
  35. .setPredictionDetailCol("pred_detail")
  36. .setPredictionCol("pred")
  37. .setPluginVersion("1.5.1");
  38. predictBatchOp.linkFrom(trainOp, batchSource).print();
  39. predictStreamOp.linkFrom(streamSource).print();
  40. StreamOperator.execute();
  41. }
  42. }

运行结果