蒸馏训练原理
蒸馏(Distillation)总体思想是将一个模型所学到的知识蒸馏转移到另外一个模型上,就像教师教学生一样,因此前一个模型常被称为教师模型,后面一个模型常被称为学生模型。如果学生模型比教师模型小,那么蒸馏也成为一种模型压缩方法。Hinton在2015年提出了蒸馏这一思想,具体做法可以参考论文:
Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. “Distilling the knowledge in a neural network.” arXiv preprint arXiv:1503.02531 (2015).
MNN蒸馏训练示例
以MobilenetV2的蒸馏训练量化为例,我们来看一下MNN中怎样做蒸馏训练,相关代码在 MNN_ROOT/tools/train/source/demo/distillTrainQuant.cpp中。
根据蒸馏算法,我们需要取出输入到模型Softmax节点的logits,然后加上温度参数,最后计算蒸馏loss进行训练
注意,此demo中需要MobilenetV2的MNN模型
// distillTrainQuant.cpp......// 读取教师MNN模型auto varMap = Variable::loadMap(argv[1]);if (varMap.empty()) {MNN_ERROR("Can not load model %s\n", argv[1]);return 0;}......// 获取教师模型的输入输出节点auto inputOutputs = Variable::getInputAndOutput(varMap);auto inputs = Variable::mapToSequence(inputOutputs.first);MNN_ASSERT(inputs.size() == 1);// 教师模型的输入节点及其名称auto input = inputs[0];std::string inputName = input->name();auto inputInfo = input->getInfo();MNN_ASSERT(nullptr != inputInfo && inputInfo->order == NC4HW4);// 教师模型的输出节点及其名称auto outputs = Variable::mapToSequence(inputOutputs.second);std::string originOutputName = outputs[0]->name();// 教师模型输入到Softmax之前的节点,即logitsstd::string nodeBeforeSoftmax = "MobilenetV2/Predictions/Reshape";auto lastVar = varMap[nodeBeforeSoftmax];std::map<std::string, VARP> outputVarPair;outputVarPair[nodeBeforeSoftmax] = lastVar;// 抽取出从输入一直到logits输出的模型部分auto logitsOutput = Variable::mapToSequence(outputVarPair);{auto exe = Executor::getGlobalExecutor();BackendConfig config;exe->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 4);}// 将原始模型输入到logits部分转换为float可训练模型std::shared_ptr<Module> model(PipelineModule::extract(inputs, logitsOutput, true));// 将上述模型转换为训练量化模型PipelineModule::turnQuantize(model.get(), bits);// 原始模型不会进行训练,只会进行前向推理std::shared_ptr<Module> originModel(PipelineModule::extract(inputs, logitsOutput, false));// 进入训练环节_train(originModel, model, inputName, originOutputName);
OK,上面演示了如何获取logits输出,并将模型转换成训练量化模型,下面看一下训练工程中实现量化的关键部分代码
// 训练中的一次前向计算过程// 将输入数据转换成MNN内部的NC4HW4格式auto nc4hw4example = _Convert(example, NC4HW4);// 教师模型前向,得到教师模型的logits输出auto teacherLogits = origin->forward(nc4hw4example);// 学生模型前向,得到学生模型的logits输出auto studentLogits = optmized->forward(nc4hw4example);// 计算label的One-Hot向量auto labels = trainData[0].second[0];const int addToLabel = 1;auto newTarget = _OneHot(_Cast<int32_t>(_Squeeze(labels + _Scalar<int32_t>(addToLabel), {})),_Scalar<int>(1001), _Scalar<float>(1.0f),_Scalar<float>(0.0f));// 使用教师模型和学生模型的logits,以及真实label,来计算蒸馏loss// 温度T = 20,softTargets loss系数 0.9VARP loss = _DistillLoss(studentLogits, teacherLogits, newTarget, 20, 0.9);
下面来看一下,蒸馏loss是如何计算的,代码在 MNN_ROOT/tools/train/source/optimizer/Loss.cpp中
// Loss.cppExpress::VARP _DistillLoss(Express::VARP studentLogits, Express::VARP teacherLogits, Express::VARP oneHotTargets, const float temperature, const float alpha) {auto info = teacherLogits->getInfo();if (info->order == NC4HW4) {teacherLogits = _Convert(teacherLogits, NCHW);studentLogits = _Convert(studentLogits, NCHW);}MNN_ASSERT(studentLogits->getInfo()->dim.size() == 2);MNN_ASSERT(studentLogits->getInfo()->dim == teacherLogits->getInfo()->dim);MNN_ASSERT(studentLogits->getInfo()->dim == oneHotTargets->getInfo()->dim);MNN_ASSERT(alpha >= 0 && alpha <= 1);// 计算考虑温度之后的教师模型softTargets,学生模型的预测输出auto softTargets = _Softmax(teacherLogits * _Scalar(1 / temperature));auto studentPredict = _Softmax(studentLogits * _Scalar(1 / temperature));// 计算softTargets产生的lossauto loss1 = _Scalar(temperature * temperature) * _KLDivergence(studentPredict, softTargets);// 计算label产生的lossauto loss2 = _CrossEntropy(_Softmax(studentLogits), oneHotTargets);// 总的loss为两者加权之和auto loss = _Scalar(alpha) * loss1 + _Scalar(1 - alpha) * loss2;return loss;}
