BERT序列标注

BERT序列标注是对序列中的每个token都当成一个多分类问题,如下所示
image.png
我们采用了Google原论文中的序列标注方法,即把word-piece后的第一个subtoken作为token-level分类器的输入,计算loss进行反向传播,其他的subtokens输出均会被mask

We use the representation of the first sub-token as the input to the token-level classifier over the NER label set


调用命令

  1. export CUDA_VISIBLE_DEVICES=$1
  2. if [ ! -f ./train.csv ]; then
  3. wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/sequence_labeling/train.csv
  4. fi
  5. if [ ! -f ./dev.csv ]; then
  6. wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/sequence_labeling/dev.csv
  7. fi
  8. if [ ! -f ./test.csv ]; then
  9. wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/sequence_labeling/test.csv
  10. fi
  11. mode=$2
  12. if [ "$mode" = "train" ]; then
  13. easynlp \
  14. --mode=train \
  15. --worker_gpu=1 \
  16. --tables=train.csv,dev.csv \
  17. --input_schema=content:str:1,label:str:1 \
  18. --first_sequence=content \
  19. --label_name=label \
  20. --label_enumerate_values=B-LOC,B-ORG,B-PER,I-LOC,I-ORG,I-PER,O \
  21. --checkpoint_dir=./labeling_model \
  22. --learning_rate=1e-4 \
  23. --epoch_num=1 \
  24. --logging_steps=100 \
  25. --save_checkpoint_steps=100 \
  26. --sequence_length=128 \
  27. --micro_batch_size=64 \
  28. --app_name=sequence_labeling \
  29. --user_defined_parameters='
  30. pretrain_model_name_or_path=hfl/chinese-roberta-wwm-ext
  31. '
  32. elif [ "$mode" = "evaluate" ]; then
  33. easynlp \
  34. --mode=evaluate \
  35. --worker_gpu=1 \
  36. --tables=dev.csv \
  37. --input_schema=content:str:1,label:str:1 \
  38. --first_sequence=content \
  39. --label_name=label \
  40. --label_enumerate_values=B-LOC,B-ORG,B-PER,I-LOC,I-ORG,I-PER,O \
  41. --checkpoint_path=./labeling_model \
  42. --sequence_length=128 \
  43. --micro_batch_size=32 \
  44. --app_name=sequence_labeling
  45. elif [ "$mode" = "predict" ]; then
  46. easynlp \
  47. --mode=predict \
  48. --worker_gpu=1 \
  49. --tables=test.csv \
  50. --outputs=test.pred.csv \
  51. --input_schema=content:str:1,label:str:1 \
  52. --first_sequence=content \
  53. --sequence_length=128 \
  54. --output_schema=output \
  55. --append_cols=label \
  56. --checkpoint_path=./labeling_model \
  57. --micro_batch_size=32 \
  58. --app_name=sequence_labeling
  59. fi