1 Obtain a trainable model

First, we need to obtain a trainable model, which can be achieved in two ways in MNN.

First, convert models trained by other frameworks, such as TensorFlow and Pytorch, into trainable models in MNN format. This process can be implemented by using the MNNConvert tool. Typical application scenarios are: 1. Use MNN to finetune a pre-trained model 2. Use MNN for quantization-aware training. You need to use the --forTrainingoption of MNNConvert during model conversion to retain the operators used in training such as BatchNorm and Dropout.

Second, use MNN to construct a model from scratch and use MNN for training, which can save the steps of model conversion and can also be easily converted into a model for quantization aware training. In the MNN_ROOT/tools/train/source/models/directory, we provide Lenet, MobilenetV1, MobilenetV2 and other models built using the MNN training framework.

1.1 Convert other framework models to MNN trainable models

Take the quantization-aware training of MobilenetV2 as an example. First, we need to download the MobilenetV2 model officially provided by TensorFlow, then compile MNNConvert, and execute the following command to convert:

  1. ./MNNConvert --modelFile mobilenet_v2_1.0_224_frozen.pb --MNNModel mobilenet_v2_tfpb_train.mnn --framework TF --bizCode AliNNTest --forTraining

Note that the preceding model used in the preceding command contains the BatchNorm operator and is not fused. By using the --forTraining option during conversion, we keep the BatchNorm operator into the transformed epoch model.

If your model does not include operators like BatchNorm and Dropout which will be fused when converting the model to MNN model, you don’t have to convert your model again. In this situation, you can use the MNN inference model for training as well.

Next, let’s follow the example in MNN_ROOT/tools/train/source/demo/``mobilenetV2Train.cpp which reads the converted model and converts it into a MNN training model. Key code examples are as follows:

  1. // mobilenetV2Train.cpp
  2. // Read the converted MNN model
  3. auto varMap = Variable::loadMap(argv[1]);
  4. // Specify bit width of the quantization
  5. int bits = 8;
  6. // Get intputs and outputs
  7. auto inputOutputs = Variable::getInputAndOutput(varMap);
  8. auto inputs = Variable::mapToSequence(inputOutputs.first);
  9. auto outputs = Variable::mapToSequence(inputOutputs.second);
  10. // Convert the MNN model into the trainable model.
  11. // (Extract the Convolution, BatchNorm, Dropout from the inference model, and then convert to trainable modules)
  12. std::shared_ptr<Module> model(PipelineModule::extract(inputs, outputs, true));
  13. // Convert the trainable model into a quantization-training models. If you don't want to do
  14. // quantization-aware training, then this step can be skipped.
  15. ((PipelineModule*)model.get())->toTrainQuant(bits);
  16. // Train te model.
  17. MobilenetV2Utils::train(model, 1001, 1, trainImagesFolder, trainImagesTxt, testImagesFolder, testImagesTxt);

1.2 Use MNN to construct models from scratch

Taking Lenet as an example, let’s take a look at how to use MNN to build a model from zero. MNN provides a variety of operators to use, the following example is not detailed. It is worth noting that the output of Pooling is NC4HW4 format, which needs to be converted to NCHW format to enter the full connection layer for calculation.

  1. class MNN_PUBLIC Lenet : public Module {
  2. public:
  3. Lenet();
  4. virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override;
  5. std::shared_ptr<Module> conv1;
  6. std::shared_ptr<Module> conv2;
  7. std::shared_ptr<Module> ip1;
  8. std::shared_ptr<Module> ip2;
  9. std::shared_ptr<Module> dropout;
  10. };
  11. // Initialization
  12. Lenet::Lenet() {
  13. NN::ConvOption convOption;
  14. convOption.kernelSize = {5, 5};
  15. convOption.channel = {1, 20};
  16. conv1.reset(NN::Conv(convOption));
  17. convOption.reset();
  18. convOption.kernelSize = {5, 5};
  19. convOption.channel = {20, 50};
  20. conv2.reset(NN::Conv(convOption));
  21. ip1.reset(NN::Linear(800, 500));
  22. ip2.reset(NN::Linear(500, 10));
  23. dropout.reset(NN::Dropout(0.5));
  24. // You must register the parameters for them to be updated in back prop.
  25. registerModel({conv1, conv2, ip1, ip2, dropout});
  26. }
  27. // Forward pass.
  28. std::vector<Express::VARP> Lenet::onForward(const std::vector<Express::VARP>& inputs) {
  29. using namespace Express;
  30. VARP x = inputs[0];
  31. x = conv1->forward(x);
  32. x = _MaxPool(x, {2, 2}, {2, 2});
  33. x = conv2->forward(x);
  34. x = _MaxPool(x, {2, 2}, {2, 2});
  35. // The output of the pooling layer is NC4HW4 which needs to be converted into
  36. // NCHW before computation in the FC layer
  37. x = _Convert(x, NCHW);
  38. x = _Reshape(x, {0, -1});
  39. x = ip1->forward(x);
  40. x = _Relu(x);
  41. x = dropout->forward(x);
  42. x = ip2->forward(x);
  43. x = _Softmax(x, 1);
  44. return {x};
  45. }

