在这个教程中,我们所做的任务为MRPC数据上的文本匹配,因为在BERT中,文本匹配本质是一个双句文本分类,因此下面代码用文本分类的方法进行任务。本次教程依托PAI-PyTorch上开发的EasyTexMiner,用户仅需要配置好相关命令参数,改动少量代码就可以在PAI上跑BERT文本分类任务。

1. 准备工作

下载MRPC的 训练集验证集 ,上传到ODPS,字段如下:
image.png

2. 安装odpscmd

odpscmd的客户端安装教程如下:https://help.aliyun.com/document_detail/27971.html
一般有三种odpscmd使用方式:

使用方式1

直接在odpscmd的客户端里操作,可以在terminal里输入odpscmd,启动之后会有如下界面:

  1. $/odps_clt_release_64/bin/odpscmd
  2. __ __
  3. ___ ___/ /___ ___ ____ __ _ ___/ /
  4. / _ \/ _ // _ \ (_-</ __// ' \/ _ /
  5. \___/\_,_// .__//___/\__//_/_/_/\_,_/
  6. /_/
  7. Aliyun ODPS Command Line Tool
  8. Version 0.35.4
  9. @Copyright 2019 Alibaba Cloud Computing Co., Ltd. All rights reserved.
  10. Connecting to your project...
  11. Connected!
  12. odps@ project_name > read project_name.table_name 10;

可以在这个客户端里输入odps命令,比方说:read project_name.table_name 10;
其中project_name换成你的项目名字,table_name换成odps表名。

使用方式2

可以用odpscmd的客户端直接执行odps命令,示例如下:

$/odps_clt_release_64/bin/odpscmd -e "read project_name.table_name 10;"

使用方式3

可以用odpscmd的客户端直接执行odps命令文件,假设文件为odps_cmd.txt, 示例如下:

$/odps_clt_release_64/bin/odpscmd -f odps_cmd.txt

3. 自定义代码

我们首先导入需要的模块:

import torch.nn as nn
from easytexminer import model_zoo, modules, losses
from easytexminer.core import Evaluator
from easytexminer.core.trainer import Trainer
from easytexminer.data import BertClassificationDataset
from easytexminer.utils import config, init_running_envs, get_dir_name, distributed_call_main

2.1 定义模型

自定义的模型需要继承layers.BaseModel,并继承from_pretrained的方法,能够从oss上初始化相应的参数,一个文本分类的模型如下:

class BertTextClassify(modules.nn.BaseModel):
    """ BERT Classification/Regression Teacher """

    def __init__(self, config, **kwargs):
        super(BertTextClassify, self).__init__(config)
        self.model_name = "text_classify_bert"
        self.bert = model_zoo.BertModel(config)
        num_labels = kwargs.pop("num_labels", None)

        self.classifier = nn.Linear(config.hidden_size, num_labels)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.apply(self.init_model_weights)

    def forward(self, inputs):
        sequence_output, att_output, pooled_output = \
            self.bert(inputs["input_ids"],
                      inputs["segment_ids"],
                      inputs["input_mask"],
                      output_all_encoded_layers=True, output_att=True)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return {
            "logits": logits
        }

    def compute_loss(self, model_outputs, inputs):
        logits = model_outputs["logits"]
        label_ids = inputs["label_ids"]
        return {
            "loss": losses.cross_entropy(logits, label_ids)
        }
  1. 需要定义forward函数

输入为inputs:数据预处理后的结果,对于BertClassificationDataset数据集,dict中包含了

  • input_ids: torch.Tensor(batch_size, sequence_length)
  • input_mask: torch.Tensor(batch_size, sequence_length)
  • segment_ids: torch.Tensor(batch_size, sequence_length)

输出dict 需要包含:

  • logits:分类层输出的logits
  1. 需要定义compute_loss函数,这里的loss用户可以用任意的PyTorch相关的loss

输入为:

  • model_outputsforward后出来的dict
  • inputs与forward的inputs一致

输出dict 需要包含:

  • logits:最后输出的loss

2.2 定义训练/评估

