Java 类名:com.alibaba.alink.operator.batch.regression.BertTextRegressorPredictBatchOp
Python 类名:BertTextRegressorPredictBatchOp

功能介绍

与 BERT 文本回归训练组件对应的预测组件。

参数说明

| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 | | —- | —- | —- | —- | —- | —- | —- |

| predictionCol | 预测结果列名 | 预测结果列名 | String | ✓ | | |

| inferBatchSize | 推理数据批大小 | 推理数据批大小 | Integer | | | 256 |

| reservedCols | 算法保留列名 | 算法保留列 | String[] | | | null |

代码示例

以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!

Python 代码

  1. url = "http://alink-test.oss-cn-beijing.aliyuncs.com/jiqi-temp/tf_ut_files/ChnSentiCorp_htl_small.csv"
  2. schema = "label double, review string"
  3. data = CsvSourceBatchOp() \
  4. .setFilePath(url) \
  5. .setSchemaStr(schema) \
  6. .setIgnoreFirstLine(True)
  7. data = data.where("review is not null")
  8. data = data.firstN(300)
  9. model = CsvSourceBatchOp() \
  10. .setFilePath("http://alink-test.oss-cn-beijing.aliyuncs.com/jiqi-temp/tf_ut_files/bert_text_regressor_model.csv") \
  11. .setSchemaStr("model_id bigint, model_info string, label_value double")
  12. predict = BertTextRegressorPredictBatchOp() \
  13. .setPredictionCol("pred") \
  14. .setReservedCols(["label"]) \
  15. .linkFrom(model, data)
  16. predict.print()

Java 代码

  1. import com.alibaba.alink.operator.batch.BatchOperator;
  2. import com.alibaba.alink.operator.batch.regression.BertTextRegressorPredictBatchOp;
  3. import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
  4. import org.junit.Test;
  5. public class BertTextRegressorPredictBatchOpTest {
  6. @Test
  7. public void testBertTextRegressorPredictBatchOp() throws Exception {
  8. String url = "http://alink-test.oss-cn-beijing.aliyuncs.com/jiqi-temp/tf_ut_files/ChnSentiCorp_htl_small.csv";
  9. String schema = "label double, review string";
  10. BatchOperator <?> data = new CsvSourceBatchOp()
  11. .setFilePath(url)
  12. .setSchemaStr(schema)
  13. .setIgnoreFirstLine(true);
  14. data = data.where("review is not null");
  15. data = data.firstN(300);
  16. BatchOperator <?> model = new CsvSourceBatchOp()
  17. .setFilePath("http://alink-test.oss-cn-beijing.aliyuncs.com/jiqi-temp/tf_ut_files/bert_text_regressor_model.csv")
  18. .setSchemaStr("model_id bigint, model_info string, label_value double");
  19. BertTextRegressorPredictBatchOp predict = new BertTextRegressorPredictBatchOp()
  20. .setPredictionCol("pred")
  21. .setReservedCols("label")
  22. .linkFrom(model, data);
  23. predict.print();
  24. }
  25. }

运行结果

| label | pred | | —- | —- |

| 1.0 | 5.004022 |

| 1.0 | 5.004022 |

| 1.0 | 5.004022 |

| 1.0 | 5.004022 |

| 1.0 | 5.004022 |

| … | … |

| 0.0 | 5.004022 |

| 0.0 | 5.004022 |

| 0.0 | 5.004022 |

| 0.0 | 5.004022 |

| 0.0 | 5.004022 |