多标签的数据格式与单标签类似,只是需要加上—multi_label参数。

PAI模型训练和评估

ODPS数据准备

首先可以下载 训练集评估集,其中 train_multilabel_zh.csv , dev_multilabel_zh.csv 是用\t 分隔的 .csv 文件。
我们定义这两个字段为 content_seq,label
我们对各数据创建表,并将相应的数据上传到 ODPS 上:

  1. drop table if exists modelzoo_example_train;
  2. create table modelzoo_example_train(content_seq STRING, label STRING);
  3. tunnel upload train_multilabel_zh.csv modelzoo_example_train -fd '\t';
  4. drop table if exists modelzoo_example_dev;
  5. create table modelzoo_example_dev(content_seq STRING, label STRING);
  6. tunnel upload train_multilabel_zh.csv modelzoo_example_dev -fd '\t';

模型训练

配置一些环境参数:

export train_table=odps://${project_name}/tables/your_train_table_name
export dev_table=odps://${project_name}/tables/your_dev_table_name
export saved_model_dir=oss://path/to/your_model/
export oss_bucket_name=your_bucket_name
export access_key_id=your_access_id
export access_key_secret=your_access_key_secret
export host=your_host

模型训练的PAI命令如下:

pai -name easytexminer
 -project algo_platform_dev
 -Dmode=train
 -DinputTable=${train_table},${dev_table}
 -DfirstSequence=content_seq
 -DlabelName=label
 -DlabelEnumerateValues=体积大小,外观,制热范围,制热效果,衣物烘干,味道,产品功耗,滑轮提手,声音
 -DsequenceLength=64
 -DappName=text_classify
 -DcheckpointDir=${saved_model_dir}
 -DlearningRate=3e-5
 -DnumEpochs=3
 -DsaveCheckpointSteps=50
 -DbatchSize=32
 -DworkerCount=1
 -DworkerGPU=1
 -DpretrainedModelNameOrPath=hfl/chinese-roberta-wwm-ext
 -DuserDefinedParameters='--multi_label'
 -Dbuckets="oss://${oss_bucket_name}?access_key_id=${access_key_id}&access_key_secret=${access_key_secret}&host=${host}";

本地数据准备和预测

可以下载 训练集评估集,其中 train_multilabel_zh.csv , dev_multilabel_zh.csv 是用\t 分隔的 .csv 文件。

模型训练

  easytexminer \
    --mode=train \
    --worker_gpu=1 \
    --tables=train_multilabel_zh.csv,dev_multilabel_zh.csv \
    --input_schema=content_seq:str:1,label:str:1 \
    --first_sequence=content_seq \
    --label_name=label \
    --label_enumerate_values=体积大小,外观,制热范围,制热效果,衣物烘干,味道,产品功耗,滑轮提手,声音 \
    --pretrained_model_name_or_path=hfl/chinese-roberta-wwm-ext \
    --checkpoint_dir=./multi_label_classification_model \
    --learning_rate=3e-5  \
    --epoch_num=3  \
    --seed=42 \
    --save_checkpoint_steps=50 \
    --sequence_length=128 \
    --micro_batch_size=32 \
    --app_name=text_classify \
    --multi_label

模型评估

  easytexminer \
      --mode=evaluate \
      --worker_gpu=1 \
      --tables=dev_multilabel_zh.csv \
      --input_schema=content_seq:str:1,label:str:1 \
      --first_sequence=content_seq \
      --label_name=label \
      --label_enumerate_values=体积大小,外观,制热范围,制热效果,衣物烘干,味道,产品功耗,滑轮提手,声音 \
      --checkpoint_dir=./multi_label_classification_model \
      --sequence_length=128 \
      --micro_batch_size=32 \
      --app_name=text_classify \
      --multi_label

模型预测

 easytexminer \
    --mode=predict \
    --worker_gpu=1 \
    --tables=dev_multilabel_zh.csv \
    --outputs=dev_multilabel_zh.pred.csv \
    --input_schema=content_seq:str:1,label:str:1 \
    --output_schema=predictions,probabilities,logits,output \
    --append_cols=label \
    --first_sequence=content_seq \
    --checkpoint_path=./multi_label_classification_model \
    --micro_batch_size=32 \
    --sequence_length=128 \
    --app_name=text_classify \
    --multi_label