TextCNN文本分类是使用基于CNN网络的深度学习模型,输出为分类标签,如下所示:
image.png
easytexminer 命令中选择 --model-name=text_classify_cnn 即可调用这个模型。

模型参数详解

除了通用的AppZoo参数(详见完整PAI命令文档)以外,用户还需要输入以下参数以初始化模型。

参数名 参数描述 取值类型 默认值
tokenizer 分词词表,需要用户指定词表路径,或者指定’en’或’zh’以使用对应语言的BERT词表 string类型 None,需要用户自己指定词表路径
conv_dim 卷积层的输出维度 int类型 100
kernel_sizes 各卷积层的卷积尺寸 string类型 1,2,3,4,代表使用4个卷积层,每个卷积层的卷积尺寸分别为1,2,3,4
linear_hidden_size 卷积层后的全连接层尺寸 int类型 512
embed_size 输入词向量的维度 int类型 300

数据准备

数据准备阶段与BERT文本分类任务相同。需要准备训练集评估集为用\t 分隔的 .csv 文件,且不包含列名,样例如下。

  1. 1 702876 702977 Amrozi accused his brother, whom he called "the witness", of deliberately distorting his evidence. Referring to him as only "the witness", Amrozi accused his brother of deliberately distorting his evidence.
  2. 0 2108705 2108831 Yucaipa owned Dominick's before selling the chain to Safeway in 1998 for $2.5 billion. Yucaipa bought Dominick's in 1995 for $693 million and sold it to Safeway for $1.8 billion in 1998.
  3. 1 1330381 1330521 They had published an advertisement on the Internet on June 10, offering the cargo for sale, he added. On June 10, the ship's owners had published an advertisement on the Internet, offering the explosives for sale.
  4. 0 3344667 3344648 Around 0335 GMT, Tab shares were up 19 cents, or 4.4%, at A$4.56, having earlier set a record high of A$4.57. Tab shares jumped 20 cents, or 4.6%, to set a record closing high at A$4.57.

我们定义字段为 label,sid1,sid2,sent1,sent2
需要将数据集保存在本地,并且在后续命令的--table参数指定路径。

模型训练

需要注意的是目前不支持TextCNN模型导出Tensorflow SavedModel,所以需要保证输入参数中
--export_tf_checkpoint_type='none'

  1. easytexminer \
  2. --mode=train \
  3. --model_name='text_classify_cnn' \
  4. --tables=train.tsv,dev.tsv \
  5. --input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \
  6. --first_sequence=sent1 \
  7. --second_sequence=sent2 \
  8. --label_name=label \
  9. --label_enumerate_values=0,1 \
  10. --tokenizer=vocab.txt \
  11. --checkpoint_dir=./cnn_model/ \
  12. --learning_rate=1e-5 \
  13. --epoch_num=30 \
  14. --logging_steps=100 \
  15. --sequence_length=128 \
  16. --train_batch_size=32 \
  17. --conv_dim=100 \
  18. --kernel_sizes=1,2,3,4 \
  19. --linear_hidden_size=512 \
  20. --embed_size=300 --export_tf_checkpoint_type='none'

模型评估

easytexminer \
  --mode=evaluate \
  --model_name='text_classify_cnn' \
  --tables=dev.tsv \
  --input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \
  --first_sequence=sent1 \
  --second_sequence=sent2 \
  --label_name=label \
  --label_enumerate_values=0,1 \
  --checkpoint_path=./cnn_model/ \

模型预测

easytexminer \
  --mode=predict \
  --model_name='text_classify_cnn' \
  --tables=test.tsv \
  --outputs=test.pred.tsv \
  --input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \
  --output_schema=predictions,probabilities,logits,output \
  --append_cols=label \
  --first_sequence=sent1 \
  --second_sequence=sent2 \
  --checkpoint_path=./cnn_model/ \
  --batch_size 32