知识蒸馏简介

随着BERT等预训练语言模型在各项任务上都取得了STOA效果,BERT这类模型已经成为 NLP 深度迁移学习管道中的重要组成部分。但 BERT 并不是完美无瑕的,这类模型仍然存在以下两个问题:

  1. 模型参数量太大:BERT-base 模型能够包含一亿个参数,较大的 BERT-large 甚至包含 3.4 亿个参数。显然,很难将这种规模的模型部署到资源有限的环境(例如移动设备或嵌入式系统)当中。
  2. 训练/推理速度慢:在基于 Pod 配置的 4 个 Cloud TPUs(总共 16 个 TPU 芯片)上对 BERT-base 进行训练,或者在 16 个 Cloud TPU(总共 64 个 TPU 芯片)上对 BERT-large 进行训练,每次预训练都需要至少 4 天的时间才能完成。而BERT的推理速度更是严重影响到了需要较高QPS的线上场景,部署成本非常高。

而这个问题,不仅仅是在NLP领域存在,计算机视觉也同样存在,通常来讲有以下三种解决方案:

  1. 架构改进:将原有的架构改进为更小/更快的架构,例如,将 RNN 替换为 Transformer 或 CNN,ALBERT替代BERT等;使用需要较少计算的层等。当然也可以采用其他优化,例如从学习率和策略、预热步数,较大的批处理大小等;
  2. 模型压缩:通常使用量化和修剪来完成,从而能够在架构不变(或者大部分架构不变)的情况下减少计算总量;
  3. 知识蒸馏:训练一个小的模型,使得其在相应任务行为上能够逼近大的模型的效果,如DistillBERT,BERT-PKD,TinyBERT等

image.png
PAI-NLP团队,在知识蒸馏上有大量相关的积累,如我们提出的AdaBERT,利用可微神经网络架构搜索来自动将BERT蒸馏成适应不同特定任务的小型模型。在多个NLP任务上的结果表明这些任务适应型压缩模型在保证表现相当的同时,推理速度比BERT快10~30倍,同时参数缩小至BERT的十分之一。本框架重点研究,以及提供简单易用的工具,来进行BERT的知识蒸馏。

EasyDistill能够快速接入EasyTransfer-AppZoo/ google-bert / hugginface-transformer的模型,进行多种方法,多种维度的蒸馏,总体框架如下:
image.png

使用说明

支持的任务

当前知识蒸馏训练范式仅支持部分下游任务,包括:

  1. 文本分类 (Text Classification)
  2. 单塔文本匹配 (Single-Tower Text Match)

其他下游任务的支持有待后续更新。

主要流程

EasyTexMiner 知识蒸馏的主要流程为:

  1. 准备所需的数据集,并提前处理为 EasyTexMiner 支持的 tsv 格式(以制表符\t分隔的值表)。
  2. 选定大规模的预训练模型作为 Teacher Model,并依照其所属的下游任务进行 fine-tuning。详情参考本文档对应章节。
  3. 导出训练好的 Teacher Model 的 logits 到文件。
  4. 根据需求选定小规模的预训练模型作为 Student Model,并依照知识蒸馏范式进行 fine-tuning。
  5. 得到目标模型。

用户接口

在常规的下游任务 fine-tuning 命令的基础上,使用知识蒸馏范式需要在 input_schema 的末尾追加 logits 条目,并在 user_defined_parameters 中显式地启用,以键值对的形式传入所需的参数:

参数名 类型 可选值 描述
enable_distillation bool True/False 是否启用知识蒸馏
type str vanilla_kd(更多类型有待增加) 知识蒸馏的类型
logits_name str 应与 input_schema 中一致 Logits 字段在输入模式中的名称
logits_saved_path str tsv 文件相对/绝对路径 Teacher Model 导出的 logits 文件的路径
temperature float 大于等于 1,一般不超过 10 知识蒸馏的温度
alpha float [0, 1],一般不大于 0.5 Teacher Knowledge 在训练过程中的占比

