Java 类名:com.alibaba.alink.operator.batch.evaluation.EvalRegressionBatchOp
Python 类名:EvalRegressionBatchOp

功能介绍

回归评估是对回归算法的预测结果进行效果评估,支持下列评估指标。

SST 总平方和(Sum of Squared for Total)

SSE 误差平方和(Sum of Squares for Error)

SSR 回归平方和(Sum of Squares for Regression)

R^2 判定系数(Coefficient of Determination)

R 多重相关系数(Multiple Correlation Coeffient)

MSE 均方误差(Mean Squared Error)

RMSE 均方根误差(Root Mean Squared Error)

SAE/SAD 绝对误差(Sum of Absolute Error/Difference)

MAE/MAD 平均绝对误差(Mean Absolute Error/Difference)

MAPE 平均绝对百分误差(Mean Absolute Percentage Error)

count 行数

explained variance 解释方差

参数说明

| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 | | —- | —- | —- | —- | —- | —- | —- |

| labelCol | 标签列名 | 输入表中的标签列名 | String | ✓ | | |

| predictionCol | 预测结果列名 | 预测结果列名 | String | ✓ | | |

代码示例

Python 代码

  1. from pyalink.alink import *
  2. import pandas as pd
  3. useLocalEnv(1)
  4. df = pd.DataFrame([
  5. [0, 0],
  6. [8, 8],
  7. [1, 2],
  8. [9, 10],
  9. [3, 1],
  10. [10, 7]
  11. ])
  12. inOp = BatchOperator.fromDataframe(df, schemaStr='pred int, label int')
  13. metrics = EvalRegressionBatchOp().setPredictionCol("pred").setLabelCol("label").linkFrom(inOp).collectMetrics()
  14. print("Total Samples Number:", metrics.getCount())
  15. print("SSE:", metrics.getSse())
  16. print("SAE:", metrics.getSae())
  17. print("RMSE:", metrics.getRmse())
  18. print("R2:", metrics.getR2())

Java 代码

  1. import org.apache.flink.types.Row;
  2. import com.alibaba.alink.operator.batch.BatchOperator;
  3. import com.alibaba.alink.operator.batch.evaluation.EvalRegressionBatchOp;
  4. import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
  5. import com.alibaba.alink.operator.common.evaluation.RegressionMetrics;
  6. import org.junit.Test;
  7. import java.util.Arrays;
  8. import java.util.List;
  9. public class EvalRegressionBatchOpTest {
  10. @Test
  11. public void testEvalRegressionBatchOp() throws Exception {
  12. List <Row> df = Arrays.asList(
  13. Row.of(0, 0),
  14. Row.of(8, 8),
  15. Row.of(1, 2),
  16. Row.of(9, 10),
  17. Row.of(3, 1),
  18. Row.of(10, 7)
  19. );
  20. BatchOperator <?> inOp = new MemSourceBatchOp(df, "pred int, label int");
  21. RegressionMetrics metrics = new EvalRegressionBatchOp().setPredictionCol("pred").setLabelCol("label").linkFrom(
  22. inOp).collectMetrics();
  23. System.out.println("Total Samples Number:" + metrics.getCount());
  24. System.out.println("SSE:" + metrics.getSse());
  25. System.out.println("SAE:" + metrics.getSae());
  26. System.out.println("RMSE:" + metrics.getRmse());
  27. System.out.println("R2:" + metrics.getR2());
  28. }
  29. }

运行结果

  1. Total Samples Number: 6.0
  2. SSE: 15.0
  3. SAE: 7.0
  4. RMSE: 1.5811388300841898
  5. R2: 0.8282442748091603