Java 类名:com.alibaba.alink.operator.stream.onlinelearning.FtrlPredictStreamOp
Python 类名:FtrlPredictStreamOp

功能介绍

实时更新ftrl 训练得到的模型流,并使用实时的模型对实时的数据进行预测。

参数说明

| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 | | —- | —- | —- | —- | —- | —- | —- |

| predictionCol | 预测结果列名 | 预测结果列名 | String | ✓ | | |

| predictionDetailCol | 预测详细信息列名 | 预测详细信息列名 | String | | | |

| reservedCols | 算法保留列名 | 算法保留列 | String[] | | | null |

| vectorCol | 向量列名 | 向量列对应的列名,默认值是null | String | | 所选列类型为 [DENSE_VECTOR, SPARSE_VECTOR, STRING, VECTOR] | null |

| numThreads | 组件多线程线程个数 | 组件多线程线程个数 | Integer | | | 1 |

代码示例

以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!

Python 代码

  1. trainData0 = RandomTableSourceBatchOp() \
  2. .setNumCols(5) \
  3. .setNumRows(100) \
  4. .setOutputCols(["f0", "f1", "f2", "f3", "label"]) \
  5. .setOutputColConfs("label:weight_set(1.0,1.0,2.0,5.0)")
  6. model = LogisticRegressionTrainBatchOp() \
  7. .setFeatureCols(["f0", "f1", "f2", "f3"]) \
  8. .setLabelCol("label") \
  9. .setMaxIter(10).linkFrom(trainData0)
  10. trainData1 = RandomTableSourceStreamOp() \
  11. .setNumCols(5) \
  12. .setMaxRows(10000) \
  13. .setOutputCols(["f0", "f1", "f2", "f3", "label"]) \
  14. .setOutputColConfs("label:weight_set(1.0,1.0,2.0,5.0)") \
  15. .setTimePerSample(0.1)
  16. models = FtrlTrainStreamOp(model, None) \
  17. .setFeatureCols(["f0", "f1", "f2", "f3"]) \
  18. .setLabelCol("label") \
  19. .setTimeInterval(10) \
  20. .setAlpha(0.1) \
  21. .setBeta(0.1) \
  22. .setL1(0.1) \
  23. .setL2(0.1)\
  24. .setVectorSize(4)\
  25. .setWithIntercept(True) \
  26. .linkFrom(trainData1)
  27. FtrlPredictStreamOp(model) \
  28. .setPredictionCol("pred") \
  29. .setReservedCols(["label"]) \
  30. .setPredictionDetailCol("details") \
  31. .linkFrom(models, trainData1).print()
  32. StreamOperator.execute()

Java 代码

  1. package com.alibaba.alink.operator.stream.ml.onlinelearning;
  2. import com.alibaba.alink.operator.batch.BatchOperator;
  3. import com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp;
  4. import com.alibaba.alink.operator.batch.source.RandomTableSourceBatchOp;
  5. import com.alibaba.alink.operator.stream.StreamOperator;
  6. import com.alibaba.alink.operator.stream.onlinelearning.FtrlPredictStreamOp;
  7. import com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp;
  8. import com.alibaba.alink.operator.stream.source.RandomTableSourceStreamOp;
  9. import org.junit.Test;
  10. public class FtrlTrainTestTest {
  11. @Test
  12. public void FtrlClassification() throws Exception {
  13. StreamOperator.setParallelism(2);
  14. BatchOperator trainData0 = new RandomTableSourceBatchOp()
  15. .setNumCols(5)
  16. .setNumRows(100L)
  17. .setOutputCols(new String[]{"f0", "f1", "f2", "f3", "label"})
  18. .setOutputColConfs("label:weight_set(1.0,1.0,2.0,5.0)");
  19. BatchOperator model = new LogisticRegressionTrainBatchOp()
  20. .setFeatureCols(new String[]{"f0", "f1", "f2", "f3"})
  21. .setLabelCol("label")
  22. .setMaxIter(10).linkFrom(trainData0);
  23. StreamOperator trainData1 = new RandomTableSourceStreamOp()
  24. .setNumCols(5)
  25. .setMaxRows(1000L)
  26. .setOutputCols(new String[]{"f0", "f1", "f2", "f3", "label"})
  27. .setOutputColConfs("label:weight_set(1.0,1.0,2.0,5.0)")
  28. .setTimePerSample(0.1);
  29. StreamOperator smodel = new FtrlTrainStreamOp(model)
  30. .setFeatureCols(new String[]{"f0", "f1", "f2", "f3"})
  31. .setLabelCol("label")
  32. .setTimeInterval(10)
  33. .setAlpha(0.1)
  34. .setBeta(0.1)
  35. .setL1(0.1)
  36. .setL2(0.1)
  37. .setVectorSize(4)
  38. .setWithIntercept(true)
  39. .linkFrom(trainData1);
  40. new FtrlPredictStreamOp(model)
  41. .setPredictionCol("pred")
  42. .setReservedCols(new String[]{"label"})
  43. .setPredictionDetailCol("details")
  44. .linkFrom(smodel, trainData1).print();
  45. StreamOperator.execute();
  46. }
  47. }

运行结果

| label | pred | details | | —- | —- | —- |

| 2.0000 | 2.0000 | {“2.0”:”0.8407811313273308”,”1.0”:”0.1592188686726692”} |

| 2.0000 | 2.0000 | {“2.0”:”0.8094960632541983”,”1.0”:”0.19050393674580168”} |

| 2.0000 | 2.0000 | {“2.0”:”0.8685396820088952”,”1.0”:”0.1314603179911048”} |

| 2.0000 | 2.0000 | {“2.0”:”0.781050184076571”,”1.0”:”0.218949815923429”} |

| 1.0000 | 2.0000 | {“2.0”:”0.8347637657816113”,”1.0”:”0.16523623421838873”} |

| 2.0000 | 2.0000 | {“2.0”:”0.9211808843291631”,”1.0”:”0.07881911567083688”} |