在这个教程中,我们所做的任务为MRPC数据上的文本匹配,因为在BERT中,文本匹配本质是一个双句文本分类,因此下面代码用文本分类的方法进行任务。本次教程依托PAI-PyTorch上开发的EasyTexMiner,用户仅需要配置好相关命令参数,改动少量代码就可以在PAI上跑BERT文本分类任务。
1. 准备工作
下载MRPC的 训练集 和 验证集 ,上传到ODPS,字段如下:
2. 安装odpscmd
odpscmd的客户端安装教程如下:https://help.aliyun.com/document_detail/27971.html
一般有三种odpscmd使用方式:
使用方式1
直接在odpscmd的客户端里操作,可以在terminal里输入odpscmd,启动之后会有如下界面:
$/odps_clt_release_64/bin/odpscmd__ _____ ___/ /___ ___ ____ __ _ ___/ // _ \/ _ // _ \ (_-</ __// ' \/ _ /\___/\_,_// .__//___/\__//_/_/_/\_,_//_/Aliyun ODPS Command Line ToolVersion 0.35.4@Copyright 2019 Alibaba Cloud Computing Co., Ltd. All rights reserved.Connecting to your project...Connected!odps@ project_name > read project_name.table_name 10;
可以在这个客户端里输入odps命令,比方说:read project_name.table_name 10;
其中project_name换成你的项目名字,table_name换成odps表名。
使用方式2
可以用odpscmd的客户端直接执行odps命令,示例如下:
$/odps_clt_release_64/bin/odpscmd -e "read project_name.table_name 10;"
使用方式3
可以用odpscmd的客户端直接执行odps命令文件,假设文件为odps_cmd.txt, 示例如下:
$/odps_clt_release_64/bin/odpscmd -f odps_cmd.txt
3. 自定义代码
我们首先导入需要的模块:
import torch.nn as nn
from easytexminer import model_zoo, modules, losses
from easytexminer.core import Evaluator
from easytexminer.core.trainer import Trainer
from easytexminer.data import BertClassificationDataset
from easytexminer.utils import config, init_running_envs, get_dir_name, distributed_call_main
2.1 定义模型
自定义的模型需要继承layers.BaseModel,并继承from_pretrained的方法,能够从oss上初始化相应的参数,一个文本分类的模型如下:
class BertTextClassify(modules.nn.BaseModel):
""" BERT Classification/Regression Teacher """
def __init__(self, config, **kwargs):
super(BertTextClassify, self).__init__(config)
self.model_name = "text_classify_bert"
self.bert = model_zoo.BertModel(config)
num_labels = kwargs.pop("num_labels", None)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.apply(self.init_model_weights)
def forward(self, inputs):
sequence_output, att_output, pooled_output = \
self.bert(inputs["input_ids"],
inputs["segment_ids"],
inputs["input_mask"],
output_all_encoded_layers=True, output_att=True)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return {
"logits": logits
}
def compute_loss(self, model_outputs, inputs):
logits = model_outputs["logits"]
label_ids = inputs["label_ids"]
return {
"loss": losses.cross_entropy(logits, label_ids)
}
- 需要定义forward函数:
输入为inputs:数据预处理后的结果,对于BertClassificationDataset数据集,dict中包含了
- input_ids: torch.Tensor(batch_size, sequence_length)
- input_mask: torch.Tensor(batch_size, sequence_length)
- segment_ids: torch.Tensor(batch_size, sequence_length)
输出dict 需要包含:
- logits:分类层输出的logits
- 需要定义compute_loss函数,这里的loss用户可以用任意的PyTorch相关的loss
输入为:
- model_outputs:forward后出来的dict
- inputs:与forward的inputs一致
输出dict 需要包含:
- logits:最后输出的loss
2.2 定义训练/评估
def main_fn(gpu, cfg, *args, **kwargs):
# Prepare seed / logging / gpu environment
init_running_envs(gpu, cfg)
print("User Defined Example.")
vocab_file = get_dir_name(cfg.pretrain_model_name_or_path) + "/vocab.txt" \
if cfg.pretrain_model_name_or_path else get_dir_name(cfg.checkpoint_dir) + "/vocab.txt"
pretrain_model_name_or_path = cfg.pretrain_model_name_or_path if cfg.pretrain_model_name_or_path \
else cfg.checkpoint_dir
valid_dataset = BertClassificationDataset(
model_type="text_classify_bert",
data_file=cfg.tables.split(",")[-1],
vocab_file=vocab_file,
max_seq_length=cfg.sequence_length,
input_schema=cfg.input_schema,
first_sequence=cfg.first_sequence,
second_sequence=cfg.second_sequence,
label_name=cfg.label_name,
label_enumerate_values=cfg.label_enumerate_values,
is_training=False)
model = BertTextClassify.from_pretrained(
pretrained_model_name_or_path=pretrain_model_name_or_path,
num_labels=len(valid_dataset.label_enumerate_values))
if cfg.mode == "train":
# Build Data Loader
train_dataset = BertClassificationDataset(model_type="text_classify_bert",
data_file=cfg.tables.split(",")[0],
vocab_file=vocab_file,
max_seq_length=cfg.sequence_length,
input_schema=cfg.input_schema,
first_sequence=cfg.first_sequence,
second_sequence=cfg.second_sequence,
label_name=cfg.label_name,
label_enumerate_values=cfg.label_enumerate_values,
is_training=True)
# Training
trainer = Trainer(model=model, train_dataset=train_dataset, valid_dataset=valid_dataset, cfg=cfg)
trainer.train()
elif cfg.mode == "evaluate":
evaluator = Evaluator(metrics=valid_dataset.eval_metrics)
evaluator.evaluate(model=model, valid_dataset=valid_dataset, eval_batch_size=cfg.eval_batch_size)
if __name__ == "__main__":
parser = config.add_basic_argument()
cfg = parser.parse_args()
distributed_call_main(main_fn=main_fn, cfg=cfg)
这里用到的config的配置可以参考 EasyTexMiner完整命令详解 ,本教程包括了:
- pretrain_model_name_or_path:预训练BERT路径,可以为空,为空即为随机初始化
- checkpoint_dir:待训练模型路径
BertClassificationDataset是EasyTexMiner中专为文本分类抽象的数据集,用户也可以参考这个链接自定义自己的数据集。构建该数据集所需要的参数为:
- tables:训练表、评估表
- input_schema:参考EasyTransfer 中inputSchema的配置,读ODPS表时可为空
- first_sequence: 待分类文本序列1所在列名
- second_sequence:待分类文本序列2所在列名,可为空
- label_name:标签列所在列名
- label_enumerate_values:标签的枚举值,用逗号分隔
- sequence_length:所需分类的序列长度
2.3 PAI 上提交任务
- 训练 ```bash export proj=pai_exp_dev export oss_buckets=’oss://your_bucket?access_key_id=xxx&access_key_secret=xxx&host=xxx’ export train_table=odps://${proj}/tables/easytexminer_mrpc_train export dev_table=odps://${proj}/tables/easytexminer_mrpc_dev export pretrain_model_ckpt=oss://path/to/bert_base_uncased/bert_model.ckpt export model_dir=oss://path/to/tmp_model/
Steo 3:
training
tar -zcvf quick_start.tar.gz main.py ossutil cp -f quick_start.tar.gz ${model_dir} rm -f quick_start.tar.gz
command=” pai -name easytexminer -project algo_platform_dev -Dscript=${model_dir}quick_start.tar.gz -DentryFile=main.py -Dmode=train -DinputTable=${train_table},${dev_table} -DuserDefinedParameters=’ —first_sequence=sent1 —second_sequence=sent2 —label_name=quality —label_enumerate_values=0,1 —pretrain_model_name_or_path=${pretrain_model_ckpt} —checkpoint_dir=${model_dir} —learning_rate=3e-5 —epoch_num=3 —logging_steps=100 —save_checkpoint_steps=50 —sequence_length=128 —train_batch_size=32 ‘ -DworkerCount=1 -DworkerGPU=1 -Dbuckets=’${oss_buckets}’
“
echo “${command}” odpscmd -e “${command}” echo “finish…” rm -f easytexminer_quick_start.tar.gz
2.评估
```bash
pai -name easytexminer
-project algo_platform_dev
-Dscript=${model_dir}quick_start.tar.gz
-DentryFile=main.py
-Dmode=evaluate
-DinputTable=${dev_table}
-DuserDefinedParameters='
--first_sequence=sent1
--second_sequence=sent2
--label_name=quality
--label_enumerate_values=0,1
--checkpoint_dir=${model_dir}
--sequence_length=128
--eval_batch_size=32
'
-DworkerCount=1
-DworkerGPU=1
-Dbuckets='${oss_buckets}'
3.预测
pai -name easytexminer
-project algo_platform_dev
-Dscript=${model_dir}quick_start.tar.gz
-DentryFile=main.py
-Dmode=predict
-DinputTable=oss://easytransfer-new/161093/GLUE/MRPC/dev.tsv
-DinputSchema=quality:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1
-DoutputTable=${model_dir}dev.pred.tsv
-DoutputSchema=pred,prob,logit,output
-DappendCols=quality
-DuserDefinedParameters='
--first_sequence=sent1
--second_sequence=sent2
--checkpoint_dir=${model_dir}
--sequence_length=128
--batch_size=32
'
-DworkerCount=1
-DworkerGPU=1
-Dbuckets='${oss_buckets}'
