struct PipelineInfo { /** op */ const Op* op; /** input tensors */ std::vector<Tensor*> inputs; /** output tensors */ std::vector<Tensor*> outputs;};struct ScheduleInfo { /** pipelines with backend info 按顺序将涉及计算的结点构成pipeline*/ std::vector<std::pair<Backend::Info, std::vector<PipelineInfo>>> pipelineInfo; /** input tensors map */ std::map<std::string, Tensor*> inputTensors; /** output tensors map */ std::map<std::string, Tensor*> outputTensor; /** all tensors map */ std::vector<std::pair<int, std::shared_ptr<Tensor>>> allTensors; /** input valid for resize*/ bool validForResize;};
Schedule::ScheduleInfo Schedule::schedule(const Net *net, const std::vector<ScheduleConfig> &configs){ std::vector<std::shared_ptr<Tensor>> allTensors; ScheduleInfo schedule; if (nullptr == net->oplists()) { MNN_PRINT("Error net for schedule\n"); return schedule; } bool valid = _setUpTensorInfo(allTensors, net); // 初始化tensor信息 schedule.validForResize = valid; // pipeline以及对应后端的构造 std::vector<std::pair<Backend::Info, std::vector<PipelineInfo>>> result; for (auto &config : configs) { Backend::Info compute; compute.type = _getApprociateType(config, net, allTensors, valid); // 得到合适的推理方式 compute.numThread = config.numThread; compute.user = config.backendConfig; auto oplists = _scheduleUnit(net, config, allTensors); // 去除Input算子 result.emplace_back(std::make_pair(compute, std::move(oplists))); } schedule.pipelineInfo = std::move(result); // get all used op's output, drop unused op, won't change op order. always insert all Input Ops std::set<const Op *> oplists; { for (std::pair<Backend::Info, vector<PipelineInfo>> &pipeline : schedule.pipelineInfo) { for (auto &info : pipeline.second) { oplists.insert(info.op); } } } std::set<int> outputIndexes; std::set<int> inputIndexes; for (auto op : oplists) { if (nullptr != op->outputIndexes()) { auto data = op->outputIndexes()->data(); for (int j = 0; j < op->outputIndexes()->size(); ++j) { outputIndexes.insert(data[j]); } } if (nullptr != op->inputIndexes()) { auto data = op->inputIndexes()->data(); for (int j = 0; j < op->inputIndexes()->size(); ++j) { inputIndexes.insert(data[j]); } } MNN_ASSERT(OpType_Input != op->type()); } // Get All Output and Input 得到输入和输出的结点 std::set<int> inputIndexDiff; std::set<int> outputIndexesDiff; std::set_difference(outputIndexes.begin(), outputIndexes.end(), inputIndexes.begin(), inputIndexes.end(), std::inserter(outputIndexesDiff, outputIndexesDiff.begin())); std::set_difference(inputIndexes.begin(), inputIndexes.end(), outputIndexes.begin(), outputIndexes.end(), std::inserter(inputIndexDiff, inputIndexDiff.begin())); std::unordered_map<std::string, int> tensorNameIndexMap; for (int i = 0; i < net->tensorName()->size(); ++i) { tensorNameIndexMap[net->tensorName()->Get(i)->str()] = i; } for (auto &config : configs) { for (const auto &name : config.saveTensors) { // 默认情况下saveTensors是空的, 如果客户需要取中间tensor的计算结果, 那么传入saveTensors // saveTensors 也要当做output tensor if (tensorNameIndexMap.count(name)) { outputIndexesDiff.insert(tensorNameIndexMap[name]); } else { MNN_PRINT("Bad outputname: %s\n", name.c_str()); } } } // 把模型本身的output tensor 取出来 if (net->outputName()) { for (int i = 0; i < net->outputName()->size(); ++i) { std::string name = net->outputName()->Get(i)->str(); if (tensorNameIndexMap.count(name)) { outputIndexesDiff.insert(tensorNameIndexMap[name]); } } } // 最终的输入、输出、全部Tensor for (auto index : inputIndexDiff) { schedule.inputTensors.insert( std::make_pair(net->tensorName()->GetAsString(index)->c_str(), allTensors[index].get())); TensorUtils::getDescribe(allTensors[index].get())->usage = TensorUsage::INPUT; } for (auto index : outputIndexesDiff) { schedule.outputTensor.insert( std::make_pair(net->tensorName()->GetAsString(index)->c_str(), allTensors[index].get())); } for (auto &t : allTensors) { schedule.allTensors.emplace_back(std::make_pair(0, std::move(t))); } for (int i = 0; i < net->oplists()->size(); ++i) { auto op = net->oplists()->GetAs<Op>(i); if (nullptr != op->inputIndexes()) { auto data = op->inputIndexes()->data(); for (int j = 0; j < op->inputIndexes()->size(); ++j) { auto index = data[j]; schedule.allTensors[index].first += 1; // Tensor被算子引用次数 } } } for (auto outputIndex : outputIndexesDiff) { TensorUtils::getDescribe(schedule.allTensors[outputIndex].second.get())->usage = TensorUsage::OUTPUT; schedule.allTensors[outputIndex].first += 1; // Tensor被输出Tensor引用次数 } return schedule;}
static bool _setUpTensorInfo(std::vector<std::shared_ptr<Tensor>> &allTensors, const Net *net){ bool valid = true; auto &tensors = allTensors; tensors.resize(net->tensorName()->size()); for (int i = 0; i < tensors.size(); ++i) // 依次创建tensor对象 { tensors[i].reset(new Tensor(4)); // NCHW, TODO tensors[i]->setType(DataType_DT_FLOAT); } // Set Input Tensor, if the type of input is not the same with ExtraTensorDescribe, use input parameter for (int opIndex = 0; opIndex < net->oplists()->size(); ++opIndex) // 遍历op找到input结点 { auto op = net->oplists()->GetAs<Op>(opIndex); if (OpType_Input == op->type()) { MNN_ASSERT(nullptr != op->outputIndexes()); auto index = op->outputIndexes()->data()[0]; auto tensor = tensors[index].get(); auto &tb = tensor->buffer(); auto inputParam = op->main_as_Input(); if (auto idims = inputParam->dims()) { for (int i = 0; i < idims->size(); ++i) { tb.dim[i].min = 0; int extent = idims->data()[i]; // dim-0 is batch(when input batch is -1, set it to be 1, ignore other dim) if (i == 0 && extent == -1) { extent = 1; } if (extent < 0) // 维度为0,出现错误返回 { valid = false; } tb.dim[i].extent = extent; // buffer上指定的维度 } tb.dimensions = idims->size(); } else { tb.dimensions = 0; } tensor->setType(inputParam->dtype()); TensorUtils::getDescribe(tensor)->dimensionFormat = inputParam->dformat(); } } return valid;}
static MNNForwardType _getApprociateType(const ScheduleConfig& config, const Net* net, const std::vector<std::shared_ptr<Tensor>>& allTensors, bool inputShapeValid) { MNNForwardType type = config.type; if (MNN_FORWARD_AUTO == config.type) { // Search Backend Exclude MNN_FORWARD_CPU for (int i = 1; i < MNN_FORWARD_ALL; ++i) { // 检查下 传递进来的配置的backend type是否支持 if (MNNGetExtraBackendCreator((MNNForwardType)i) != nullptr) { type = (MNNForwardType)i; break; } } } auto creator = MNNGetExtraBackendCreator(type); // 根据backend type找 backend creater创建backend, 具体过程后面分析 if (nullptr == creator) { MNN_PRINT("Can't Find type=%d backend, use %d instead\n", type, config.backupType); type = config.backupType; } return type; }