BERT多标签文本分类采用BERT类模型训练模型,与BERT单标签文本分类不同,对单个样本可能会存在多个标签。多标签的数据格式与单标签类似,调用时只需要加上—multi_label参数。

本地数据准备和预测

可以下载 训练集评估集,其中 train_multilabel_zh.csv , dev_multilabel_zh.csv 是用\t 分隔的 .csv 文件。

模型训练

  1. easynlp \
  2. --mode=train \
  3. --worker_gpu=1 \
  4. --tables=train_multilabel_zh.csv,dev_multilabel_zh.csv \
  5. --input_schema=content_seq:str:1,label:str:1 \
  6. --first_sequence=content_seq \
  7. --label_name=label \
  8. --label_enumerate_values=体积大小,外观,制热范围,制热效果,衣物烘干,味道,产品功耗,滑轮提手,声音 \
  9. --checkpoint_dir=./multi_label_classification_model \
  10. --learning_rate=3e-5 \
  11. --epoch_num=3 \
  12. --random_seed=42 \
  13. --save_checkpoint_steps=50 \
  14. --sequence_length=128 \
  15. --micro_batch_size=32 \
  16. --app_name=text_classify \
  17. --user_defined_parameters='
  18. pretrain_model_name_or_path=hfl/chinese-roberta-wwm-ext
  19. multi_label=True
  20. '

模型评估

  1. easynlp \
  2. --mode=evaluate \
  3. --worker_gpu=1 \
  4. --tables=dev_multilabel_zh.csv \
  5. --input_schema=content_seq:str:1,label:str:1 \
  6. --first_sequence=content_seq \
  7. --label_name=label \
  8. --label_enumerate_values=体积大小,外观,制热范围,制热效果,衣物烘干,味道,产品功耗,滑轮提手,声音 \
  9. --checkpoint_dir=./multi_label_classification_model \
  10. --sequence_length=128 \
  11. --micro_batch_size=32 \
  12. --app_name=text_classify \
  13. --user_defined_parameters='
  14. multi_label=True
  15. '

模型预测

  1. easynlp \
  2. --mode=predict \
  3. --worker_gpu=1 \
  4. --tables=dev_multilabel_zh.csv \
  5. --outputs=dev_multilabel_zh.pred.csv \
  6. --input_schema=content_seq:str:1,label:str:1 \
  7. --output_schema=predictions,probabilities,logits,output \
  8. --append_cols=label \
  9. --first_sequence=content_seq \
  10. --checkpoint_path=./multi_label_classification_model \
  11. --micro_batch_size=32 \
  12. --sequence_length=128 \
  13. --app_name=text_classify \
  14. --user_defined_parameters='
  15. multi_label=True
  16. '