在这个教程中,我们所做的任务为MRPC数据上的文本匹配,因为在BERT中,文本匹配本质是一个双句文本分类,因此下面代码用文本分类的方法进行任务。本次教程依托PAI-PyTorch上开发的EasyNLP,用户仅需要配置好相关命令参数,改动少量代码就可以在PAI上跑BERT文本分类任务。

1. 代码详解

构造数据

  1. train_dataset = ClassificationDataset(
  2. pretrained_model_name_or_path=args.pretrained_model_name_or_path,
  3. data_file=args.tables,
  4. max_seq_length=args.sequence_length,
  5. input_schema=args.input_schema,
  6. first_sequence=args.first_sequence,
  7. label_name=args.label_name,
  8. label_enumerate_values=args.label_enumerate_values,
  9. is_training=True)

构造Application

  1. model = SequenceClassification(
  2. pretrained_model_name_or_path=args.pretrained_model_name_or_path)

调用Trainer训练

  1. Trainer(model=model, train_dataset=train_dataset).train()

2. 跑脚本

  1. cd EasyNLP/examples/quick_start/
  2. sh run_user_defined_local.sh

3. 过程详解

下载数据

  1. export CUDA_VISIBLE_DEVICES=0
  2. # Local training example
  3. # cur_path=/tmp/EasyNLP
  4. cur_path=/home/admin/workspace/EasyNLP/
  5. cd ${cur_path}
  6. if [ ! -f ./tmp/train.tsv ]; then
  7. wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/classification/train.tsv
  8. wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/classification/dev.tsv
  9. mkdir tmp/
  10. mv *.tsv tmp/
  11. fi

跑训练脚本

这里预训练模型是: bert-small-uncased

  1. DISTRIBUTED_ARGS="--nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 6009"
  2. python -m torch.distributed.launch $DISTRIBUTED_ARGS \
  3. examples/self_defined_examples/main.py \
  4. --mode train \
  5. --tables=tmp/train.tsv,tmp/dev.tsv \
  6. --input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \
  7. --first_sequence=sent1 \
  8. --second_sequence=sent2 \
  9. --label_name=label \
  10. --label_enumerate_values=0,1 \
  11. --checkpoint_dir=./tmp/classification_model/ \
  12. --learning_rate=3e-5 \
  13. --epoch_num=3 \
  14. --random_seed=42 \
  15. --logging_steps=1 \
  16. --save_checkpoint_steps=50 \
  17. --sequence_length=128 \
  18. --micro_batch_size=10 \
  19. --app_name=text_classify \
  20. --use_amp \
  21. --user_defined_parameters='pretrain_model_name_or_path=bert-small-uncased'