1 获取可训练模型

首先我们需要获得一个可训练的模型结构,而在MNN中可以用两种方式达到这一目的。

一、将其他框架,如TensorFlow,Pytorch训练得到的模型转成MNN可训练模型,这一过程可使用 MNNConverter 工具实现。典型的应用场景为 1. 使用MNN Finetune,2. 使用MNN进行训练量化。在模型转换过程中建议使用 MNNConverter 的 —forTraining 选项,保留BatchNorm,Dropout等训练过程中会用到的算子。

二、使用MNN从零开始搭建一个模型,并使用MNN进行训练,这可以省去模型转换的步骤,并且也可以十分容易地转换为训练量化模型。在 MNN_ROOT/tools/train/source/models/ 目录中我们提供了Lenet,MobilenetV1,MobilenetV2等使用MNN训练框架搭建的模型。

1.1 将其他框架模型转换为MNN可训练模型

以MobilenetV2的训练量化为例。首先我们需要到下载TensorFlow官方提供的MobilenetV2模型,然后编译 MNNConverter,并执行以下命令进行转换:

  1. ./MNNConvert --modelFile mobilenet_v2_1.0_224_frozen.pb --MNNModel mobilenet_v2_tfpb_train.mnn --framework TF --bizCode AliNNTest --forTraining

注意,上述命令中使用到的 mobilenet_v2_1.0_224_frozen.pb 模型中含有 BatchNorm 算子,没有进行融合,通过在转换时使用 —forTraining 选项,我们保留了BatchNorm算子到转换出来的 mobilenet_v2_tfpb_train.mnn 模型之中。

如果你的模型中没有BN,Dropout等在转MNN推理模型时会被融合掉的算子,那么直接使用MNN推理模型也可以进行训练,不必重新进行转换。

接下来我们仿照 MNN_ROOT/tools/train/source/demo/mobilenetV2Train.cpp 中的示例,读取转换得到的模型,将其转换为MNN可训练模型。关键代码示例如下

  1. // mobilenetV2Train.cpp
  2. // 读取转换得到的MNN模型
  3. auto varMap = Variable::loadMap(argv[1]);
  4. // 指定量化bit数
  5. int bits = 8;
  6. // 获取模型的输入和输出
  7. auto inputOutputs = Variable::getInputAndOutput(varMap);
  8. auto inputs = Variable::mapToSequence(inputOutputs.first);
  9. auto outputs = Variable::mapToSequence(inputOutputs.second);
  10. // 将转换得到的模型转换为可训练模型(将推理模型中的卷积,BatchNorm,Dropout抽取出来,转换成可训练模块)
  11. std::shared_ptr<Module> model(PipelineModule::extract(inputs, outputs, true));
  12. // 将可训练模型转换为训练量化模型,如果不需要进行训练量化,则可不做这一步
  13. ((PipelineModule*)model.get())->toTrainQuant(bits);
  14. // 进入训练环节
  15. MobilenetV2Utils::train(model, 1001, 1, trainImagesFolder, trainImagesTxt, testImagesFolder, testImagesTxt);

1.2 使用MNN从零开始搭建模型

以Lenet为例,我们来看一下,如何使用MNN从零搭建一个模型。MNN提供了丰富的算子可供使用,下面的例子就不详细展开。值得注意的是Pooling输出为NC4HW4格式,需要转换到 NCHW 格式才能进入全连接层进行计算。

  1. class MNN_PUBLIC Lenet : public Module {
  2. public:
  3. Lenet();
  4. virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override;
  5. std::shared_ptr<Module> conv1;
  6. std::shared_ptr<Module> conv2;
  7. std::shared_ptr<Module> ip1;
  8. std::shared_ptr<Module> ip2;
  9. std::shared_ptr<Module> dropout;
  10. };
  11. // 初始化
  12. Lenet::Lenet() {
  13. NN::ConvOption convOption;
  14. convOption.kernelSize = {5, 5};
  15. convOption.channel = {1, 20};
  16. conv1.reset(NN::Conv(convOption));
  17. convOption.reset();
  18. convOption.kernelSize = {5, 5};
  19. convOption.channel = {20, 50};
  20. conv2.reset(NN::Conv(convOption));
  21. ip1.reset(NN::Linear(800, 500));
  22. ip2.reset(NN::Linear(500, 10));
  23. dropout.reset(NN::Dropout(0.5));
  24. // 必须要进行register的参数才会进行更新
  25. registerModel({conv1, conv2, ip1, ip2, dropout});
  26. }
  27. // 前向计算
  28. std::vector<Express::VARP> Lenet::onForward(const std::vector<Express::VARP>& inputs) {
  29. using namespace Express;
  30. VARP x = inputs[0];
  31. x = conv1->forward(x);
  32. x = _MaxPool(x, {2, 2}, {2, 2});
  33. x = conv2->forward(x);
  34. x = _MaxPool(x, {2, 2}, {2, 2});
  35. // Pooling输出为NC4HW4格式,需要转换到NCHW才能进入全连接层进行计算
  36. x = _Convert(x, NCHW);
  37. x = _Reshape(x, {0, -1});
  38. x = ip1->forward(x);
  39. x = _Relu(x);
  40. x = dropout->forward(x);
  41. x = ip2->forward(x);
  42. x = _Softmax(x, 1);
  43. return {x};
  44. }

