大模型在小样本数据上取得了不错的效果,但在很多实际场景中数据量不足的问题仍然制约着大模型的应用,如何提高预训练模型在小样本场景的泛化性还是个挑战。其次,大模型参数量太大导致训练和推理速度慢,严重影响到了需要较高QPS的线上场景,部署成本非常高,如何快速蒸馏出小模型也是个挑战。EasyNLP推出小样本学习功能,帮助用户在小样本场景快速训练一个效果好的模型落地。同时,EasyNLP支持知识蒸馏技术,可以将大模型压缩到小的高效的模型上线。
下面我们给出一个示例,将一个大的预训练模型(hfl/macbert-large-zh)在小样本场景上落地,并且蒸馏到小的模型上。如下图所示,一个大模型(3亿参数)在一个小样本场景上原始的Accuracy为83.8%,通过小样本学习可以提升7%,达到90.6%。同时,如果用一个小模型(3百万参数)跑这个场景的话,效果仅有54.4%,可以把效果提升到75%(提升约21%),inference的时间相比大模型提升了约80倍。

模型 参数量 Dev Set指标(Accuracy) Batch Inference时间
标准Finetune hfl/macbert-large-zh 325 Million 83.75% 3.22ms/sample
(batch_size=8)
标准Finetune alibaba-pai/pai-bert-tiny-zh 3 Million 54.38% 0.04ms/sample
(batch_size=64)
知识蒸馏Finetune alibaba-pai/pai-bert-tiny-zh 3 Million 75.21% 0.04ms/sample
(batch_size=64)
小样本Finetune hfl/macbert-large-zh 325 Million 90.63% 3.21ms/sample
(batch_size=8)

详细代码示例如下。

代码示例

数据准备

  1. wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/landing_plm/train.csv
  2. wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/landing_plm/dev.csv

小样本学习测试脚本

  1. easynlp \
  2. --app_name=text_classify \
  3. --mode=train \
  4. --worker_count=1 \
  5. --worker_gpu=1 \
  6. --tables=train.csv,dev.csv \
  7. --input_schema=text:str:1,label:str:1 \
  8. --first_sequence=text \
  9. --label_name=label \
  10. --label_enumerate_values=Positive,Negative \
  11. --checkpoint_dir=./fewshot_model/ \
  12. --learning_rate=1e-5 \
  13. --epoch_num=5 \
  14. --random_seed=42 \
  15. --save_checkpoint_steps=100 \
  16. --sequence_length=512 \
  17. --micro_batch_size=8 \
  18. --user_defined_parameters="
  19. pretrain_model_name_or_path=hfl/macbert-large-zh
  20. enable_fewshot=True
  21. label_desc=好,差
  22. type=pet_fewshot
  23. pattern=text,是一条商品,label,评。
  24. "

知识蒸馏测试脚本

  1. # train teacher
  2. easynlp \
  3. --app_name=text_classify \
  4. --mode=train \
  5. --worker_count=1 \
  6. --worker_gpu=1 \
  7. --tables=train.csv,dev.csv \
  8. --input_schema=text:str:1,label:str:1 \
  9. --first_sequence=text \
  10. --label_name=label \
  11. --label_enumerate_values=Positive,Negative \
  12. --checkpoint_dir=./teacher_model/ \
  13. --learning_rate=1e-5 \
  14. --epoch_num=5 \
  15. --random_seed=42 \
  16. --save_checkpoint_steps=100 \
  17. --sequence_length=128 \
  18. --micro_batch_size=8 \
  19. --user_defined_parameters="
  20. pretrain_model_name_or_path=hfl/macbert-large-zh
  21. "
  22. # data augmentation
  23. easynlp \
  24. --app_name=data_augmentation \
  25. --worker_count=1 \
  26. --worker_gpu=1 \
  27. --mode=predict \
  28. --tables=train.csv \
  29. --input_schema=text:str:1,label:str:1 \
  30. --first_sequence=text \
  31. --label_name=label \
  32. --outputs=aug.csv \
  33. --output_schema=augmented_data \
  34. --checkpoint_dir=_ \
  35. --sequence_length=128 \
  36. --micro_batch_size=8 \
  37. --user_defined_parameters="
  38. pretrain_model_name_or_path=hfl/macbert-large-zh
  39. type=mlm_da
  40. expansion_rate=10
  41. mask_proportion=0.25
  42. remove_blanks=True
  43. "
  44. # forward teacher logits
  45. easynlp \
  46. --mode=predict \
  47. --worker_count=1 \
  48. --worker_gpu=1 \
  49. --tables=aug.csv \
  50. --outputs=logits.csv \
  51. --input_schema=text:str:1,label:str:1 \
  52. --output_schema=logits \
  53. --first_sequence=text \
  54. --checkpoint_path=./teacher_model/ \
  55. --micro_batch_size=8 \
  56. --sequence_length=128 \
  57. --app_name=text_classify
  58. # train student w/ KD
  59. easynlp \
  60. --app_name=text_classify \
  61. --mode=train \
  62. --worker_count=1 \
  63. --worker_gpu=1 \
  64. --tables=aug.csv,dev.csv \
  65. --input_schema=text:str:1,label:str:1,logits:float:2 \
  66. --first_sequence=text \
  67. --label_name=label \
  68. --label_enumerate_values=Positive,Negative \
  69. --checkpoint_dir=./student_model/ \
  70. --learning_rate=1e-4 \
  71. --epoch_num=5 \
  72. --random_seed=42 \
  73. --save_checkpoint_steps=100 \
  74. --sequence_length=128 \
  75. --micro_batch_size=8 \
  76. --user_defined_parameters="
  77. pretrain_model_name_or_path=alibaba-pai/pai-bert-tiny-zh
  78. enable_distillation=True
  79. type=vanilla_kd
  80. logits_name=logits
  81. logits_saved_path=logits.csv
  82. temperature=1
  83. alpha=0.5
  84. "
  85. # train student w/o. KD
  86. easynlp \
  87. --app_name=text_classify \
  88. --mode=train \
  89. --worker_count=1 \
  90. --worker_gpu=1 \
  91. --tables=train.csv,dev.csv \
  92. --input_schema=text:str:1,label:str:1 \
  93. --first_sequence=text \
  94. --label_name=label \
  95. --label_enumerate_values=Positive,Negative \
  96. --checkpoint_dir=./small_model_2/ \
  97. --learning_rate=1e-4 \
  98. --epoch_num=5 \
  99. --random_seed=42 \
  100. --save_checkpoint_steps=100 \
  101. --sequence_length=128 \
  102. --micro_batch_size=8 \
  103. --user_defined_parameters="
  104. pretrain_model_name_or_path=alibaba-pai/pai-bert-tiny-zh
  105. "