多标签的数据格式与单标签类似,只是需要加上—multi_label参数。
PAI模型训练和评估
ODPS数据准备
首先可以下载 训练集 和 评估集,其中 train_multilabel_zh.csv , dev_multilabel_zh.csv 是用\t 分隔的 .csv 文件。
我们定义这两个字段为 content_seq,label。
我们对各数据创建表,并将相应的数据上传到 ODPS 上:
drop table if exists modelzoo_example_train;create table modelzoo_example_train(content_seq STRING, label STRING);tunnel upload train_multilabel_zh.csv modelzoo_example_train -fd '\t';drop table if exists modelzoo_example_dev;create table modelzoo_example_dev(content_seq STRING, label STRING);tunnel upload train_multilabel_zh.csv modelzoo_example_dev -fd '\t';
模型训练
配置一些环境参数:
export train_table=odps://${project_name}/tables/your_train_table_name
export dev_table=odps://${project_name}/tables/your_dev_table_name
export saved_model_dir=oss://path/to/your_model/
export oss_bucket_name=your_bucket_name
export access_key_id=your_access_id
export access_key_secret=your_access_key_secret
export host=your_host
模型训练的PAI命令如下:
pai -name easytexminer
-project algo_platform_dev
-Dmode=train
-DinputTable=${train_table},${dev_table}
-DfirstSequence=content_seq
-DlabelName=label
-DlabelEnumerateValues=体积大小,外观,制热范围,制热效果,衣物烘干,味道,产品功耗,滑轮提手,声音
-DsequenceLength=64
-DappName=text_classify
-DcheckpointDir=${saved_model_dir}
-DlearningRate=3e-5
-DnumEpochs=3
-DsaveCheckpointSteps=50
-DbatchSize=32
-DworkerCount=1
-DworkerGPU=1
-DpretrainedModelNameOrPath=hfl/chinese-roberta-wwm-ext
-DuserDefinedParameters='--multi_label'
-Dbuckets="oss://${oss_bucket_name}?access_key_id=${access_key_id}&access_key_secret=${access_key_secret}&host=${host}";
本地数据准备和预测
可以下载 训练集 和 评估集,其中 train_multilabel_zh.csv , dev_multilabel_zh.csv 是用\t 分隔的 .csv 文件。
模型训练
easytexminer \
--mode=train \
--worker_gpu=1 \
--tables=train_multilabel_zh.csv,dev_multilabel_zh.csv \
--input_schema=content_seq:str:1,label:str:1 \
--first_sequence=content_seq \
--label_name=label \
--label_enumerate_values=体积大小,外观,制热范围,制热效果,衣物烘干,味道,产品功耗,滑轮提手,声音 \
--pretrained_model_name_or_path=hfl/chinese-roberta-wwm-ext \
--checkpoint_dir=./multi_label_classification_model \
--learning_rate=3e-5 \
--epoch_num=3 \
--seed=42 \
--save_checkpoint_steps=50 \
--sequence_length=128 \
--micro_batch_size=32 \
--app_name=text_classify \
--multi_label
模型评估
easytexminer \
--mode=evaluate \
--worker_gpu=1 \
--tables=dev_multilabel_zh.csv \
--input_schema=content_seq:str:1,label:str:1 \
--first_sequence=content_seq \
--label_name=label \
--label_enumerate_values=体积大小,外观,制热范围,制热效果,衣物烘干,味道,产品功耗,滑轮提手,声音 \
--checkpoint_dir=./multi_label_classification_model \
--sequence_length=128 \
--micro_batch_size=32 \
--app_name=text_classify \
--multi_label
模型预测
easytexminer \
--mode=predict \
--worker_gpu=1 \
--tables=dev_multilabel_zh.csv \
--outputs=dev_multilabel_zh.pred.csv \
--input_schema=content_seq:str:1,label:str:1 \
--output_schema=predictions,probabilities,logits,output \
--append_cols=label \
--first_sequence=content_seq \
--checkpoint_path=./multi_label_classification_model \
--micro_batch_size=32 \
--sequence_length=128 \
--app_name=text_classify \
--multi_label
