RoBERT文本分类/匹配采用RoBERTa类训练模型,其流程与BERT相同,具体示意图参照BERT文本分类/匹配。

easynlp命令中选择text_classify,指定模型为RoBERTa即可调用模型。

本地数据准备和预测

首先可以下载 训练集评估集,其中 train.csv , dev.csv 是用\t 分隔的 .csv 文件。

  1. wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/easytexminer/tutorials/classification/train.tsv
  2. wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/easytexminer/tutorials/classification/dev.tsv

模型训练

  1. easynlp \
  2. --mode=train \
  3. --worker_gpu=1 \
  4. --tables=train.tsv,dev.tsv \
  5. --input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \
  6. --first_sequence=sent1 \
  7. --second_sequence=sent2 \
  8. --label_name=label \
  9. --label_enumerate_values=0,1 \
  10. --checkpoint_dir=./roberta_classification_model \
  11. --learning_rate=3e-5 \
  12. --epoch_num=3 \
  13. --random_seed=42 \
  14. --save_checkpoint_steps=50 \
  15. --sequence_length=128 \
  16. --micro_batch_size=32 \
  17. --app_name=text_classify \
  18. --user_defined_parameters="
  19. pretrain_model_name_or_path=roberta-base-en
  20. "

模型预测

  1. easynlp \
  2. --mode=predict \
  3. --worker_gpu=1 \
  4. --tables=dev.tsv \
  5. --outputs=dev.pred.tsv \
  6. --input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \
  7. --output_schema=predictions,probabilities,logits,output \
  8. --append_cols=label \
  9. --first_sequence=sent1 \
  10. --second_sequence=sent2 \
  11. --checkpoint_path=./roberta_classification_model \
  12. --micro_batch_size=32 \
  13. --sequence_length=128 \
  14. --app_name=text_classify

参数说明:

  • input_schema:数据集每列的名称
  • output_schema:需要输出的结果类型,默认有四种:predictions(预测结果),probabilities(预测的概率),logits(预测的logits,即softmax之前的值),output(输出值)
  • append_cols:需要append的输入数据的column,多个column可以用逗号分隔