Java 类名:com.alibaba.alink.operator.batch.regression.KerasSequentialRegressorPredictBatchOp
Python 类名:KerasSequentialRegressorPredictBatchOp

功能介绍

与 KerasSequential 回归训练组件对应的预测组件。

参数说明

名称 中文名称 描述 类型 是否必须? 取值范围 默认值
predictionCol 预测结果列名 预测结果列名 String
inferBatchSize 推理数据批大小 推理数据批大小 Integer 256
reservedCols 算法保留列名 算法保留列 String[] null

代码示例

以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!

Python 代码

  1. source = CsvSourceBatchOp() \
  2. .setFilePath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/random_tensor.csv") \
  3. .setSchemaStr("tensor string, label double")
  4. source = ToTensorBatchOp() \
  5. .setSelectedCol("tensor") \
  6. .setTensorDataType("DOUBLE") \
  7. .setTensorShape([200, 3]) \
  8. .linkFrom(source)
  9. trainBatchOp = KerasSequentialRegressorTrainBatchOp() \
  10. .setTensorCol("tensor") \
  11. .setLabelCol("label") \
  12. .setLayers([
  13. "Conv1D(256, 5, padding='same', activation='relu')",
  14. "Conv1D(128, 5, padding='same', activation='relu')",
  15. "Dropout(0.1)",
  16. "MaxPooling1D(pool_size=8)",
  17. "Conv1D(128, 5, padding='same', activation='relu')",
  18. "Conv1D(128, 5, padding='same', activation='relu')",
  19. "Flatten()"
  20. ]) \
  21. .setOptimizer("Adam()") \
  22. .setNumEpochs(1) \
  23. .linkFrom(source)
  24. predictBatchOp = KerasSequentialRegressorPredictBatchOp() \
  25. .setPredictionCol("pred") \
  26. .setReservedCols(["label"]) \
  27. .linkFrom(trainBatchOp, source)
  28. predictBatchOp.lazyPrint(10)
  29. BatchOperator.execute()

Java 代码

  1. import com.alibaba.alink.operator.batch.BatchOperator;
  2. import com.alibaba.alink.operator.batch.dataproc.ToTensorBatchOp;
  3. import com.alibaba.alink.operator.batch.regression.KerasSequentialRegressorPredictBatchOp;
  4. import com.alibaba.alink.operator.batch.regression.KerasSequentialRegressorTrainBatchOp;
  5. import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
  6. import org.junit.Test;
  7. public class KerasSequentialRegressorTrainBatchOpTest {
  8. @Test
  9. public void testKerasSequentialRegressorTrainBatchOp() throws Exception {
  10. BatchOperator<?> source = new CsvSourceBatchOp()
  11. .setFilePath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/random_tensor.csv")
  12. .setSchemaStr("tensor string, label double");
  13. source = new ToTensorBatchOp()
  14. .setSelectedCol("tensor")
  15. .setTensorDataType("DOUBLE")
  16. .setTensorShape(200, 3)
  17. .linkFrom(source);
  18. KerasSequentialRegressorTrainBatchOp trainBatchOp = new KerasSequentialRegressorTrainBatchOp()
  19. .setTensorCol("tensor")
  20. .setLabelCol("label")
  21. .setLayers(new String[] {
  22. "Conv1D(256, 5, padding='same', activation='relu')",
  23. "Conv1D(128, 5, padding='same', activation='relu')",
  24. "Dropout(0.1)",
  25. "MaxPooling1D(pool_size=8)",
  26. "Conv1D(128, 5, padding='same', activation='relu')",
  27. "Conv1D(128, 5, padding='same', activation='relu')",
  28. "Flatten()"
  29. })
  30. .setOptimizer("Adam()")
  31. .setNumEpochs(1)
  32. .linkFrom(source);
  33. KerasSequentialRegressorPredictBatchOp predictBatchOp = new KerasSequentialRegressorPredictBatchOp()
  34. .setPredictionCol("pred")
  35. .setReservedCols("label")
  36. .linkFrom(trainBatchOp, source);
  37. predictBatchOp.lazyPrint(10);
  38. BatchOperator.execute();
  39. }
  40. }

运行结果

label pred
1.0000 0.4822
0.0000 0.4826
0.0000 0.4752
0.0000 0.4702
1.0000 0.4907
1.0000 0.4992
0.0000 0.4866
1.0000 0.5045
0.0000 0.4994
1.0000 0.4837