蒸馏训练原理

蒸馏(Distillation)总体思想是将一个模型所学到的知识蒸馏转移到另外一个模型上,就像教师教学生一样,因此前一个模型常被称为教师模型,后面一个模型常被称为学生模型。如果学生模型比教师模型小,那么蒸馏也成为一种模型压缩方法。Hinton在2015年提出了蒸馏这一思想,具体做法可以参考论文:
Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. “Distilling the knowledge in a neural network.” arXiv preprint arXiv:1503.02531 (2015).

MNN蒸馏训练示例

以MobilenetV2的蒸馏训练量化为例,我们来看一下MNN中怎样做蒸馏训练,相关代码在 MNN_ROOT/tools/train/source/demo/distillTrainQuant.cpp中。
根据蒸馏算法,我们需要取出输入到模型Softmax节点的logits,然后加上温度参数,最后计算蒸馏loss进行训练
注意,此demo中需要MobilenetV2的MNN模型

  1. // distillTrainQuant.cpp
  2. ......
  3. // 读取教师MNN模型
  4. auto varMap = Variable::loadMap(argv[1]);
  5. if (varMap.empty()) {
  6. MNN_ERROR("Can not load model %s\n", argv[1]);
  7. return 0;
  8. }
  9. ......
  10. // 获取教师模型的输入输出节点
  11. auto inputOutputs = Variable::getInputAndOutput(varMap);
  12. auto inputs = Variable::mapToSequence(inputOutputs.first);
  13. MNN_ASSERT(inputs.size() == 1);
  14. // 教师模型的输入节点及其名称
  15. auto input = inputs[0];
  16. std::string inputName = input->name();
  17. auto inputInfo = input->getInfo();
  18. MNN_ASSERT(nullptr != inputInfo && inputInfo->order == NC4HW4);
  19. // 教师模型的输出节点及其名称
  20. auto outputs = Variable::mapToSequence(inputOutputs.second);
  21. std::string originOutputName = outputs[0]->name();
  22. // 教师模型输入到Softmax之前的节点,即logits
  23. std::string nodeBeforeSoftmax = "MobilenetV2/Predictions/Reshape";
  24. auto lastVar = varMap[nodeBeforeSoftmax];
  25. std::map<std::string, VARP> outputVarPair;
  26. outputVarPair[nodeBeforeSoftmax] = lastVar;
  27. // 抽取出从输入一直到logits输出的模型部分
  28. auto logitsOutput = Variable::mapToSequence(outputVarPair);
  29. {
  30. auto exe = Executor::getGlobalExecutor();
  31. BackendConfig config;
  32. exe->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 4);
  33. }
  34. // 将原始模型输入到logits部分转换为float可训练模型
  35. std::shared_ptr<Module> model(PipelineModule::extract(inputs, logitsOutput, true));
  36. // 将上述模型转换为训练量化模型
  37. PipelineModule::turnQuantize(model.get(), bits);
  38. // 原始模型不会进行训练,只会进行前向推理
  39. std::shared_ptr<Module> originModel(PipelineModule::extract(inputs, logitsOutput, false));
  40. // 进入训练环节
  41. _train(originModel, model, inputName, originOutputName);

OK,上面演示了如何获取logits输出,并将模型转换成训练量化模型,下面看一下训练工程中实现量化的关键部分代码

  1. // 训练中的一次前向计算过程
  2. // 将输入数据转换成MNN内部的NC4HW4格式
  3. auto nc4hw4example = _Convert(example, NC4HW4);
  4. // 教师模型前向,得到教师模型的logits输出
  5. auto teacherLogits = origin->forward(nc4hw4example);
  6. // 学生模型前向,得到学生模型的logits输出
  7. auto studentLogits = optmized->forward(nc4hw4example);
  8. // 计算label的One-Hot向量
  9. auto labels = trainData[0].second[0];
  10. const int addToLabel = 1;
  11. auto newTarget = _OneHot(_Cast<int32_t>(_Squeeze(labels + _Scalar<int32_t>(addToLabel), {})),
  12. _Scalar<int>(1001), _Scalar<float>(1.0f),
  13. _Scalar<float>(0.0f));
  14. // 使用教师模型和学生模型的logits,以及真实label,来计算蒸馏loss
  15. // 温度T = 20,softTargets loss系数 0.9
  16. VARP loss = _DistillLoss(studentLogits, teacherLogits, newTarget, 20, 0.9);

下面来看一下,蒸馏loss是如何计算的,代码在 MNN_ROOT/tools/train/source/optimizer/Loss.cpp中

  1. // Loss.cpp
  2. Express::VARP _DistillLoss(Express::VARP studentLogits, Express::VARP teacherLogits, Express::VARP oneHotTargets, const float temperature, const float alpha) {
  3. auto info = teacherLogits->getInfo();
  4. if (info->order == NC4HW4) {
  5. teacherLogits = _Convert(teacherLogits, NCHW);
  6. studentLogits = _Convert(studentLogits, NCHW);
  7. }
  8. MNN_ASSERT(studentLogits->getInfo()->dim.size() == 2);
  9. MNN_ASSERT(studentLogits->getInfo()->dim == teacherLogits->getInfo()->dim);
  10. MNN_ASSERT(studentLogits->getInfo()->dim == oneHotTargets->getInfo()->dim);
  11. MNN_ASSERT(alpha >= 0 && alpha <= 1);
  12. // 计算考虑温度之后的教师模型softTargets,学生模型的预测输出
  13. auto softTargets = _Softmax(teacherLogits * _Scalar(1 / temperature));
  14. auto studentPredict = _Softmax(studentLogits * _Scalar(1 / temperature));
  15. // 计算softTargets产生的loss
  16. auto loss1 = _Scalar(temperature * temperature) * _KLDivergence(studentPredict, softTargets);
  17. // 计算label产生的loss
  18. auto loss2 = _CrossEntropy(_Softmax(studentLogits), oneHotTargets);
  19. // 总的loss为两者加权之和
  20. auto loss = _Scalar(alpha) * loss1 + _Scalar(1 - alpha) * loss2;
  21. return loss;
  22. }