跨任务知识蒸馏简介

预训练语言模型的蒸馏往往只关注单一领域的知识,学生模型也只能从对应领域的教师模型中获取知识。知识蒸馏可以让学生模型从多个来自不同领域的教师或跨领域的教师中获取知识,进而帮助目标领域的学生模型训练。但这种方式可能会传递一些来自其他领域的非迁移性知识,这些知识与当前领域无关从而造成模型下降。跨任务知识蒸馏通过元学习的方法获取多个领域的可迁移性知识,提高教师模型在跨领域知识上的泛化性能以提高学生模型的性能。

跨任务知识蒸馏Meta-KD算法过程介绍

Meta-KD算法与现有跨任务知识蒸馏不同,借鉴了元学习的思想,首先在多个不同领域数据集上训练一个meta-teacher,获取多个领域的可迁移性知识。在这个meta-teacher的基础上,模型再蒸馏到基于特定任务的学生模型上,取得更好的效果。Meta-KD算法的算法思想如下图所示:
image.png
在算法实现中,首先基于不同领域的训练数据,训练meta-teacher。由于不同领域数据的可迁移性不同,我们对每个数据都采用基于Class Centroid的方法计算权重(即为下图的Prototype Score),表示这个数据对于其他各个领域的可迁移性。一般而言,领域特性越小的数据,权重越大。Meta-teacher在领域数据上进行带权重的混合训练。当meta-teacher训练完毕后,我们将这一模型蒸馏到某个特定领域的数据上,充分考虑了多种损失函数的组合。此外,由于meta-teacher不一定在所有领域数据上都具有良好的表现,在蒸馏过程中我们采用了domain-expertise weight衡量meta-teacher对于当前样本预测正确的置信度。Domain-expertise weight较高的样本在蒸馏过程中拥有更高的权重。
image.png
Meta-KD算法的细节可以参考论文Meta-KD: A Meta Knowledge Distillation Framework for Language Model Compression across Domains (ACL-IJCNLP 2021)[链接]

完整流程示例

环境准备

完整代码位于EasyNLP/examples/knowledge_distillation/metakd
下载示例数据集并划分:

  1. cd data
  2. if [ ! -f ./SENTI/dev.tsv ];then
  3. wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/datasets/domain_data/domain_sentiment_data.tar.gz
  4. tar -zxvf domain_sentiment_data.tar.gz
  5. fi
  6. cd ..
  7. if [ ! -f data/SENTI/dev.tsv ];then
  8. python generate_senti_data.py
  9. fi

预处理示例数据集

产生训练所需meta-weight并统一测试集格式:

  1. if [ ! -f data/SENTI/train.embeddings.tsv ];then
  2. python extract_embeddings.py \
  3. --bert_path ~/.easynlp/modelzoo/bert-base-uncased \
  4. --input data/SENTI/train.tsv \
  5. --output data/SENTI/train.embeddings.tsv \
  6. --task_name senti --gpu 7
  7. fi
  8. if [ ! -f data/SENTI/train_with_weights.tsv ];then
  9. python generate_meta_weights.py \
  10. data/SENTI/train.embeddings.tsv \
  11. data/SENTI/train_with_weights.tsv \
  12. books,dvd,electronics,kitchen
  13. fi
  14. if [ ! -f data/SENTI/dev_standard.tsv ];then
  15. python generate_dev_file.py \
  16. --input data/SENTI/dev.tsv \
  17. --output data/SENTI/dev_standard.tsv
  18. fi

训练meta-teacher

训练时需要指定use_sample_weight和use_domain_loss为Ture并设定domain_loss_weight的值。

  1. model=bert-base-uncased
  2. DISTRIBUTED_ARGS="--nproc_per_node 2 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 6009"
  3. python -m torch.distributed.launch $DISTRIBUTED_ARGS meta_teacher_train.py \
  4. --mode train \
  5. --tables=data/SENTI/train_with_weights.tsv,data/SENTI/dev_standard.tsv \
  6. --input_schema=guid:str:1,text_a:str:1,text_b:str:1,label:str:1,domain:str:1,weight:str:1 \
  7. --first_sequence=text_a \
  8. --second_sequence=text_b \
  9. --label_name=label \
  10. --label_enumerate_values=positive,negative \
  11. --checkpoint_dir=./tmp/meta_teacher/ \
  12. --learning_rate=3e-5 \
  13. --epoch_num=1 \
  14. --random_seed=42 \
  15. --logging_steps=20 \
  16. --save_checkpoint_steps=50 \
  17. --sequence_length=128 \
  18. --micro_batch_size=16 \
  19. --app_name=text_classify \
  20. --user_defined_parameters="
  21. pretrain_model_name_or_path=$model
  22. use_sample_weights=True
  23. use_domain_loss=True
  24. domain_loss_weight=0.5
  25. "