2 实现数据集接口

这部分在MNN文档中 加载训练数据 部分有详细描述。

3 训练并保存模型

以MNIST模型训练为例,代码在 MNN_ROOT/tools/train/source/demo/MnistUtils.cpp

  1. // MnistUtils.cpp
  2. ......
  3. void MnistUtils::train(std::shared_ptr<Module> model, std::string root) {
  4. {
  5. // Load snapshot
  6. // 模型结构 + 模型参数
  7. auto para = Variable::load("mnist.snapshot.mnn");
  8. model->loadParameters(para);
  9. }
  10. // 配置训练框架参数
  11. auto exe = Executor::getGlobalExecutor();
  12. BackendConfig config;
  13. // 使用CPU,4线程
  14. exe->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 4);
  15. // SGD求解器
  16. std::shared_ptr<SGD> sgd(new SGD(model));
  17. // SGD求解器参数设置
  18. sgd->setMomentum(0.9f);
  19. sgd->setWeightDecay(0.0005f);
  20. // 创建数据集和DataLoader
  21. auto dataset = MnistDataset::create(root, MnistDataset::Mode::TRAIN);
  22. // the stack transform, stack [1, 28, 28] to [n, 1, 28, 28]
  23. const size_t batchSize = 64;
  24. const size_t numWorkers = 0;
  25. bool shuffle = true;
  26. auto dataLoader = std::shared_ptr<DataLoader>(dataset.createLoader(batchSize, true, shuffle, numWorkers));
  27. size_t iterations = dataLoader->iterNumber();
  28. auto testDataset = MnistDataset::create(root, MnistDataset::Mode::TEST);
  29. const size_t testBatchSize = 20;
  30. const size_t testNumWorkers = 0;
  31. shuffle = false;
  32. auto testDataLoader = std::shared_ptr<DataLoader>(testDataset.createLoader(testBatchSize, true, shuffle, testNumWorkers));
  33. size_t testIterations = testDataLoader->iterNumber();
  34. // 开始训练
  35. for (int epoch = 0; epoch < 50; ++epoch) {
  36. model->clearCache();
  37. exe->gc(Executor::FULL);
  38. exe->resetProfile();
  39. {
  40. AUTOTIME;
  41. dataLoader->reset();
  42. // 训练阶段需设置isTraining Flag为true
  43. model->setIsTraining(true);
  44. Timer _100Time;
  45. int lastIndex = 0;
  46. int moveBatchSize = 0;
  47. for (int i = 0; i < iterations; i++) {
  48. // AUTOTIME;
  49. // 获得一个batch的数据,包括数据及其label
  50. auto trainData = dataLoader->next();
  51. auto example = trainData[0];
  52. auto cast = _Cast<float>(example.first[0]);
  53. example.first[0] = cast * _Const(1.0f / 255.0f);
  54. moveBatchSize += example.first[0]->getInfo()->dim[0];
  55. // Compute One-Hot
  56. auto newTarget = _OneHot(_Cast<int32_t>(example.second[0]), _Scalar<int>(10), _Scalar<float>(1.0f),
  57. _Scalar<float>(0.0f));
  58. // 前向计算
  59. auto predict = model->forward(example.first[0]);
  60. // 计算loss
  61. auto loss = _CrossEntropy(predict, newTarget);
  62. // 调整学习率
  63. float rate = LrScheduler::inv(0.01, epoch * iterations + i, 0.0001, 0.75);
  64. sgd->setLearningRate(rate);
  65. if (moveBatchSize % (10 * batchSize) == 0 || i == iterations - 1) {
  66. std::cout << "epoch: " << (epoch);
  67. std::cout << " " << moveBatchSize << " / " << dataLoader->size();
  68. std::cout << " loss: " << loss->readMap<float>()[0];
  69. std::cout << " lr: " << rate;
  70. std::cout << " time: " << (float)_100Time.durationInUs() / 1000.0f << " ms / " << (i - lastIndex) << " iter" << std::endl;
  71. std::cout.flush();
  72. _100Time.reset();
  73. lastIndex = i;
  74. }
  75. // 根据loss反向计算,并更新网络参数
  76. sgd->step(loss);
  77. }
  78. }
  79. // 保存模型参数,便于重新载入训练
  80. Variable::save(model->parameters(), "mnist.snapshot.mnn");
  81. {
  82. model->setIsTraining(false);
  83. auto forwardInput = _Input({1, 1, 28, 28}, NC4HW4);
  84. forwardInput->setName("data");
  85. auto predict = model->forward(forwardInput);
  86. predict->setName("prob");
  87. // 优化网络结构【可选】
  88. Transformer::turnModelToInfer()->onExecute({predict});
  89. // 保存模型和结构,可脱离Module定义使用
  90. Variable::save({predict}, "temp.mnist.mnn");
  91. }
  92. // 测试模型
  93. int correct = 0;
  94. testDataLoader->reset();
  95. // 测试时,需设置标志位
  96. model->setIsTraining(false);
  97. int moveBatchSize = 0;
  98. for (int i = 0; i < testIterations; i++) {
  99. auto data = testDataLoader->next();
  100. auto example = data[0];
  101. moveBatchSize += example.first[0]->getInfo()->dim[0];
  102. if ((i + 1) % 100 == 0) {
  103. std::cout << "test: " << moveBatchSize << " / " << testDataLoader->size() << std::endl;
  104. }
  105. auto cast = _Cast<float>(example.first[0]);
  106. example.first[0] = cast * _Const(1.0f / 255.0f);
  107. auto predict = model->forward(example.first[0]);
  108. predict = _ArgMax(predict, 1);
  109. auto accu = _Cast<int32_t>(_Equal(predict, _Cast<int32_t>(example.second[0]))).sum({});
  110. correct += accu->readMap<int32_t>()[0];
  111. }
  112. // 计算准确率
  113. auto accu = (float)correct / (float)testDataLoader->size();
  114. std::cout << "epoch: " << epoch << " accuracy: " << accu << std::endl;
  115. exe->dumpProfile();
  116. }
  117. }