具体的 CLI 命令示例如下:

  1. # SST-2 文本分类 知识蒸馏样例
  2. easytexminer \
  3. --app_name=text_classify \
  4. --mode=train \
  5. --worker_count=1 \
  6. --worker_gpu=1 \
  7. --tables=train.tsv,dev.tsv \
  8. --input_schema=sent:str:1,label:str:1,logits:float:2 \
  9. --first_sequence=sent \
  10. --label_name=label \
  11. --label_enumerate_values=0,1 \
  12. --checkpoint_dir=./results/small_sst2_student \
  13. --learning_rate=3e-5 \
  14. --epoch_num=1 \
  15. --random_seed=42 \
  16. --save_checkpoint_steps=200 \
  17. --sequence_length=128 \
  18. --micro_batch_size=32 \
  19. --user_defined_parameters="
  20. pretrain_model_name_or_path=${STUDENT_MODEL}
  21. enable_distillation=True
  22. type=vanilla_kd
  23. logits_name=logits
  24. logits_saved_path=${LOGITS_PATH}
  25. temperature=5
  26. alpha=0.2
  27. "

完整流程示例

本节以英文双句文本分类任务(MRPC)为例,给出完整的知识蒸馏流程命令示例。
可在此下载训练集验证集

为了快速测试,样例中使用了尽量精简的超参数设置(#epoch、batch size 等),需要根据实际场景调整。

定义所需环境变量

  1. # GPU device settings
  2. export WORKER_COUNT=1
  3. export WORKER_GPU=1
  4. # Models to be used
  5. export TEACHER_MODEL=bert-large-uncased
  6. export STUDENT_MODEL=bert-small-uncased
  7. # Path to save the fine-tuned models
  8. export TEACHER_CKPT=results/large-sst2-teacher
  9. export STUDENT_CKPT=results/small-sst2-student
  10. # Path to save the teacher logits
  11. export LOGITS_PATH=results/large-sst2-teacher/logits.tsv

Teacher Fine-tuning

  1. easytexminer \
  2. --app_name=text_classify \
  3. --mode=train \
  4. --worker_count=1 \
  5. --worker_gpu=1 \
  6. --tables=train.tsv,dev.tsv \
  7. --input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \
  8. --first_sequence=sent1 \
  9. --second_sequence=sent2 \
  10. --label_name=label \
  11. --label_enumerate_values=0,1 \
  12. --checkpoint_dir=${TEACHER_CKPT} \
  13. --learning_rate=3e-5 \
  14. --epoch_num=1 \
  15. --random_seed=42 \
  16. --save_checkpoint_steps=100 \
  17. --sequence_length=128 \
  18. --micro_batch_size=32 \
  19. --user_defined_parameters="pretrain_model_name_or_path=${TEACHER_MODEL}"

导出 Teacher Logits

通过 predict 模式导出 teacher model 对训练集的 logits。

easytexminer \
    --app_name=text_classify \
    --mode=predict \
    --worker_count=1 \
    --worker_gpu=1 \
    --tables=train.tsv \
    --outputs=${LOGITS_PATH} \
    --input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \
    --output_schema=logits \
    --first_sequence=sent1 \
    --second_sequence=sent2 \
    --checkpoint_path=${TEACHER_CKPT} \
    --micro_batch_size=32 \
    --sequence_length=128

Student 知识蒸馏

注意在 input_schema 中追加 logits 字段,类型为 float,数量与任务的标签数保持一致。

easytexminer \
    --app_name=text_classify \
    --mode=train \
    --worker_count=1 \
    --worker_gpu=1 \
    --tables=train.tsv,dev.tsv \
    --input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1,logits:float:2 \
    --first_sequence=sent1 \
    --second_sequence=sent2 \
    --label_name=label \
    --label_enumerate_values=0,1 \
    --checkpoint_dir=${STUDENT_CKPT} \
    --learning_rate=3e-5 \
    --epoch_num=1 \
    --random_seed=42 \
    --save_checkpoint_steps=200 \
    --sequence_length=128 \
    --micro_batch_size=32 \
    --user_defined_parameters="
        pretrain_model_name_or_path=${STUDENT_MODEL}
        enable_distillation=True
        type=vanilla_kd
        logits_name=logits
        logits_saved_path=${LOGITS_PATH}
        temperature=5
        alpha=0.2
    "

Student 模型预测

easytexminer \
    --app_name=text_classify \
    --mode=predict \
    --worker_gpu=1 \
    --worker_count=1 \
    --tables=dev.tsv \
    --outputs=student_pred.tsv \
    --input_schema=label:str:1,sid1:str:1,sid2:str:1,sent1:str:1,sent2:str:1 \
    --output_schema=predictions \
    --first_sequence=sent1 \
    --second_sequence=sent2 \
    --checkpoint_path=${STUDENT_CKPT} \
    --micro_batch_size=32 \
    --sequence_length=128