算法简介

BERT 文本匹配采用BERT类模型训练模型,输入两个句子,输出是否匹配。BERT文本匹配本质上是一个双句分类的任务,可以复用文本分类的配置,在输入中做个调整,输入两个句子即可。模型如下所示:
image.png

可视化配置参数

【输入桩配置】

输入桩(从左到右) 限制数据类型 建议上游组件 是否必选
训练数据 odps 读数据表odps
测试数据 odps 读数据表odps

【右侧参数表单】

image.png
字段设置:

参数名称 参数描述 取值类型 必选,默认值
第一文本列选择 第一个文本序列在输入格式中对应的列名 string类型 必选
第二文本列选择 第二个文本序列在输入格式中对应的列名 string类型 必选
标签列选择 标签对应的列名 string类型 必选
标签枚举值 需要枚举出所有标签,一般为0,1 string类型 必选
模型存储路径 模型checkpoint的存储路径 string类型 必选

image.pngimage.png
参数设置:

参数名称 参数描述 取值类型 必选,默认值
模型选择 预训练模型名 string 可选,默认为’text_classify_bert’,此外还支持非bert模型: text_classify_cnn, text_classify_dgcnn
优化器类型 优化器选择 string 可选,默认为’adam’
batchSize 特征提取批大小 int 可选,默认为256
sequenceLength 序列整体最大长度 int 可选,默认为128,范围为1~512
numEpochs 训练的轮次 int 可选,默认为2
学习率 优化器的学习率 float 可选,默认为1e-5
模型额外参数 额外的参数,比方说修改预训练模型等 string 可选,可以修改预训练模型,比方说:pretrain_model_name_or_path=pai-bert-base-zh,
其他模型详见:https://yuque.antfin-inc.com/pai/transfer-learning/uugdk2

执行调优:

参数名称 参数描述 取值类型 必选,默认值
指定Worker数 worker的数量 int 可选,默认为3个Worker
指定Worker的GPU卡数 每个worker的GPU卡数 int 可选,标识是否使用GPU。默认是2张卡
指定Worker的CPU卡数 每个worker的CPU核数 int 可选,标识是否使用GPU。默认是4张卡。
分布式策略 定义分布式策略 MirroredStrategy
或者:
ExascaleStrategy
必须,单机单卡或者单机多卡选
MirroredStrategy
多机多卡选
ExascaleStrategy

【输出桩配置】

输出桩 限制数据类型 建议下游组件 是否必选
结果数据 oss 写oss数据

PAI命令及说明

1. PAI命令

  1. pai -name easy_transfer_app_ext
  2. -Dmode=train
  3. -DmodelName=text_match_bert
  4. -DinputTable=odps://${your_project}/tables/${train},odps://${your_project}/tables/${dev}
  5. -DfirstSequence=query1
  6. -DsecondSequence=query2
  7. -DlabelName=is_same_question
  8. -DlabelEnumerateValues=0,1
  9. -DsequenceLength=64
  10. -DcheckpointDir=oss://${your_bucket}/${your_path}
  11. -DbatchSize=32
  12. -DnumEpochs=1
  13. -DoptimizerType=adam
  14. -DlearningRate=2e-5
  15. -DuserDefinedParameters=' pretrain_model_name_or_path=pai-bert-base-zh'
  16. -Dbuckets=oss://${your_bucket}/
  17. -Darn=${your_role_arn}
  18. -DossHost=${your_host}

2. 参数说明

参数名称 是否必选 参数描述 类型 默认值
mode 必选 模式,包括三种:
- train(训练)
- evaluate(评测)
- predict(预测)
STRING
modelName 必选 模型名字,和应用一一对应,包括:
- text_classify_bert(文本分类)
- text_match_bert(文本匹配)
- sequence_labeling_bert(序列标注)
STRING
inputTable 必选 输入odps表名 STRING
firstSequence 必选 文本序列在输入表中对应的列名 STRING
secondSequence 必选 文本序列在输入表中对应的列名 STRING
labelName 必选 分类标签对应的列名 STRING
labelEnumerateValues 必选 需要枚举出所有标签值 STRING
sequenceLength 必选 序列整体最大长度 BIGINT
checkpointDir 必选 模型checkpoint的存储路径,比方说:
oss://easynlp-sh/text_match/
STRING
batchSize 必选 批大小 BIGINT
numEpochs 必选 训练的轮次 BIGINT
optimizerType 必选 优化器,例如adam STRING
learningRate 必选 优化器的学习率,例如3e-5
userDefinedParameters 必选 额外的参数,比方说修改预训练模型:
“ pretrain_model_name_or_path=pai-bert-base-zh”
STRING
buckets 必选 需要鉴权的oss bucket,和
checkpointDir对应,比方说oss://easynlp-sh/
STRING
arn 必选 用户的arn配置 STRING
ossHost 必选 用户的bucket对应的oss host STRING

3、输出结果

运行的开始之后就可以打开logview看运行的进度,从stderr可以看到运行的状态:
image.png
这里的PAI命令里有个-DcheckpointDir参数,填写的就是输出模型的checkpointDir。运行完之后可以通过oss console来看到checkpointDir里存储的信息,示例输出结果如下:
image.png
包括如下内容:

  • 模型中间结果:avg_loss是训练的loss,eval是评测的结果,variables是模型参数,其他的为模型的checkpoint和meta信息
  • 部署的模型:deployment放的是可以部署的模型,可以直接对接PAI EAS的服务。详见:https://help.aliyun.com/document_detail/113696.html

    支持计算资源


    【MaxCompute】

具体示例

首先可以下载 训练集评估集,其中 train.csv , dev.csv 是用\t 分隔的 .csv 文件。
我们定义这五个字段为 label,sid1,sid2,sent1,sent2
我们对各数据创建表,并将相应的数据上传到 ODPS 上:

drop table if exists modelzoo_example_train;
create table modelzoo_example_train(label STRING, sid1 STRING, sid2 STRING, sent1 STRING,sent2 STRING);
tunnel upload train.tsv modelzoo_example_train -fd '\t';

drop table if exists modelzoo_example_dev;
create table modelzoo_example_dev(label STRING, sid1 STRING, sid2 STRING, sent1 STRING,sent2 STRING);
tunnel upload dev.tsv modelzoo_example_dev -fd '\t';

参考以上可视化配置参数。运行组件即可获得结果。

image.png