Java 类名:com.alibaba.alink.pipeline.classification.OneVsRest
Python 类名:OneVsRest

功能介绍

本组件用One VS Rest策略进行多分类。

参数说明

名称 中文名称 描述 类型 是否必须? 取值范围 默认值
numClass 类别数 多分类的类别数,必选 Integer
predictionCol 预测结果列名 预测结果列名 String
modelFilePath 模型的文件路径 模型的文件路径 String null
overwriteSink 是否覆写已有数据 是否覆写已有数据 Boolean false
predictionDetailCol 预测详细信息列名 预测详细信息列名 String
reservedCols 算法保留列名 算法保留列 String[] null
numThreads 组件多线程线程个数 组件多线程线程个数 Integer 1

代码示例

Python 代码

  1. from pyalink.alink import *
  2. import pandas as pd
  3. useLocalEnv(1)
  4. URL = "https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/iris.csv";
  5. SCHEMA_STR = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string";
  6. data = CsvSourceBatchOp().setFilePath(URL).setSchemaStr(SCHEMA_STR)
  7. lr = LogisticRegression() \
  8. .setFeatureCols(["sepal_length", "sepal_width", "petal_length", "petal_width"]) \
  9. .setLabelCol("category") \
  10. .setPredictionCol("pred_result") \
  11. .setMaxIter(100)
  12. oneVsRest = OneVsRest().setClassifier(lr).setNumClass(3)
  13. model = oneVsRest.fit(data)
  14. model.setPredictionCol("pred_result").setPredictionDetailCol("pred_detail")
  15. model.transform(data).print()

Java 代码

  1. import com.alibaba.alink.operator.batch.BatchOperator;
  2. import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
  3. import com.alibaba.alink.pipeline.classification.LogisticRegression;
  4. import com.alibaba.alink.pipeline.classification.OneVsRest;
  5. import com.alibaba.alink.pipeline.classification.OneVsRestModel;
  6. import org.junit.Test;
  7. public class OneVsRestTest {
  8. @Test
  9. public void testOneVsRest() throws Exception {
  10. String URL = "https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/iris.csv";
  11. String SCHEMA_STR
  12. = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string";
  13. BatchOperator <?> data = new CsvSourceBatchOp().setFilePath(URL).setSchemaStr(SCHEMA_STR);
  14. LogisticRegression lr = new LogisticRegression()
  15. .setFeatureCols("sepal_length", "sepal_width", "petal_length", "petal_width")
  16. .setLabelCol("category")
  17. .setPredictionCol("pred_result")
  18. .setMaxIter(100);
  19. OneVsRest oneVsRest = new OneVsRest().setClassifier(lr).setNumClass(3);
  20. OneVsRestModel model = oneVsRest.fit(data);
  21. model.setPredictionCol("pred_result").setPredictionDetailCol("pred_detail");
  22. model.transform(data).print();
  23. }
  24. }

运行结果

| sepal_length | sepal_width | petal_length | petal_width | category | pred_result | pred_detail | | —- | —- | —- | —- | —- | —- | —- |

| 6.7000 | 3.1000 | 4.4000 | 1.4000 | Iris-versicolor | Iris-versicolor | {“Iris-versicolor”:0.9999890601537083,”Iris-virginica”:1.0939842119301402E-5,”Iris-setosa”:4.1724971938972156E-12} |

| 5.4000 | 3.0000 | 4.5000 | 1.5000 | Iris-versicolor | Iris-versicolor | {“Iris-versicolor”:0.9939699721610056,”Iris-virginica”:0.006030026623291463,”Iris-setosa”:1.2157029667713158E-9} |

| 5.4000 | 3.9000 | 1.7000 | 0.4000 | Iris-setosa | Iris-setosa | {“Iris-versicolor”:0.02236524089333592,”Iris-virginica”:0.0,”Iris-setosa”:0.9776347591066641} |

| 5.0000 | 3.4000 | 1.6000 | 0.4000 | Iris-setosa | Iris-setosa | {“Iris-versicolor”:0.07720412400682967,”Iris-virginica”:0.0,”Iris-setosa”:0.9227958759931704} |

| 5.6000 | 3.0000 | 4.5000 | 1.5000 | Iris-versicolor | Iris-versicolor | {“Iris-versicolor”:0.9961816818708689,”Iris-virginica”:0.003818317908880254,”Iris-setosa”:2.2025091271297693E-10} |

| … | … | … | … | … | … | … |