Java 类名:com.alibaba.alink.operator.stream.regression.BertTextPairRegressorPredictStreamOp
Python 类名:BertTextPairRegressorPredictStreamOp

功能介绍

与 BERT 文本对回归训练组件对应的预测组件。

参数说明

| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 | | —- | —- | —- | —- | —- | —- | —- |

| predictionCol | 预测结果列名 | 预测结果列名 | String | ✓ | | |

| inferBatchSize | 推理数据批大小 | 推理数据批大小 | Integer | | | 256 |

| modelFilePath | 模型的文件路径 | 模型的文件路径 | String | | | null |

| reservedCols | 算法保留列名 | 算法保留列 | String[] | | | null |

| modelStreamFilePath | 模型流的文件路径 | 模型流的文件路径 | String | | | null |

| modelStreamScanInterval | 扫描模型路径的时间间隔 | 描模型路径的时间间隔,单位秒 | Integer | | | 10 |

| modelStreamStartTime | 模型流的起始时间 | 模型流的起始时间。默认从当前时刻开始读。使用yyyy-mm-dd hh:mm:ss.fffffffff格式,详见Timestamp.valueOf(String s) | String | | | null |

代码示例

以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!

Python 代码

  1. # If OOM encountered, uncomment the following line and/or use a smaller parallelism
  2. # get_java_class("System").setProperty("direct.reader.policy", "local_file")
  3. url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv"
  4. schemaStr = "f_quality double, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string"
  5. data = CsvSourceStreamOp() \
  6. .setFilePath(url) \
  7. .setSchemaStr(schemaStr) \
  8. .setFieldDelimiter("\t") \
  9. .setIgnoreFirstLine(True) \
  10. .setQuoteChar(None)
  11. model = CsvSourceBatchOp() \
  12. .setFilePath("http://alink-test.oss-cn-beijing.aliyuncs.com/jiqi-temp/tf_ut_files/bert_text_pair_regressor_model.csv") \
  13. .setSchemaStr("model_id bigint, model_info string, label_value double")
  14. predict = BertTextPairRegressorPredictStreamOp(model) \
  15. .setPredictionCol("pred") \
  16. .setReservedCols(["f_quality"]) \
  17. .linkFrom(data)
  18. predict.print()
  19. StreamOperator.execute()

Java 代码

  1. import com.alibaba.alink.common.io.directreader.DataBridgeGeneratorPolicy;
  2. import com.alibaba.alink.common.io.directreader.LocalFileDataBridgeGenerator;
  3. import com.alibaba.alink.operator.batch.BatchOperator;
  4. import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
  5. import com.alibaba.alink.operator.stream.StreamOperator;
  6. import com.alibaba.alink.operator.stream.regression.BertTextPairRegressorPredictStreamOp;
  7. import com.alibaba.alink.operator.stream.source.CsvSourceStreamOp;
  8. import org.junit.Test;
  9. public class BertTextPairRegressorPredictStreamOpTest {
  10. @Test
  11. public void test() throws Exception {
  12. StreamOperator.setParallelism(2); // a larger parallelism needs much more memory
  13. System.setProperty("direct.reader.policy",
  14. LocalFileDataBridgeGenerator.class.getAnnotation(DataBridgeGeneratorPolicy.class).policy());
  15. String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
  16. String schemaStr = "f_quality double, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
  17. StreamOperator <?> data = new CsvSourceStreamOp()
  18. .setFilePath(url)
  19. .setSchemaStr(schemaStr)
  20. .setFieldDelimiter("\t")
  21. .setIgnoreFirstLine(true)
  22. .setQuoteChar(null);
  23. BatchOperator <?> model = new CsvSourceBatchOp()
  24. .setFilePath("http://alink-test.oss-cn-beijing.aliyuncs.com/jiqi-temp/tf_ut_files/bert_text_pair_regressor_model.csv")
  25. .setSchemaStr("model_id bigint, model_info string, label_value double");
  26. BertTextPairRegressorPredictStreamOp predict = new BertTextPairRegressorPredictStreamOp(model)
  27. .setPredictionCol("pred")
  28. .setReservedCols("f_quality")
  29. .linkFrom(data);
  30. predict.print();
  31. StreamOperator.execute();
  32. }
  33. }

运行结果

| f_quality | pred | | —- | —- |

| 0.0000 | 1.4043 |

| 1.0000 | 1.4043 |

| 0.0000 | 1.4038 |

| 0.0000 | 1.4044 |

| 1.0000 | 1.4046 |

| … | … |