Java 类名:com.alibaba.alink.operator.batch.classification.BertTextPairClassifierPredictBatchOp
Python 类名:BertTextPairClassifierPredictBatchOp

功能介绍

与 BERT 文本对分类训练组件对应的预测组件。

参数说明

| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 | | —- | —- | —- | —- | —- | —- | —- |

| predictionCol | 预测结果列名 | 预测结果列名 | String | ✓ | | |

| inferBatchSize | 推理数据批大小 | 推理数据批大小 | Integer | | | 256 |

| predictionDetailCol | 预测详细信息列名 | 预测详细信息列名 | String | | | |

| reservedCols | 算法保留列名 | 算法保留列 | String[] | | | null |

代码示例

以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!

Python 代码

  1. url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv"
  2. schemaStr = "f_quality bigint, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string"
  3. data = CsvSourceBatchOp() \
  4. .setFilePath(url) \
  5. .setSchemaStr(schemaStr) \
  6. .setFieldDelimiter("\t") \
  7. .setIgnoreFirstLine(True) \
  8. .setQuoteChar(None)
  9. data = data.firstN(300)
  10. model = CsvSourceBatchOp() \
  11. .setFilePath("http://alink-test.oss-cn-beijing.aliyuncs.com/jiqi-temp/tf_ut_files/bert_text_pair_classifier_model.csv") \
  12. .setSchemaStr("model_id bigint, model_info string, label_value bigint")
  13. predict = BertTextPairClassifierPredictBatchOp() \
  14. .setPredictionCol("pred") \
  15. .setPredictionDetailCol("pred_detail") \
  16. .setReservedCols(["f_quality"]) \
  17. .linkFrom(model, data)
  18. predict.print()

Java 代码

  1. import com.alibaba.alink.operator.batch.BatchOperator;
  2. import com.alibaba.alink.operator.batch.classification.BertTextPairClassifierPredictBatchOp;
  3. import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
  4. import org.junit.Test;
  5. public class BertTextPairClassifierPredictBatchOpTest {
  6. @Test
  7. public void testBertTextPairClassifierPredictBatchOpTest() throws Exception {
  8. String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
  9. String schemaStr = "f_quality bigint, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
  10. BatchOperator <?> data = new CsvSourceBatchOp()
  11. .setFilePath(url)
  12. .setSchemaStr(schemaStr)
  13. .setFieldDelimiter("\t")
  14. .setIgnoreFirstLine(true)
  15. .setQuoteChar(null);
  16. data = data.firstN(300);
  17. BatchOperator <?> model = new CsvSourceBatchOp()
  18. .setFilePath("http://alink-test.oss-cn-beijing.aliyuncs.com/jiqi-temp/tf_ut_files/bert_text_pair_classifier_model.csv")
  19. .setSchemaStr("model_id bigint, model_info string, label_value bigint");
  20. BertTextPairClassifierPredictBatchOp predict = new BertTextPairClassifierPredictBatchOp()
  21. .setPredictionCol("pred")
  22. .setPredictionDetailCol("pred_detail")
  23. .setReservedCols("f_quality")
  24. .linkFrom(model, data);
  25. predict.print();
  26. }
  27. }

运行结果

| f_quality | pred | pred_detail | | —- | —- | —- |

| 1 | 1 | {“0”:0.07034707069396973,”1”:0.9296529293060303} |

| 0 | 1 | {“0”:0.07034707069396973,”1”:0.9296529293060303} |

| 1 | 1 | {“0”:0.07034707069396973,”1”:0.9296529293060303} |

| 0 | 1 | {“0”:0.07034707069396973,”1”:0.9296529293060303} |

| 1 | 1 | {“0”:0.07034707069396973,”1”:0.9296529293060303} |

| … | … | … |

| 1 | 1 | {“0”:0.0704156756401062,”1”:0.9295843243598938} |

| 1 | 1 | {“0”:0.0704156756401062,”1”:0.9295843243598938} |

| 1 | 1 | {“0”:0.0704156756401062,”1”:0.9295843243598938} |

| 1 | 1 | {“0”:0.0704156756401062,”1”:0.9295843243598938} |

| 0 | 1 | {“0”:0.0704156756401062,”1”:0.9295843243598938} |