2 Implement the data set interface

See Load Training Data for a detailed description.

3 Train and save the model

Take MNIST model training as an example, the code is MNN_ROOT/tools/train/source/demo/`` MnistUtils.cpp

  1. // MnistUtils.cpp
  2. ......
  3. void MnistUtils::train(std::shared_ptr<Module> model, std::string root) {
  4. {
  5. // Load snapshot
  6. // Model structure + model params
  7. auto para = Variable::load("mnist.snapshot.mnn");
  8. model->loadParameters(para);
  9. }
  10. // Configure training framework params.
  11. auto exe = Executor::getGlobalExecutor();
  12. BackendConfig config;
  13. // Use CPU, 4 threads.
  14. exe->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 4);
  15. // SGD optimizer.
  16. std::shared_ptr<SGD> sgd(new SGD);
  17. sgd->append(model->parameters());
  18. // SGD params.
  19. sgd->setMomentum(0.9f);
  20. sgd->setWeightDecay(0.0005f);
  21. // Creates the data set and DataLoader
  22. auto dataset = MnistDataset::create(root, MnistDataset::Mode::TRAIN);
  23. // the stack transform, stack [1, 28, 28] to [n, 1, 28, 28]
  24. const size_t batchSize = 64;
  25. const size_t numWorkers = 0;
  26. bool shuffle = true;
  27. auto dataLoader = std::shared_ptr<DataLoader>(dataset.createLoader(batchSize, true, shuffle, numWorkers));
  28. size_t iterations = dataLoader->iterNumber();
  29. auto testDataset = MnistDataset::create(root, MnistDataset::Mode::TEST);
  30. const size_t testBatchSize = 20;
  31. const size_t testNumWorkers = 0;
  32. shuffle = false;
  33. auto testDataLoader = std::shared_ptr<DataLoader>(testDataset.createLoader(testBatchSize, true, shuffle, testNumWorkers));
  34. size_t testIterations = testDataLoader->iterNumber();
  35. // Begin training
  36. for (int epoch = 0; epoch < 50; ++epoch) {
  37. model->clearCache();
  38. exe->gc(Executor::FULL);
  39. exe->resetProfile();
  40. {
  41. AUTOTIME;
  42. dataLoader->reset();
  43. // Set isTraining flag to be true during training phase.
  44. model->setIsTraining(true);
  45. Timer _100Time;
  46. int lastIndex = 0;
  47. int moveBatchSize = 0;
  48. for (int i = 0; i < iterations; i++) {
  49. // AUTOTIME;
  50. // Obtain the training data and label for a batch.
  51. auto trainData = dataLoader->next();
  52. auto example = trainData[0];
  53. auto cast = _Cast<float>(example.first[0]);
  54. example.first[0] = cast * _Const(1.0f / 255.0f);
  55. moveBatchSize += example.first[0]->getInfo()->dim[0];
  56. // Compute One-Hot
  57. auto newTarget = _OneHot(_Cast<int32_t>(example.second[0]), _Scalar<int>(10), _Scalar<float>(1.0f),
  58. _Scalar<float>(0.0f));
  59. // Forward pass
  60. auto predict = model->forward(example.first[0]);
  61. // Calculate loss
  62. auto loss = _CrossEntropy(predict, newTarget);
  63. // Adjust the learning rate
  64. float rate = LrScheduler::inv(0.01, epoch * iterations + i, 0.0001, 0.75);
  65. sgd->setLearningRate(rate);
  66. if (moveBatchSize % (10 * batchSize) == 0 || i == iterations - 1) {
  67. std::cout << "epoch: " << (epoch);
  68. std::cout << " " << moveBatchSize << " / " << dataLoader->size();
  69. std::cout << " loss: " << loss->readMap<float>()[0];
  70. std::cout << " lr: " << rate;
  71. std::cout << " time: " << (float)_100Time.durationInUs() / 1000.0f << " ms / " << (i - lastIndex) << " iter" << std::endl;
  72. std::cout.flush();
  73. _100Time.reset();
  74. lastIndex = i;
  75. }
  76. // Backward pass and parameter updates.
  77. sgd->step(loss);
  78. }
  79. }
  80. // Saves the model parames for later reloads.
  81. Variable::save(model->parameters(), "mnist.snapshot.mnn");
  82. {
  83. model->setIsTraining(false);
  84. auto forwardInput = _Input({1, 1, 28, 28}, NC4HW4);
  85. forwardInput->setName("data");
  86. auto predict = model->forward(forwardInput);
  87. predict->setName("prob");
  88. // Optimizes the network structure (optional)
  89. Transformer::turnModelToInfer()->onExecute({predict});
  90. // Saves the model and its structure, which can be used without the Module definitions.
  91. Variable::save({predict}, "temp.mnist.mnn");
  92. }
  93. // Model test.
  94. int correct = 0;
  95. testDataLoader->reset();
  96. // Set training to be false during model test.
  97. model->setIsTraining(false);
  98. int moveBatchSize = 0;
  99. for (int i = 0; i < testIterations; i++) {
  100. auto data = testDataLoader->next();
  101. auto example = data[0];
  102. moveBatchSize += example.first[0]->getInfo()->dim[0];
  103. if ((i + 1) % 100 == 0) {
  104. std::cout << "test: " << moveBatchSize << " / " << testDataLoader->size() << std::endl;
  105. }
  106. auto cast = _Cast<float>(example.first[0]);
  107. example.first[0] = cast * _Const(1.0f / 255.0f);
  108. auto predict = model->forward(example.first[0]);
  109. predict = _ArgMax(predict, 1);
  110. auto accu = _Cast<int32_t>(_Equal(predict, _Cast<int32_t>(example.second[0]))).sum({});
  111. correct += accu->readMap<int32_t>()[0];
  112. }
  113. // Calculate the accuracy.
  114. auto accu = (float)correct / (float)testDataLoader->size();
  115. std::cout << "epoch: " << epoch << " accuracy: " << accu << std::endl;
  116. exe->dumpProfile();
  117. }
  118. }

