Java 类名:com.alibaba.alink.operator.batch.regression.BertTextPairRegressorTrainBatchOp
Python 类名:BertTextPairRegressorTrainBatchOp

功能介绍

在预训练的 BERT 模型的基础上增加一个全连接层,用于进行文本对回归任务。

参数说明

| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 | | —- | —- | —- | —- | —- | —- | —- |

| labelCol | 标签列名 | 输入表中的标签列名 | String | ✓ | | |

| textCol | 文本列 | 文本列 | String | ✓ | 所选列类型为 [STRING] | |

| textPairCol | 文本对列 | 文本对列 | String | ✓ | 所选列类型为 [STRING] | |

| batchSize | 数据批大小 | 数据批大小 | Integer | | | 32 |

| bertModelName | BERT模型名字 | BERT模型名字: Base-Chinese,Base-Multilingual-Cased,Base-Uncased,Base-Cased | String | | | “Base-Chinese” |

| checkpointFilePath | 保存 checkpoint 的路径 | 用于保存中间结果的路径,将作为 TensorFlow 中 Estimatormodel_dir 传入,需要为所有 worker 都能访问到的目录 | String | | | null |

| customConfigJson | 自定义参数 | 对应 https://github.com/alibaba/EasyTransfer/blob/master/easytransfer/app_zoo/app_config.py 中的config_json | String | | | |

| intraOpParallelism | Op 间并发度 | Op 间并发度 | Integer | | | 4 |

| learningRate | 学习率 | 学习率 | Double | | | 0.001 |

| maxSeqLength | 句子截断长度 | 句子截断长度 | Integer | | | 128 |

| numEpochs | epoch 数 | epoch 数 | Double | | | 0.01 |

| numFineTunedLayers | 微调层数 | 微调层数 | Integer | | | 1 |

| numPSs | PS 角色数 | PS 角色的数量。值未设置时,如果 Worker 角色数也未设置,则为作业总并发度的 1/4(需要取整),否则为总并发度减去 Worker 角色数。 | Integer | | | null |

| numWorkers | Worker 角色数 | Worker 角色的数量。值未设置时,如果 PS 角色数也未设置,则为作业总并发度的 3/4(需要取整),否则为总并发度减去 PS 角色数。 | Integer | | | null |

| pythonEnv | Python 环境路径 | Python 环境路径,一般情况下不需要填写。如果是压缩文件,需要解压后得到一个目录,且目录名与压缩文件主文件名一致,可以使用 http://, https://, oss://, hdfs:// 等路径;如果是目录,那么只能使用本地路径,即 file://。 | String | | | “” |

| removeCheckpointBeforeTraining | 是否在训练前移除 checkpoint 相关文件 | 是否在训练前移除 checkpoint 相关文件用于重新训练,只会删除必要的文件 | Boolean | | | null |

代码示例

以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!

Python 代码

  1. url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv"
  2. schemaStr = "f_quality double, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string"
  3. data = CsvSourceBatchOp() \
  4. .setFilePath(url) \
  5. .setSchemaStr(schemaStr) \
  6. .setFieldDelimiter("\t") \
  7. .setIgnoreFirstLine(True) \
  8. .setQuoteChar(None)
  9. data = ShuffleBatchOp().linkFrom(data)
  10. train = BertTextPairRegressorTrainBatchOp() \
  11. .setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality") \
  12. .setNumEpochs(0.1) \
  13. .setMaxSeqLength(32) \
  14. .setNumFineTunedLayers(1) \
  15. .setBertModelName("Base-Uncased") \
  16. .linkFrom(data)
  17. AkSinkBatchOp() \
  18. .setFilePath("/tmp/bert_text_pair_regressor_model.ak") \
  19. .setOverwriteSink(True) \
  20. .linkFrom(train)
  21. BatchOperator.execute()

Java 代码

  1. import com.alibaba.alink.operator.batch.BatchOperator;
  2. import com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp;
  3. import com.alibaba.alink.operator.batch.regression.BertTextPairRegressorTrainBatchOp;
  4. import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp;
  5. import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
  6. import org.junit.Test;
  7. public class BertTextPairRegressorTrainBatchOpTest {
  8. @Test
  9. public void testBertTextPairRegressorTrainBatchOp() throws Exception {
  10. String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
  11. String schemaStr = "f_quality double, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
  12. BatchOperator <?> data = new CsvSourceBatchOp()
  13. .setFilePath(url)
  14. .setSchemaStr(schemaStr)
  15. .setFieldDelimiter("\t")
  16. .setIgnoreFirstLine(true)
  17. .setQuoteChar(null);
  18. data = new ShuffleBatchOp().linkFrom(data);
  19. BertTextPairRegressorTrainBatchOp train = new BertTextPairRegressorTrainBatchOp()
  20. .setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality")
  21. .setNumEpochs(0.1)
  22. .setMaxSeqLength(32)
  23. .setNumFineTunedLayers(1)
  24. .setBertModelName("Base-Uncased")
  25. .linkFrom(data);
  26. new AkSinkBatchOp()
  27. .setFilePath("/tmp/bert_text_pair_regressor_model.ak")
  28. .setOverwriteSink(true)
  29. .linkFrom(train);
  30. BatchOperator.execute();
  31. }
  32. }