BERT多标签文本分类采用BERT类模型训练模型,与BERT单标签文本分类不同,对单个样本可能会存在多个标签。多标签的数据格式与单标签类似,调用时只需要加上—multi_label参数。
本地数据准备和预测
可以下载 训练集 和 评估集,其中 train_multilabel_zh.csv , dev_multilabel_zh.csv 是用\t 分隔的 .csv 文件。
模型训练
easynlp \--mode=train \--worker_gpu=1 \--tables=train_multilabel_zh.csv,dev_multilabel_zh.csv \--input_schema=content_seq:str:1,label:str:1 \--first_sequence=content_seq \--label_name=label \--label_enumerate_values=体积大小,外观,制热范围,制热效果,衣物烘干,味道,产品功耗,滑轮提手,声音 \--checkpoint_dir=./multi_label_classification_model \--learning_rate=3e-5 \--epoch_num=3 \--random_seed=42 \--save_checkpoint_steps=50 \--sequence_length=128 \--micro_batch_size=32 \--app_name=text_classify \--user_defined_parameters='pretrain_model_name_or_path=hfl/chinese-roberta-wwm-extmulti_label=True'
模型评估
easynlp \--mode=evaluate \--worker_gpu=1 \--tables=dev_multilabel_zh.csv \--input_schema=content_seq:str:1,label:str:1 \--first_sequence=content_seq \--label_name=label \--label_enumerate_values=体积大小,外观,制热范围,制热效果,衣物烘干,味道,产品功耗,滑轮提手,声音 \--checkpoint_dir=./multi_label_classification_model \--sequence_length=128 \--micro_batch_size=32 \--app_name=text_classify \--user_defined_parameters='multi_label=True'
模型预测
easynlp \--mode=predict \--worker_gpu=1 \--tables=dev_multilabel_zh.csv \--outputs=dev_multilabel_zh.pred.csv \--input_schema=content_seq:str:1,label:str:1 \--output_schema=predictions,probabilities,logits,output \--append_cols=label \--first_sequence=content_seq \--checkpoint_path=./multi_label_classification_model \--micro_batch_size=32 \--sequence_length=128 \--app_name=text_classify \--user_defined_parameters='multi_label=True'
