Distillation training basics

The general idea of distillation is to transfer the knowledge learned from one model to another, just like teachers teach students, so the former model is often called a teacher model, the latter model is often called a student model. If the student model is smaller than the teacher model, distillation also becomes a model compression method. Hinton proposed the idea of distillation in 2015. For specific methods, please refer to the paper:
Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. “Distilling the knowledge in a neural network.” arXiv preprint arXiv:1503.02531 (2015).

MNN distillation training example

Take the distillation training quantification of MobilenetV2 as an example. Let’s take a look at how to do distillation training in MNN. The relevant code is``MNN_ROOT/tools/train/source/demo/`` In distillTrainQuant.cpp.
According to the distillation algorithm, we need to take out the logits input to the Softmax node of the model, add the temperature parameter, and finally calculate the distillation loss for training.
Note that the MNN model of MobilenetV2 is required in this demo.

  1. // distillTrainQuant.cpp
  2. ......
  3. // Reads the teacher MNN model
  4. auto varMap = Variable::loadMap(argv[1]);
  5. if (varMap.empty()) {
  6. MNN_ERROR("Can not load model %s\n", argv[1]);
  7. return 0;
  8. }
  9. ......
  10. // Gets the inputs and outputs of the teacher model.
  11. auto inputOutputs = Variable::getInputAndOutput(varMap);
  12. auto inputs = Variable::mapToSequence(inputOutputs.first);
  13. MNN_ASSERT(inputs.size() == 1);
  14. // Input node of the teacher model
  15. auto input = inputs[0];
  16. std::string inputName = input->name();
  17. auto inputInfo = input->getInfo();
  18. MNN_ASSERT(nullptr != inputInfo && inputInfo->order == NC4HW4);
  19. // Output node of the teacher model
  20. auto outputs = Variable::mapToSequence(inputOutputs.second);
  21. std::string originOutputName = outputs[0]->name();
  22. // The node before Softmax in the teacher model, i.e. logits
  23. std::string nodeBeforeSoftmax = "MobilenetV2/Predictions/Reshape";
  24. auto lastVar = varMap[nodeBeforeSoftmax];
  25. std::map<std::string, VARP> outputVarPair;
  26. outputVarPair[nodeBeforeSoftmax] = lastVar;
  27. // Extracts the part of model from the input node to the logits output
  28. auto logitsOutput = Variable::mapToSequence(outputVarPair);
  29. {
  30. auto exe = Executor::getGlobalExecutor();
  31. BackendConfig config;
  32. exe->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 4);
  33. }
  34. // Converts original model (from the input to logits) into a trainable float model.
  35. std::shared_ptr<Module> model(PipelineModule::extract(inputs, logitsOutput, true));
  36. // Converts the above model into a quantized model.
  37. PipelineModule::turnQuantize(model.get(), bits);
  38. // The original model does not train and will only perform forward inference pass.
  39. std::shared_ptr<Module> originModel(PipelineModule::extract(inputs, logitsOutput, false));
  40. // Begin training.
  41. _train(originModel, model, inputName, originOutputName);

OK, the above demonstrates how to obtain the logits output and convert the model into a training quantization model. Let’s take a look at the key parts of the code for implementing quantization in the training project.

  1. // A forward pass during training.
  2. // Converts the input data into NC4HW4 used by MNN internally.
  3. auto nc4hw4example = _Convert(example, NC4HW4);
  4. // The forward pass of the teacher model. Gets the logits output of the teacher model.
  5. auto teacherLogits = origin->forward(nc4hw4example);
  6. // The forward pass of the student model. Gets the logits output of the student model.
  7. auto studentLogits = optmized->forward(nc4hw4example);
  8. // Calculate the One-Hot vector of the label.
  9. auto labels = trainData[0].second[0];
  10. const int addToLabel = 1;
  11. auto newTarget = _OneHot(_Cast<int32_t>(_Squeeze(labels + _Scalar<int32_t>(addToLabel), {})),
  12. _Scalar<int>(1001), _Scalar<float>(1.0f),
  13. _Scalar<float>(0.0f));
  14. // Use the logits of the teacher model and the student model and the true label to calclate loss.
  15. // Temperature T = 20, softTargets loss coefficient = 0.9
  16. VARP loss = _DistillLoss(studentLogits, teacherLogits, newTarget, 20, 0.9);

Let’s take a look at how distillation loss is calculated. The code is in MNN_ROOT/tools/train/source/optimizer/Loss.cpp

  1. // Loss.cpp
  2. Express::VARP _DistillLoss(Express::VARP studentLogits, Express::VARP teacherLogits, Express::VARP oneHotTargets, const float temperature, const float alpha) {
  3. auto info = teacherLogits->getInfo();
  4. if (info->order == NC4HW4) {
  5. teacherLogits = _Convert(teacherLogits, NCHW);
  6. studentLogits = _Convert(studentLogits, NCHW);
  7. }
  8. MNN_ASSERT(studentLogits->getInfo()->dim.size() == 2);
  9. MNN_ASSERT(studentLogits->getInfo()->dim == teacherLogits->getInfo()->dim);
  10. MNN_ASSERT(studentLogits->getInfo()->dim == oneHotTargets->getInfo()->dim);
  11. MNN_ASSERT(alpha >= 0 && alpha <= 1);
  12. // Calculates softTargets of the teacher model after considering the temperature.
  13. auto softTargets = _Softmax(teacherLogits * _Scalar(1 / temperature));
  14. // Calculates the prediction of the student model after considering the temperature.
  15. auto studentPredict = _Softmax(studentLogits * _Scalar(1 / temperature));
  16. // Calculates the loss for softTargets.
  17. auto loss1 = _Scalar(temperature * temperature) * _KLDivergence(studentPredict, softTargets);
  18. // Calculates the loss for the label.
  19. auto loss2 = _CrossEntropy(_Softmax(studentLogits), oneHotTargets);
  20. // Total loss is a weight sum of the above two losses.
  21. auto loss = _Scalar(alpha) * loss1 + _Scalar(1 - alpha) * loss2;
  22. return loss;
  23. }