4 保存和恢复模型

一、只保存模型参数,不保存模型结构,需要对应的模型结构去加载这些参数
保存:

  1. Variable::save(model->parameters(), "mnist.snapshot.mnn");

恢复:

  1. // 模型结构 + 模型参数
  2. auto para = Variable::load("mnist.snapshot.mnn");
  3. model->loadParameters(para);

二、同时保存模型结构和参数,便于推理
保存:

  1. model->setIsTraining(false);
  2. auto forwardInput = _Input({1, 1, 28, 28}, NC4HW4);
  3. forwardInput->setName("data");
  4. auto predict = model->forward(forwardInput);
  5. predict->setName("prob");
  6. // 保存输出节点,会连同结构参数一并存储下来
  7. Variable::save({predict}, "temp.mnist.mnn");

恢复(进行推理):

  1. auto varMap = Variable::loadMap("temp.mnist.mnn");
  2. //输入节点名与保存时设定的名字一致,为 data,维度大小与保存时设定的大小一致,为 [1, 1, 28, 28]
  3. float* inputPtr = varMap["data"]->writeMap<float>();
  4. //填充 inputPtr
  5. //输出节点名与保存时设定的名字一致,为 prob
  6. float* outputPtr = varMap["prob"]->readMap<float>();
  7. // 使用 outputPtr 的数据