def main_fn(gpu, cfg, *args, **kwargs):
    # Prepare seed / logging / gpu environment
    init_running_envs(gpu, cfg)
    print("User Defined Example.")

    vocab_file = get_dir_name(cfg.pretrain_model_name_or_path) + "/vocab.txt" \
        if cfg.pretrain_model_name_or_path else get_dir_name(cfg.checkpoint_dir) + "/vocab.txt"
    pretrain_model_name_or_path = cfg.pretrain_model_name_or_path if cfg.pretrain_model_name_or_path \
        else cfg.checkpoint_dir

    valid_dataset = BertClassificationDataset(
        model_type="text_classify_bert",
        data_file=cfg.tables.split(",")[-1],
        vocab_file=vocab_file,
        max_seq_length=cfg.sequence_length,
        input_schema=cfg.input_schema,
        first_sequence=cfg.first_sequence,
        second_sequence=cfg.second_sequence,
        label_name=cfg.label_name,
        label_enumerate_values=cfg.label_enumerate_values,
        is_training=False)

    model = BertTextClassify.from_pretrained(
        pretrained_model_name_or_path=pretrain_model_name_or_path,
        num_labels=len(valid_dataset.label_enumerate_values))

    if cfg.mode == "train":
        # Build Data Loader
        train_dataset = BertClassificationDataset(model_type="text_classify_bert",
                                                  data_file=cfg.tables.split(",")[0],
                                                  vocab_file=vocab_file,
                                                  max_seq_length=cfg.sequence_length,
                                                  input_schema=cfg.input_schema,
                                                  first_sequence=cfg.first_sequence,
                                                  second_sequence=cfg.second_sequence,
                                                  label_name=cfg.label_name,
                                                  label_enumerate_values=cfg.label_enumerate_values,
                                                  is_training=True)
        # Training
        trainer = Trainer(model=model, train_dataset=train_dataset, valid_dataset=valid_dataset, cfg=cfg)
        trainer.train()
    elif cfg.mode == "evaluate":
        evaluator = Evaluator(metrics=valid_dataset.eval_metrics)
        evaluator.evaluate(model=model, valid_dataset=valid_dataset, eval_batch_size=cfg.eval_batch_size)


if __name__ == "__main__":
    parser = config.add_basic_argument()
    cfg = parser.parse_args()

    distributed_call_main(main_fn=main_fn, cfg=cfg)

这里用到的config的配置可以参考 EasyTexMiner完整命令详解 ,本教程包括了:

  • pretrain_model_name_or_path:预训练BERT路径,可以为空,为空即为随机初始化
  • checkpoint_dir:待训练模型路径

BertClassificationDataset是EasyTexMiner中专为文本分类抽象的数据集,用户也可以参考这个链接自定义自己的数据集。构建该数据集所需要的参数为:

  • tables:训练表、评估表
  • input_schema:参考EasyTransfer 中inputSchema的配置,读ODPS表时可为空
  • first_sequence: 待分类文本序列1所在列名
  • second_sequence:待分类文本序列2所在列名,可为空
  • label_name:标签列所在列名
  • label_enumerate_values:标签的枚举值,用逗号分隔
  • sequence_length:所需分类的序列长度

2.3 PAI 上提交任务

  1. 训练 ```bash export proj=pai_exp_dev export oss_buckets=’oss://your_bucket?access_key_id=xxx&access_key_secret=xxx&host=xxx’ export train_table=odps://${proj}/tables/easytexminer_mrpc_train export dev_table=odps://${proj}/tables/easytexminer_mrpc_dev export pretrain_model_ckpt=oss://path/to/bert_base_uncased/bert_model.ckpt export model_dir=oss://path/to/tmp_model/

Steo 3:

training

tar -zcvf quick_start.tar.gz main.py ossutil cp -f quick_start.tar.gz ${model_dir} rm -f quick_start.tar.gz

command=” pai -name easytexminer -project algo_platform_dev -Dscript=${model_dir}quick_start.tar.gz -DentryFile=main.py -Dmode=train -DinputTable=${train_table},${dev_table} -DuserDefinedParameters=’ —first_sequence=sent1 —second_sequence=sent2 —label_name=quality —label_enumerate_values=0,1 —pretrain_model_name_or_path=${pretrain_model_ckpt} —checkpoint_dir=${model_dir} —learning_rate=3e-5 —epoch_num=3 —logging_steps=100 —save_checkpoint_steps=50 —sequence_length=128 —train_batch_size=32 ‘ -DworkerCount=1 -DworkerGPU=1 -Dbuckets=’${oss_buckets}’

echo “${command}” odpscmd -e “${command}” echo “finish…” rm -f easytexminer_quick_start.tar.gz

2.评估
```bash
pai -name easytexminer
 -project algo_platform_dev
 -Dscript=${model_dir}quick_start.tar.gz
 -DentryFile=main.py
 -Dmode=evaluate
 -DinputTable=${dev_table}
 -DuserDefinedParameters='
    --first_sequence=sent1
    --second_sequence=sent2
    --label_name=quality
    --label_enumerate_values=0,1
    --checkpoint_dir=${model_dir}
    --sequence_length=128
    --eval_batch_size=32
 '
 -DworkerCount=1
 -DworkerGPU=1
 -Dbuckets='${oss_buckets}'

3.预测

pai -name easytexminer
 -project algo_platform_dev
 -Dscript=${model_dir}quick_start.tar.gz
 -DentryFile=main.py
 -Dmode=predict
 -DinputTable=oss://easytransfer-new/161093/GLUE/MRPC/dev.tsv
 -DinputSchema=quality:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1
 -DoutputTable=${model_dir}dev.pred.tsv
 -DoutputSchema=pred,prob,logit,output
 -DappendCols=quality
 -DuserDefinedParameters='
    --first_sequence=sent1
    --second_sequence=sent2
    --checkpoint_dir=${model_dir}
    --sequence_length=128
    --batch_size=32
 '
 -DworkerCount=1
 -DworkerGPU=1
 -Dbuckets='${oss_buckets}'