Java 类名:com.alibaba.alink.operator.batch.regression.XGBoostRegTrainBatchOp
Python 类名:XGBoostRegTrainBatchOp

功能介绍

XGBoost 组件是在开源社区的基础上进行包装,使功能和 PAI 更兼容,更易用。
XGBoost 算法在 Boosting 算法的基础上进行了扩展和升级,具有较好的易用性和鲁棒性,被广泛用在各种机器学习生产系统和竞赛领域。
当前支持分类,回归和排序。

参数说明

名称 中文名称 描述 类型 是否必须? 取值范围 默认值
labelCol 标签列名 输入表中的标签列名 String
numRound 树的棵树 树的棵树 Integer
alpha L1正则项 L1正则项 Double 1.0
baseScore Base score Base score Double 0.5
colSampleByLevel 每个树列采样 每个树列采样 Double 1.0
colSampleByNode 每个结点列采样 每个结点采样 Double 1.0
colSampleByTree 每个树列采样 每个树列采样 Double 1.0
eta 学习率 学习率 Double 0.3
featureCols 特征列名数组 特征列名数组,默认全选 String[] [BIGDECIMAL, BIGINTEGER, BYTE, DOUBLE, FLOAT, INTEGER, LONG,

代码示例

以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!

Python 代码

  1. df = pd.DataFrame([
  2. [0, 1, 1.1, 1.0],
  3. [1, -2, 0.9, 2.0],
  4. [0, 100, -0.01, 3.0],
  5. [1, -99, 0.1, 4.0],
  6. [0, 1, 1.1, 5.0],
  7. [1, -2, 0.9, 6.0]
  8. ])
  9. batchSource = BatchOperator.fromDataframe(
  10. df, schemaStr='y int, x1 double, x2 double, x3 double'
  11. )
  12. streamSource = StreamOperator.fromDataframe(
  13. df, schemaStr='y int, x1 double, x2 double, x3 double'
  14. )
  15. trainOp = XGBoostRegTrainBatchOp()\
  16. .setNumRound(1)\
  17. .setPluginVersion('1.5.1')\
  18. .setLabelCol('y')\
  19. .linkFrom(batchSource)
  20. predictBatchOp = XGBoostRegPredictBatchOp()\
  21. .setPredictionCol('pred')\
  22. .setPluginVersion('1.5.1')
  23. predictStreamOp = XGBoostRegPredictStreamOp(trainOp)\
  24. .setPredictionCol('pred')\
  25. .setPluginVersion('1.5.1')
  26. predictBatchOp.linkFrom(trainOp, batchSource).print()
  27. predictStreamOp.linkFrom(streamSource).print()
  28. StreamOperator.execute()

Java 代码

  1. import org.apache.flink.types.Row;
  2. import com.alibaba.alink.operator.batch.BatchOperator;
  3. import com.alibaba.alink.operator.batch.regression.XGBoostRegPredictBatchOp;
  4. import com.alibaba.alink.operator.batch.regression.XGBoostRegTrainBatchOp;
  5. import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
  6. import com.alibaba.alink.operator.stream.StreamOperator;
  7. import com.alibaba.alink.operator.stream.regression.XGBoostRegPredictStreamOp;
  8. import com.alibaba.alink.operator.stream.source.MemSourceStreamOp;
  9. import org.junit.Test;
  10. import java.util.Arrays;
  11. import java.util.List;
  12. public class XGBoostRegTrainBatchOpTest {
  13. @Test
  14. public void testXGBoostTrainBatchOp() throws Exception {
  15. List <Row> data = Arrays.asList(
  16. Row.of(0, 1, 1.1, 1.0),
  17. Row.of(1, -2, 0.9, 2.0),
  18. Row.of(0, 100, -0.01, 3.0),
  19. Row.of(1, -99, 0.1, 4.0),
  20. Row.of(0, 1, 1.1, 5.0),
  21. Row.of(1, -2, 0.9, 6.0)
  22. );
  23. BatchOperator <?> batchSource = new MemSourceBatchOp(data, "y int, x1 int, x2 double, x3 double");
  24. StreamOperator <?> streamSource = new MemSourceStreamOp(data, "y int, x1 int, x2 double, x3 double");
  25. BatchOperator <?> trainOp = new XGBoostRegTrainBatchOp()
  26. .setNumRound(1)
  27. .setPluginVersion("1.5.1")
  28. .setLabelCol("y")
  29. .linkFrom(batchSource);
  30. BatchOperator <?> predictBatchOp = new XGBoostRegPredictBatchOp()
  31. .setPredictionCol("pred")
  32. .setPluginVersion("1.5.1");
  33. StreamOperator <?> predictStreamOp = new XGBoostRegPredictStreamOp(trainOp)
  34. .setPredictionCol("pred")
  35. .setPluginVersion("1.5.1");
  36. predictBatchOp.linkFrom(trainOp, batchSource).print();
  37. predictStreamOp.linkFrom(streamSource).print();
  38. StreamOperator.execute();
  39. }
  40. }

运行结果