4 Save and restore the model

There are two ways to save the model:

First, only the model parameters are saved, not the model structure. These parameters need to be loaded by the corresponding model structure.
Save:

  1. Variable::save(model->parameters(), "mnist.snapshot.mnn");

Restore:

  1. // Model structure + params
  2. auto para = Variable::load("mnist.snapshot.mnn");
  3. model->loadParameters(para);

Second, save the model structure and parameters at the same time for easy inference.
Save:

  1. model->setIsTraining(false);
  2. auto forwardInput = _Input({1, 1, 28, 28}, NC4HW4);
  3. forwardInput->setName("data");
  4. auto predict = model->forward(forwardInput);
  5. predict->setName("prob");
  6. // Saves the output node, and saves the structure + params
  7. Variable::save({predict}, "temp.mnist.mnn");

Restore (Inference):

  1. auto varMap = Variable::loadMap("temp.mnist.mnn");
  2. // The input node name is the same as the one defined during Save, i.e. 'data'.
  3. // The dimension is also the same as the one defined during Save, ie. [1, 1, 28, 28]
  4. float* inputPtr = varMap["data"]->writeMap<float>();
  5. //Fill inputPtr
  6. // The output node name is the same as the one defined during Save, i.e. 'prob'
  7. float* outputPtr = varMap["prob"]->readMap<float>();
  8. // Use data in outputPtr