算法简介
BERT 文本匹配采用BERT类模型训练模型,输入两个句子,输出是否匹配。BERT文本匹配本质上是一个双句分类的任务,可以复用文本分类的配置,在输入中做个调整,输入两个句子即可。模型如下所示:
可视化配置参数
【输入桩配置】
| 输入桩(从左到右) | 限制数据类型 | 建议上游组件 | 是否必选 |
|---|---|---|---|
| 训练数据 | odps | 读数据表odps | 是 |
| 测试数据 | odps | 读数据表odps | 是 |
【右侧参数表单】

字段设置:
| 参数名称 | 参数描述 | 取值类型 | 必选,默认值 |
|---|---|---|---|
| 第一文本列选择 | 第一个文本序列在输入格式中对应的列名 | string类型 | 必选 |
| 第二文本列选择 | 第二个文本序列在输入格式中对应的列名 | string类型 | 必选 |
| 标签列选择 | 标签对应的列名 | string类型 | 必选 |
| 标签枚举值 | 需要枚举出所有标签,一般为0,1 | string类型 | 必选 |
| 模型存储路径 | 模型checkpoint的存储路径 | string类型 | 必选 |


参数设置:
| 参数名称 | 参数描述 | 取值类型 | 必选,默认值 |
|---|---|---|---|
| 模型选择 | 预训练模型名 | string | 可选,默认为’text_classify_bert’,此外还支持非bert模型: text_classify_cnn, text_classify_dgcnn |
| 优化器类型 | 优化器选择 | string | 可选,默认为’adam’ |
| batchSize | 特征提取批大小 | int | 可选,默认为256 |
| sequenceLength | 序列整体最大长度 | int | 可选,默认为128,范围为1~512 |
| numEpochs | 训练的轮次 | int | 可选,默认为2 |
| 学习率 | 优化器的学习率 | float | 可选,默认为1e-5 |
| 模型额外参数 | 额外的参数,比方说修改预训练模型等 | string | 可选,可以修改预训练模型,比方说:pretrain_model_name_or_path=pai-bert-base-zh, 其他模型详见:https://yuque.antfin-inc.com/pai/transfer-learning/uugdk2 |
执行调优:
| 参数名称 | 参数描述 | 取值类型 | 必选,默认值 |
|---|---|---|---|
| 指定Worker数 | worker的数量 | int | 可选,默认为3个Worker |
| 指定Worker的GPU卡数 | 每个worker的GPU卡数 | int | 可选,标识是否使用GPU。默认是2张卡 |
| 指定Worker的CPU卡数 | 每个worker的CPU核数 | int | 可选,标识是否使用GPU。默认是4张卡。 |
| 分布式策略 | 定义分布式策略 | MirroredStrategy 或者: ExascaleStrategy |
必须,单机单卡或者单机多卡选 MirroredStrategy 多机多卡选 ExascaleStrategy |
【输出桩配置】
| 输出桩 | 限制数据类型 | 建议下游组件 | 是否必选 |
|---|---|---|---|
| 结果数据 | oss | 写oss数据 | 否 |
PAI命令及说明
1. PAI命令
pai -name easy_transfer_app_ext-Dmode=train-DmodelName=text_match_bert-DinputTable=odps://${your_project}/tables/${train},odps://${your_project}/tables/${dev}-DfirstSequence=query1-DsecondSequence=query2-DlabelName=is_same_question-DlabelEnumerateValues=0,1-DsequenceLength=64-DcheckpointDir=oss://${your_bucket}/${your_path}-DbatchSize=32-DnumEpochs=1-DoptimizerType=adam-DlearningRate=2e-5-DuserDefinedParameters=' pretrain_model_name_or_path=pai-bert-base-zh'-Dbuckets=oss://${your_bucket}/-Darn=${your_role_arn}-DossHost=${your_host}
2. 参数说明
| 参数名称 | 是否必选 | 参数描述 | 类型 | 默认值 |
|---|---|---|---|---|
| mode | 必选 | 模式,包括三种: - train(训练) - evaluate(评测) - predict(预测) |
STRING | 无 |
| modelName | 必选 | 模型名字,和应用一一对应,包括: - text_classify_bert(文本分类) - text_match_bert(文本匹配) - sequence_labeling_bert(序列标注) |
STRING | 无 |
| inputTable | 必选 | 输入odps表名 | STRING | 无 |
| firstSequence | 必选 | 文本序列在输入表中对应的列名 | STRING | 无 |
| secondSequence | 必选 | 文本序列在输入表中对应的列名 | STRING | 无 |
| labelName | 必选 | 分类标签对应的列名 | STRING | 无 |
| labelEnumerateValues | 必选 | 需要枚举出所有标签值 | STRING | 无 |
| sequenceLength | 必选 | 序列整体最大长度 | BIGINT | 无 |
| checkpointDir | 必选 | 模型checkpoint的存储路径,比方说: oss://easynlp-sh/text_match/ |
STRING | 无 |
| batchSize | 必选 | 批大小 | BIGINT | 无 |
| numEpochs | 必选 | 训练的轮次 | BIGINT | 无 |
| optimizerType | 必选 | 优化器,例如adam | STRING | 无 |
| learningRate | 必选 | 优化器的学习率,例如3e-5 | 无 | |
| userDefinedParameters | 必选 | 额外的参数,比方说修改预训练模型: “ pretrain_model_name_or_path=pai-bert-base-zh” |
STRING | 无 |
| buckets | 必选 | 需要鉴权的oss bucket,和 checkpointDir对应,比方说oss://easynlp-sh/ |
STRING | 无 |
| arn | 必选 | 用户的arn配置 | STRING | 无 |
| ossHost | 必选 | 用户的bucket对应的oss host | STRING | 无 |
3、输出结果
运行的开始之后就可以打开logview看运行的进度,从stderr可以看到运行的状态:
这里的PAI命令里有个-DcheckpointDir参数,填写的就是输出模型的checkpointDir。运行完之后可以通过oss console来看到checkpointDir里存储的信息,示例输出结果如下:
包括如下内容:
- 模型中间结果:avg_loss是训练的loss,eval是评测的结果,variables是模型参数,其他的为模型的checkpoint和meta信息
- 部署的模型:deployment放的是可以部署的模型,可以直接对接PAI EAS的服务。详见:https://help.aliyun.com/document_detail/113696.html
支持计算资源
【MaxCompute】
具体示例
首先可以下载 训练集 和 评估集,其中 train.csv , dev.csv 是用\t 分隔的 .csv 文件。
我们定义这五个字段为 label,sid1,sid2,sent1,sent2。
我们对各数据创建表,并将相应的数据上传到 ODPS 上:
drop table if exists modelzoo_example_train;
create table modelzoo_example_train(label STRING, sid1 STRING, sid2 STRING, sent1 STRING,sent2 STRING);
tunnel upload train.tsv modelzoo_example_train -fd '\t';
drop table if exists modelzoo_example_dev;
create table modelzoo_example_dev(label STRING, sid1 STRING, sid2 STRING, sent1 STRING,sent2 STRING);
tunnel upload dev.tsv modelzoo_example_dev -fd '\t';
参考以上可视化配置参数。运行组件即可获得结果。

