1. struct PipelineInfo {
    2. /** op */
    3. const Op* op;
    4. /** input tensors */
    5. std::vector<Tensor*> inputs;
    6. /** output tensors */
    7. std::vector<Tensor*> outputs;
    8. };
    9. struct ScheduleInfo {
    10. /** pipelines with backend info 按顺序将涉及计算的结点构成pipeline*/
    11. std::vector<std::pair<Backend::Info, std::vector<PipelineInfo>>> pipelineInfo;
    12. /** input tensors map */
    13. std::map<std::string, Tensor*> inputTensors;
    14. /** output tensors map */
    15. std::map<std::string, Tensor*> outputTensor;
    16. /** all tensors map */
    17. std::vector<std::pair<int, std::shared_ptr<Tensor>>> allTensors;
    18. /** input valid for resize*/
    19. bool validForResize;
    20. };
    1. Schedule::ScheduleInfo Schedule::schedule(const Net *net, const std::vector<ScheduleConfig> &configs)
    2. {
    3. std::vector<std::shared_ptr<Tensor>> allTensors;
    4. ScheduleInfo schedule;
    5. if (nullptr == net->oplists())
    6. {
    7. MNN_PRINT("Error net for schedule\n");
    8. return schedule;
    9. }
    10. bool valid = _setUpTensorInfo(allTensors, net); // 初始化tensor信息
    11. schedule.validForResize = valid;
    12. // pipeline以及对应后端的构造
    13. std::vector<std::pair<Backend::Info, std::vector<PipelineInfo>>> result;
    14. for (auto &config : configs)
    15. {
    16. Backend::Info compute;
    17. compute.type = _getApprociateType(config, net, allTensors, valid); // 得到合适的推理方式
    18. compute.numThread = config.numThread;
    19. compute.user = config.backendConfig;
    20. auto oplists = _scheduleUnit(net, config, allTensors); // 去除Input算子
    21. result.emplace_back(std::make_pair(compute, std::move(oplists)));
    22. }
    23. schedule.pipelineInfo = std::move(result);
    24. // get all used op's output, drop unused op, won't change op order. always insert all Input Ops
    25. std::set<const Op *> oplists;
    26. {
    27. for (std::pair<Backend::Info, vector<PipelineInfo>> &pipeline : schedule.pipelineInfo)
    28. {
    29. for (auto &info : pipeline.second)
    30. {
    31. oplists.insert(info.op);
    32. }
    33. }
    34. }
    35. std::set<int> outputIndexes;
    36. std::set<int> inputIndexes;
    37. for (auto op : oplists)
    38. {
    39. if (nullptr != op->outputIndexes())
    40. {
    41. auto data = op->outputIndexes()->data();
    42. for (int j = 0; j < op->outputIndexes()->size(); ++j)
    43. {
    44. outputIndexes.insert(data[j]);
    45. }
    46. }
    47. if (nullptr != op->inputIndexes())
    48. {
    49. auto data = op->inputIndexes()->data();
    50. for (int j = 0; j < op->inputIndexes()->size(); ++j)
    51. {
    52. inputIndexes.insert(data[j]);
    53. }
    54. }
    55. MNN_ASSERT(OpType_Input != op->type());
    56. }
    57. // Get All Output and Input 得到输入和输出的结点
    58. std::set<int> inputIndexDiff;
    59. std::set<int> outputIndexesDiff;
    60. std::set_difference(outputIndexes.begin(), outputIndexes.end(), inputIndexes.begin(), inputIndexes.end(),
    61. std::inserter(outputIndexesDiff, outputIndexesDiff.begin()));
    62. std::set_difference(inputIndexes.begin(), inputIndexes.end(), outputIndexes.begin(), outputIndexes.end(),
    63. std::inserter(inputIndexDiff, inputIndexDiff.begin()));
    64. std::unordered_map<std::string, int> tensorNameIndexMap;
    65. for (int i = 0; i < net->tensorName()->size(); ++i)
    66. {
    67. tensorNameIndexMap[net->tensorName()->Get(i)->str()] = i;
    68. }
    69. for (auto &config : configs)
    70. {
    71. for (const auto &name : config.saveTensors)
    72. {
    73. // 默认情况下saveTensors是空的, 如果客户需要取中间tensor的计算结果, 那么传入saveTensors
    74. // saveTensors 也要当做output tensor
    75. if (tensorNameIndexMap.count(name))
    76. {
    77. outputIndexesDiff.insert(tensorNameIndexMap[name]);
    78. }
    79. else
    80. {
    81. MNN_PRINT("Bad outputname: %s\n", name.c_str());
    82. }
    83. }
    84. }
    85. // 把模型本身的output tensor 取出来
    86. if (net->outputName())
    87. {
    88. for (int i = 0; i < net->outputName()->size(); ++i)
    89. {
    90. std::string name = net->outputName()->Get(i)->str();
    91. if (tensorNameIndexMap.count(name))
    92. {
    93. outputIndexesDiff.insert(tensorNameIndexMap[name]);
    94. }
    95. }
    96. }
    97. // 最终的输入、输出、全部Tensor
    98. for (auto index : inputIndexDiff)
    99. {
    100. schedule.inputTensors.insert(
    101. std::make_pair(net->tensorName()->GetAsString(index)->c_str(), allTensors[index].get()));
    102. TensorUtils::getDescribe(allTensors[index].get())->usage = TensorUsage::INPUT;
    103. }
    104. for (auto index : outputIndexesDiff)
    105. {
    106. schedule.outputTensor.insert(
    107. std::make_pair(net->tensorName()->GetAsString(index)->c_str(), allTensors[index].get()));
    108. }
    109. for (auto &t : allTensors)
    110. {
    111. schedule.allTensors.emplace_back(std::make_pair(0, std::move(t)));
    112. }
    113. for (int i = 0; i < net->oplists()->size(); ++i)
    114. {
    115. auto op = net->oplists()->GetAs<Op>(i);
    116. if (nullptr != op->inputIndexes())
    117. {
    118. auto data = op->inputIndexes()->data();
    119. for (int j = 0; j < op->inputIndexes()->size(); ++j)
    120. {
    121. auto index = data[j];
    122. schedule.allTensors[index].first += 1;
    123. // Tensor被算子引用次数
    124. }
    125. }
    126. }
    127. for (auto outputIndex : outputIndexesDiff)
    128. {
    129. TensorUtils::getDescribe(schedule.allTensors[outputIndex].second.get())->usage = TensorUsage::OUTPUT;
    130. schedule.allTensors[outputIndex].first += 1;
    131. // Tensor被输出Tensor引用次数
    132. }
    133. return schedule;
    134. }
    1. static bool _setUpTensorInfo(std::vector<std::shared_ptr<Tensor>> &allTensors, const Net *net)
    2. {
    3. bool valid = true;
    4. auto &tensors = allTensors;
    5. tensors.resize(net->tensorName()->size());
    6. for (int i = 0; i < tensors.size(); ++i) // 依次创建tensor对象
    7. {
    8. tensors[i].reset(new Tensor(4)); // NCHW, TODO
    9. tensors[i]->setType(DataType_DT_FLOAT);
    10. }
    11. // Set Input Tensor, if the type of input is not the same with ExtraTensorDescribe, use input parameter
    12. for (int opIndex = 0; opIndex < net->oplists()->size(); ++opIndex) // 遍历op找到input结点
    13. {
    14. auto op = net->oplists()->GetAs<Op>(opIndex);
    15. if (OpType_Input == op->type())
    16. {
    17. MNN_ASSERT(nullptr != op->outputIndexes());
    18. auto index = op->outputIndexes()->data()[0];
    19. auto tensor = tensors[index].get();
    20. auto &tb = tensor->buffer();
    21. auto inputParam = op->main_as_Input();
    22. if (auto idims = inputParam->dims())
    23. {
    24. for (int i = 0; i < idims->size(); ++i)
    25. {
    26. tb.dim[i].min = 0;
    27. int extent = idims->data()[i];
    28. // dim-0 is batch(when input batch is -1, set it to be 1, ignore other dim)
    29. if (i == 0 && extent == -1)
    30. {
    31. extent = 1;
    32. }
    33. if (extent < 0) // 维度为0,出现错误返回
    34. {
    35. valid = false;
    36. }
    37. tb.dim[i].extent = extent; // buffer上指定的维度
    38. }
    39. tb.dimensions = idims->size();
    40. }
    41. else
    42. {
    43. tb.dimensions = 0;
    44. }
    45. tensor->setType(inputParam->dtype());
    46. TensorUtils::getDescribe(tensor)->dimensionFormat = inputParam->dformat();
    47. }
    48. }
    49. return valid;
    50. }
    1. static MNNForwardType _getApprociateType(const ScheduleConfig& config, const Net* net, const std::vector<std::shared_ptr<Tensor>>& allTensors, bool inputShapeValid) {
    2. MNNForwardType type = config.type;
    3. if (MNN_FORWARD_AUTO == config.type) {
    4. // Search Backend Exclude MNN_FORWARD_CPU
    5. for (int i = 1; i < MNN_FORWARD_ALL; ++i) { // 检查下 传递进来的配置的backend type是否支持
    6. if (MNNGetExtraBackendCreator((MNNForwardType)i) != nullptr) {
    7. type = (MNNForwardType)i;
    8. break;
    9. }
    10. }
    11. }
    12. auto creator = MNNGetExtraBackendCreator(type); // 根据backend type找 backend creater创建backend, 具体过程后面分析
    13. if (nullptr == creator) {
    14. MNN_PRINT("Can't Find type=%d backend, use %d instead\n", type, config.backupType);
    15. type = config.backupType;
    16. }
    17. return type;
    18. }