1. int main(int argc, const char* argv[]) {
    2. if (argc < 4) {
    3. DLOG(INFO) << "Usage: ./quantized.out src.mnn dst.mnn preTreatConfig.json\n";
    4. return 0;
    5. }
    6. const char* modelFile = argv[1];
    7. const char* preTreatConfig = argv[3];
    8. const char* dstFile = argv[2];
    9. DLOG(INFO) << ">>> modelFile: " << modelFile;
    10. DLOG(INFO) << ">>> preTreatConfig: " << preTreatConfig;
    11. DLOG(INFO) << ">>> dstFile: " << dstFile;
    12. std::unique_ptr<MNN::NetT> netT;
    13. // 读取MNN模型
    14. {
    15. std::ifstream input(modelFile);
    16. std::ostringstream outputOs;
    17. outputOs << input.rdbuf();
    18. netT = MNN::UnPackNet(outputOs.str().c_str());
    19. }
    20. // temp build net for inference
    21. flatbuffers::FlatBufferBuilder builder(1024);
    22. auto offset = MNN::Net::Pack(builder, netT.get());
    23. builder.Finish(offset);
    24. int size = builder.GetSize();
    25. auto ocontent = builder.GetBufferPointer();
    26. // model buffer for creating mnn Interpreter
    27. // 构造两个buffer,一个负责进行推理统计,一个负责生成量化网络
    28. std::unique_ptr<uint8_t> modelForInference(new uint8_t[size]);
    29. memcpy(modelForInference.get(), ocontent, size);
    30. std::unique_ptr<uint8_t> modelOriginal(new uint8_t[size]);
    31. memcpy(modelOriginal.get(), ocontent, size);
    32. netT.reset();
    33. netT = MNN::UnPackNet(modelOriginal.get());
    34. // quantize model's weight
    35. DLOG(INFO) << "Calibrate the feature and quantize model...";
    36. // 构造校准类从而进行量化
    37. std::shared_ptr<Calibration> calibration(
    38. new Calibration(netT.get(), modelForInference.get(), size, preTreatConfig));
    39. calibration->runQuantizeModel();
    40. DLOG(INFO) << "Quantize model done!";
    41. flatbuffers::FlatBufferBuilder builderOutput(1024);
    42. builderOutput.ForceDefaults(true);
    43. auto len = MNN::Net::Pack(builderOutput, netT.get());
    44. builderOutput.Finish(len);
    45. {
    46. std::ofstream output(dstFile);
    47. output.write((const char*)builderOutput.GetBufferPointer(), builderOutput.GetSize());
    48. }
    49. }

    quantized.cpp 文件负责实现量化,流程如下所示:

    • 通过flatbuffer读取模型内容,并生成两个模型buffer,其中一个负责进行推理统计激活值信息,另一个负责生成量化网络。
    • 调用Calibration类进行量化,生成量化网络。
    • 保存量化后的模型。