Java 类名:com.alibaba.alink.operator.batch.regression.XGBoostRegPredictBatchOp
Python 类名:XGBoostRegPredictBatchOp

功能介绍

XGBoost 组件是在开源社区的基础上进行包装,使功能和 PAI 更兼容,更易用。
XGBoost 算法在 Boosting 算法的基础上进行了扩展和升级,具有较好的易用性和鲁棒性,被广泛用在各种机器学习生产系统和竞赛领域。
当前支持分类,回归和排序。

参数说明

名称 中文名称 描述 类型 是否必须? 取值范围 默认值
predictionCol 预测结果列名 预测结果列名 String 无限制
modelFilePath 模型的文件路径 模型的文件路径 String 无限制 null
pluginVersion 插件版本号 插件版本号 String 无限制 “1.5.1”
reservedCols 算法保留列名 算法保留列 String[] 无限制 null
numThreads 组件多线程线程个数 组件多线程线程个数 Integer 无限制 1

代码示例

以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!

Python 代码

  1. df = pd.DataFrame([
  2. [0, 1, 1.1, 1.0],
  3. [1, -2, 0.9, 2.0],
  4. [0, 100, -0.01, 3.0],
  5. [1, -99, 0.1, 4.0],
  6. [0, 1, 1.1, 5.0],
  7. [1, -2, 0.9, 6.0]
  8. ])
  9. batchSource = BatchOperator.fromDataframe(
  10. df, schemaStr='y int, x1 double, x2 double, x3 double'
  11. )
  12. streamSource = StreamOperator.fromDataframe(
  13. df, schemaStr='y int, x1 double, x2 double, x3 double'
  14. )
  15. trainOp = XGBoostRegTrainBatchOp()\
  16. .setNumRound(1)\
  17. .setPluginVersion('1.5.1')\
  18. .setLabelCol('y')\
  19. .linkFrom(batchSource)
  20. predictBatchOp = XGBoostRegPredictBatchOp()\
  21. .setPredictionCol('pred')\
  22. .setPluginVersion('1.5.1')
  23. predictStreamOp = XGBoostRegPredictStreamOp(trainOp)\
  24. .setPredictionCol('pred')\
  25. .setPluginVersion('1.5.1')
  26. predictBatchOp.linkFrom(trainOp, batchSource).print()
  27. predictStreamOp.linkFrom(streamSource).print()
  28. StreamOperator.execute()

Java 代码

  1. import org.apache.flink.types.Row;
  2. import com.alibaba.alink.operator.batch.BatchOperator;
  3. import com.alibaba.alink.operator.batch.regression.XGBoostRegPredictBatchOp;
  4. import com.alibaba.alink.operator.batch.regression.XGBoostRegTrainBatchOp;
  5. import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
  6. import com.alibaba.alink.operator.stream.StreamOperator;
  7. import com.alibaba.alink.operator.stream.regression.XGBoostRegPredictStreamOp;
  8. import com.alibaba.alink.operator.stream.source.MemSourceStreamOp;
  9. import org.junit.Test;
  10. import java.util.Arrays;
  11. import java.util.List;
  12. public class XGBoostRegTrainBatchOpTest {
  13. @Test
  14. public void testXGBoostTrainBatchOp() throws Exception {
  15. List <Row> data = Arrays.asList(
  16. Row.of(0, 1, 1.1, 1.0),
  17. Row.of(1, -2, 0.9, 2.0),
  18. Row.of(0, 100, -0.01, 3.0),
  19. Row.of(1, -99, 0.1, 4.0),
  20. Row.of(0, 1, 1.1, 5.0),
  21. Row.of(1, -2, 0.9, 6.0)
  22. );
  23. BatchOperator <?> batchSource = new MemSourceBatchOp(data, "y int, x1 int, x2 double, x3 double");
  24. StreamOperator <?> streamSource = new MemSourceStreamOp(data, "y int, x1 int, x2 double, x3 double");
  25. BatchOperator <?> trainOp = new XGBoostRegTrainBatchOp()
  26. .setNumRound(1)
  27. .setPluginVersion("1.5.1")
  28. .setLabelCol("y")
  29. .linkFrom(batchSource);
  30. BatchOperator <?> predictBatchOp = new XGBoostRegPredictBatchOp()
  31. .setPredictionCol("pred")
  32. .setPluginVersion("1.5.1");
  33. StreamOperator <?> predictStreamOp = new XGBoostRegPredictStreamOp(trainOp)
  34. .setPredictionCol("pred")
  35. .setPluginVersion("1.5.1");
  36. predictBatchOp.linkFrom(trainOp, batchSource).print();
  37. predictStreamOp.linkFrom(streamSource).print();
  38. StreamOperator.execute();
  39. }
  40. }

运行结果