GEEP模型推理加速简介

BERT是一种典型的预训练语言模型,由于其在下游自然语言处理(NLP)任务中的显著改进而受到广泛关注。复杂的结构和大量的参数给BERT带来了竞争性能,但也导致模型推理速度较慢。为了提高BERT推理的速度,FastBERT在知识蒸馏和早期退出技术的基础上实现了自适应推理,推理精度降低到可以接受的程度。然而,许多因素限制了FastBERT分类器的性能,如教师分类器知识不够,批大小的缩水和子分类器的冗余计算。为了克服这些局限性,我们提出了一种基于GPU-Efficient Exit Prediction (GEEP)的BERT推理方法。在12个公开的中英文自然语言处理数据集上的实验结果证明了该方法的有效性。

GEEP算法概述&训练流程

GEEP算法

GEEP利用Shared Exit Loss将FastBERT的训练过程从两步简化为一步,通过向教师分类器提供不同的Transformer输出,使教师分类器知识更丰富,从而指导出更好的学生分类器以实现Early Exit。
image.png
加速效果如下,通过调节geep_threshold参数(0~1之间,0不加速,1最大加速),大约可以获得1~12倍之间的加速,加速比越大,准确率下降也越大,但一般不超过15%。用户可根据自己的数据调试geep_threshold参数,取得加速与效果之间的平衡。
image.png

GEEP算法的微调

GEEP模型在fine-tune阶段使用与BERT相同,目前仅支持文本分类任务:

  1. export CUDA_VISIBLE_DEVICES=$1
  2. if [ ! -f ./train.tsv ]; then
  3. wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/classification/train.tsv
  4. fi
  5. if [ ! -f ./dev.tsv ]; then
  6. wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/classification/dev.tsv
  7. fi
  8. mode=$2
  9. if [ "$mode" = "train" ]; then
  10. easynlp \
  11. --mode $mode \
  12. --worker_gpu=1 \
  13. --tables=train.tsv,dev.tsv \
  14. --input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \
  15. --first_sequence=sent1 \
  16. --second_sequence=sent2 \
  17. --label_name=label \
  18. --label_enumerate_values=0,1 \
  19. --checkpoint_dir=./classification_model \
  20. --learning_rate=3e-5 \
  21. --epoch_num=10 \
  22. --random_seed=42 \
  23. --save_checkpoint_steps=50 \
  24. --sequence_length=128 \
  25. --micro_batch_size=32 \
  26. --app_name=geep_classify \
  27. --user_defined_parameters='
  28. geep_exit_num=8
  29. pretrain_model_name_or_path=geep-base-uncased
  30. '
  31. elif [ "$mode" = "evaluate" ]; then
  32. easynlp \
  33. --mode=$mode \
  34. --worker_gpu=1 \
  35. --tables=dev.tsv \
  36. --input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \
  37. --first_sequence=sent1 \
  38. --second_sequence=sent2 \
  39. --label_name=label \
  40. --label_enumerate_values=0,1 \
  41. --checkpoint_dir=./classification_model \
  42. --sequence_length=128 \
  43. --micro_batch_size=32 \
  44. --app_name=geep_classify \
  45. --user_defined_parameters='
  46. geep_threshold=0.3
  47. '
  48. elif [ "$mode" = "predict" ]; then
  49. easynlp \
  50. --mode=$mode \
  51. --worker_gpu=1 \
  52. --tables=dev.tsv \
  53. --outputs=dev.pred.tsv \
  54. --input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \
  55. --output_schema=predictions,probabilities,logits,output \
  56. --append_cols=label \
  57. --first_sequence=sent1 \
  58. --second_sequence=sent2 \
  59. --checkpoint_path=./classification_model \
  60. --micro_batch_size=32 \
  61. --sequence_length=128 \
  62. --app_name=geep_classify \
  63. --user_defined_parameters='
  64. geep_threshold=0.3
  65. '
  66. fi