RoBERT文本分类/匹配采用RoBERTa类训练模型,其流程与BERT相同,具体示意图参照BERT文本分类/匹配。
在easynlp命令中选择text_classify,指定模型为RoBERTa即可调用模型。
本地数据准备和预测
首先可以下载 训练集 和 评估集,其中 train.csv , dev.csv 是用\t 分隔的 .csv 文件。
wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/easytexminer/tutorials/classification/train.tsvwget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/easytexminer/tutorials/classification/dev.tsv
模型训练
easynlp \--mode=train \--worker_gpu=1 \--tables=train.tsv,dev.tsv \--input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \--first_sequence=sent1 \--second_sequence=sent2 \--label_name=label \--label_enumerate_values=0,1 \--checkpoint_dir=./roberta_classification_model \--learning_rate=3e-5 \--epoch_num=3 \--random_seed=42 \--save_checkpoint_steps=50 \--sequence_length=128 \--micro_batch_size=32 \--app_name=text_classify \--user_defined_parameters="pretrain_model_name_or_path=roberta-base-en"
模型预测
easynlp \--mode=predict \--worker_gpu=1 \--tables=dev.tsv \--outputs=dev.pred.tsv \--input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \--output_schema=predictions,probabilities,logits,output \--append_cols=label \--first_sequence=sent1 \--second_sequence=sent2 \--checkpoint_path=./roberta_classification_model \--micro_batch_size=32 \--sequence_length=128 \--app_name=text_classify
参数说明:
- input_schema:数据集每列的名称
- output_schema:需要输出的结果类型,默认有四种:predictions(预测结果),probabilities(预测的概率),logits(预测的logits,即softmax之前的值),output(输出值)
- append_cols:需要append的输入数据的column,多个column可以用逗号分隔
