Java 类名:com.alibaba.alink.operator.batch.regression.BertTextPairRegressorPredictBatchOp
Python 类名:BertTextPairRegressorPredictBatchOp

功能介绍

与 BERT 文本对回归训练组件对应的预测组件。

参数说明

| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 | | —- | —- | —- | —- | —- | —- | —- |

| predictionCol | 预测结果列名 | 预测结果列名 | String | ✓ | | |

| inferBatchSize | 推理数据批大小 | 推理数据批大小 | Integer | | | 256 |

| reservedCols | 算法保留列名 | 算法保留列 | String[] | | | null |

代码示例

以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!

Python 代码

  1. url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv"
  2. schemaStr = "f_quality double, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string"
  3. data = CsvSourceBatchOp() \
  4. .setFilePath(url) \
  5. .setSchemaStr(schemaStr) \
  6. .setFieldDelimiter("\t") \
  7. .setIgnoreFirstLine(True) \
  8. .setQuoteChar(None)
  9. data = data.firstN(300)
  10. model = CsvSourceBatchOp() \
  11. .setFilePath("http://alink-test.oss-cn-beijing.aliyuncs.com/jiqi-temp/tf_ut_files/bert_text_pair_regressor_model.csv") \
  12. .setSchemaStr("model_id bigint, model_info string, label_value double")
  13. predict = BertTextPairRegressorPredictBatchOp() \
  14. .setPredictionCol("pred") \
  15. .setReservedCols(["f_quality"]) \
  16. .linkFrom(model, data)
  17. predict.print()

Java 代码

  1. import com.alibaba.alink.operator.batch.BatchOperator;
  2. import com.alibaba.alink.operator.batch.regression.BertTextPairRegressorPredictBatchOp;
  3. import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
  4. import org.junit.Test;
  5. public class BertTextPairRegressorPredictBatchOpTest {
  6. @Test
  7. public void test() throws Exception {
  8. String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
  9. String schemaStr = "f_quality double, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
  10. BatchOperator <?> data = new CsvSourceBatchOp()
  11. .setFilePath(url)
  12. .setSchemaStr(schemaStr)
  13. .setFieldDelimiter("\t")
  14. .setIgnoreFirstLine(true)
  15. .setQuoteChar(null);
  16. data = data.firstN(300);
  17. BatchOperator <?> model = new CsvSourceBatchOp()
  18. .setFilePath("http://alink-test.oss-cn-beijing.aliyuncs.com/jiqi-temp/tf_ut_files/bert_text_pair_regressor_model.csv")
  19. .setSchemaStr("model_id bigint, model_info string, label_value double");
  20. BertTextPairRegressorPredictBatchOp predict = new BertTextPairRegressorPredictBatchOp()
  21. .setPredictionCol("pred")
  22. .setReservedCols("f_quality")
  23. .linkFrom(model, data);
  24. predict.print();
  25. }
  26. }

运行结果

| f_quality | pred | | —- | —- |

| 0.0 | 1.404307 |

| 0.0 | 1.404307 |

| 1.0 | 1.404307 |

| 0.0 | 1.404307 |

| 1.0 | 1.404307 |

| … | … |

| 0.0 | 1.404392 |

| 1.0 | 1.404392 |

| 0.0 | 1.404392 |

| 1.0 | 1.404392 |

| 1.0 | 1.404392 |