蒸馏对应领域的学生模型

蒸馏对应两个阶段,第一阶段为拟合教师模型的中间层输出,第二阶段通过蒸馏损失函数训练学生模型。
第一阶段需要指定教师模型的保存路径teacher_model_path, 将distill_stage设置为first。此外,第一阶段蒸馏的checkpoint_dir将作为第二阶段蒸馏的模型输入pretrain_model_name_or_path
第二阶段同样需要制定教师模型的保存路径,将将distill_stage设置为second。同时确保pretrain_model_name_or_path为一阶段的模型保存位置。

  1. model=bert-tiny-uncased
  2. # In domain_sentiment_data, genre is one of ["books", "dvd", "electronics", "kitchen"]
  3. genre=books
  4. cd ${cur_path}
  5. # 1. Distillation pretrain
  6. DISTRIBUTED_ARGS="--nproc_per_node 2 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 6009"
  7. # Pretrained distillation
  8. python -m torch.distributed.launch $DISTRIBUTED_ARGS meta_student_distill.py \
  9. --mode train \
  10. --tables=data/SENTI/train_with_weights.tsv,data/SENTI/dev_standard.tsv \
  11. --input_schema=guid:str:1,text_a:str:1,text_b:str:1,label:str:1,domain:str:1,weight:str:1 \
  12. --first_sequence=text_a \
  13. --second_sequence=text_b \
  14. --label_name=label \
  15. --label_enumerate_values=positive,negative \
  16. --checkpoint_dir=./tmp/$genre/meta_student_pretrain/ \
  17. --learning_rate=3e-5 \
  18. --epoch_num=10 \
  19. --random_seed=42 \
  20. --logging_steps=20 \
  21. --sequence_length=128 \
  22. --micro_batch_size=16 \
  23. --app_name=text_classify \
  24. --user_defined_parameters="
  25. pretrain_model_name_or_path=$model
  26. teacher_model_path=./tmp/meta_teacher/
  27. domain_loss_weight=0.5
  28. distill_stage=first
  29. genre=$genre
  30. T=2
  31. "
  32. # 2. Finetune
  33. pretrained_path="./tmp/$genre/meta_student_pretrain/"
  34. python -m torch.distributed.launch $DISTRIBUTED_ARGS meta_student_distill.py \
  35. --mode train \
  36. --tables=data/SENTI/train_with_weights.tsv,data/SENTI/dev_standard.tsv \
  37. --input_schema=guid:str:1,text_a:str:1,text_b:str:1,label:str:1,domain:str:1,weight:str:1 \
  38. --first_sequence=text_a \
  39. --second_sequence=text_b \
  40. --label_name=label \
  41. --label_enumerate_values=positive,negative \
  42. --checkpoint_dir=./tmp/$genre/meta_student_fintune/ \
  43. --learning_rate=3e-5 \
  44. --epoch_num=10 \
  45. --random_seed=42 \
  46. --logging_steps=20 \
  47. --save_checkpoint_steps=50 \
  48. --sequence_length=128 \
  49. --micro_batch_size=16 \
  50. --app_name=text_classify \
  51. --user_defined_parameters="
  52. pretrain_model_name_or_path=$pretrained_path
  53. teacher_model_path=./tmp/meta_teacher/
  54. domain_loss_weight=0.5
  55. distill_stage=second
  56. genre=$genre
  57. T=2
  58. "
  59. # 3. Evalute
  60. Student_model_path=./tmp/$genre/meta_student_fintune/
  61. python main_evaluate.py \
  62. --mode evaluate \
  63. --tables=data/SENTI/train_with_weights.tsv,data/SENTI/dev_standard.tsv \
  64. --input_schema=guid:str:1,text_a:str:1,text_b:str:1,label:str:1,domain:str:1,weight:str:1 \
  65. --first_sequence=text_a \
  66. --label_name=label \
  67. --label_enumerate_values=positive,negative \
  68. --checkpoint_dir=./tmp/meta_teacher/ \
  69. --learning_rate=3e-5 \
  70. --epoch_num=1 \
  71. --random_seed=42 \
  72. --logging_steps=20 \
  73. --sequence_length=128 \
  74. --micro_batch_size=16 \
  75. --app_name=text_classify \
  76. --user_defined_parameters="pretrain_model_name_or_path=$Student_model_path
  77. genre=$genre"

预测时请确保测试集的格式与训练集文件train_with_weights.tsv一致。