参考:
介绍:【经典简读】知识蒸馏(Knowledge Distillation) 经典之作
编码:知识蒸馏(Knowledge Distillation)的Pytorch实现以及分析
论文:论文《Distilling the Knowledge in a Neural Network》by Hinton, 2015

知识蒸馏的本质是模型压缩。通过结构复杂、计算量大但性能优异的teacher net,指导结构简单、计算量较小的student net,来提升student net的性能。因为现在主流的网络其实存在冗余,浅层网络其实也能模仿深层网络的行为,达到相当的精度,只是如果直接通过现有数据集进行训练,很难达到这样的效果(说明相对来说,网络已经够深,但是训练数据量还是不够大)。因此使用知识蒸馏,teacher net指导student net来达到这样的效果。

知识蒸馏 Knowledge Distillation - 图1
图像来源:Knowledge Distillation

主要步骤:

  • 训练一个结构复杂但性能优异的 teacher net(传统方式训练即可,eg.使用交叉熵损失函数)
  • 使用 teacher net 指导训练 student net。但是在 student net 的训练时使用的loss函数做如下改动: ```python

    student net训练过程

criterion1 = nn.CrossEntropyLoss() # 交叉熵损失函数 criterion2 = nn.KLDivLoss() # KL散度损失函数 …… for inputs, labels in train_dataloader: ……

  1. # Hard Label 硬标签指的是样本的真实标签值,即 labels
  2. # 使用传统的方法,即与 Hard Label 对应的交叉熵损失函数作为评估指标,得到loss1
  3. outputs = student_net(inputs)
  4. loss1 = criterion1(outputs, labels)
  5. # Soft Label 软标签指的是 teacher_net 的预测值,再做softmax(多分类)或sigmoid(二分类),
  6. # 得到软化的概率分别,即 Soft Label 软标签。
  7. # 引入温度参数 T ,teacher net 和 student net使用相同的温度参数,都通过softmax
  8. # 即把 teacher net 当作拟合目标
  9. # 将 Soft Label 对应的KL散度损失函数作为评估指标,得到loss2
  10. T = 2 # 温度
  11. outputs_S = F.log_softmax(outputs/T, dim=1) # student_net的输出做softmax
  12. soft_target = teacher_net(inputs)
  13. output_T = F.softmax(soft_target/T, dim=1) # teacher_net的输出做softmax
  14. loss2 = criterion2(outputs_S, outputs_T)* T * T
  15. # 引入参数 alpha, 将 Hard Label 和 Soft Label 对应的交叉熵损失取加权平均,得到最终的loss
  16. # alpha 代表 Soft Label 对应的KL散度损失函数的权重系数
  17. # 权重系数 alpha 越大,则表明迁移学习对 teacher net 贡献的依赖越大
  18. # 在 student net 训练初期,使用较大的 alpha 有助于让 student net 更轻松地鉴别简单样本
  19. # 但在 student net 训练后期,应适当减小 alpha ,
  20. # 即减小 Soft Label 的比重,让 Hard Label ,即真实标签帮助鉴别困难样本
  21. alpha = 0.95
  22. loss = loss1 * (1-alpha) + loss2 * alpha

```

个人想法:

  • 知识蒸馏不仅能够用于模型压缩,训练出更小更快、又能保证网络性能的网络;
  • 而且也可用于提升模型性能,因为压缩成较小的模型,可能去掉了大模型的冗余,防止过拟合,所以可能使得 student net 取得比 teacher net 更好的分类性能。