struct PipelineInfo {
/** op */
const Op* op;
/** input tensors */
std::vector<Tensor*> inputs;
/** output tensors */
std::vector<Tensor*> outputs;
};
struct ScheduleInfo {
/** pipelines with backend info 按顺序将涉及计算的结点构成pipeline*/
std::vector<std::pair<Backend::Info, std::vector<PipelineInfo>>> pipelineInfo;
/** input tensors map */
std::map<std::string, Tensor*> inputTensors;
/** output tensors map */
std::map<std::string, Tensor*> outputTensor;
/** all tensors map */
std::vector<std::pair<int, std::shared_ptr<Tensor>>> allTensors;
/** input valid for resize*/
bool validForResize;
};
Schedule::ScheduleInfo Schedule::schedule(const Net *net, const std::vector<ScheduleConfig> &configs)
{
std::vector<std::shared_ptr<Tensor>> allTensors;
ScheduleInfo schedule;
if (nullptr == net->oplists())
{
MNN_PRINT("Error net for schedule\n");
return schedule;
}
bool valid = _setUpTensorInfo(allTensors, net); // 初始化tensor信息
schedule.validForResize = valid;
// pipeline以及对应后端的构造
std::vector<std::pair<Backend::Info, std::vector<PipelineInfo>>> result;
for (auto &config : configs)
{
Backend::Info compute;
compute.type = _getApprociateType(config, net, allTensors, valid); // 得到合适的推理方式
compute.numThread = config.numThread;
compute.user = config.backendConfig;
auto oplists = _scheduleUnit(net, config, allTensors); // 去除Input算子
result.emplace_back(std::make_pair(compute, std::move(oplists)));
}
schedule.pipelineInfo = std::move(result);
// get all used op's output, drop unused op, won't change op order. always insert all Input Ops
std::set<const Op *> oplists;
{
for (std::pair<Backend::Info, vector<PipelineInfo>> &pipeline : schedule.pipelineInfo)
{
for (auto &info : pipeline.second)
{
oplists.insert(info.op);
}
}
}
std::set<int> outputIndexes;
std::set<int> inputIndexes;
for (auto op : oplists)
{
if (nullptr != op->outputIndexes())
{
auto data = op->outputIndexes()->data();
for (int j = 0; j < op->outputIndexes()->size(); ++j)
{
outputIndexes.insert(data[j]);
}
}
if (nullptr != op->inputIndexes())
{
auto data = op->inputIndexes()->data();
for (int j = 0; j < op->inputIndexes()->size(); ++j)
{
inputIndexes.insert(data[j]);
}
}
MNN_ASSERT(OpType_Input != op->type());
}
// Get All Output and Input 得到输入和输出的结点
std::set<int> inputIndexDiff;
std::set<int> outputIndexesDiff;
std::set_difference(outputIndexes.begin(), outputIndexes.end(), inputIndexes.begin(), inputIndexes.end(),
std::inserter(outputIndexesDiff, outputIndexesDiff.begin()));
std::set_difference(inputIndexes.begin(), inputIndexes.end(), outputIndexes.begin(), outputIndexes.end(),
std::inserter(inputIndexDiff, inputIndexDiff.begin()));
std::unordered_map<std::string, int> tensorNameIndexMap;
for (int i = 0; i < net->tensorName()->size(); ++i)
{
tensorNameIndexMap[net->tensorName()->Get(i)->str()] = i;
}
for (auto &config : configs)
{
for (const auto &name : config.saveTensors)
{
// 默认情况下saveTensors是空的, 如果客户需要取中间tensor的计算结果, 那么传入saveTensors
// saveTensors 也要当做output tensor
if (tensorNameIndexMap.count(name))
{
outputIndexesDiff.insert(tensorNameIndexMap[name]);
}
else
{
MNN_PRINT("Bad outputname: %s\n", name.c_str());
}
}
}
// 把模型本身的output tensor 取出来
if (net->outputName())
{
for (int i = 0; i < net->outputName()->size(); ++i)
{
std::string name = net->outputName()->Get(i)->str();
if (tensorNameIndexMap.count(name))
{
outputIndexesDiff.insert(tensorNameIndexMap[name]);
}
}
}
// 最终的输入、输出、全部Tensor
for (auto index : inputIndexDiff)
{
schedule.inputTensors.insert(
std::make_pair(net->tensorName()->GetAsString(index)->c_str(), allTensors[index].get()));
TensorUtils::getDescribe(allTensors[index].get())->usage = TensorUsage::INPUT;
}
for (auto index : outputIndexesDiff)
{
schedule.outputTensor.insert(
std::make_pair(net->tensorName()->GetAsString(index)->c_str(), allTensors[index].get()));
}
for (auto &t : allTensors)
{
schedule.allTensors.emplace_back(std::make_pair(0, std::move(t)));
}
for (int i = 0; i < net->oplists()->size(); ++i)
{
auto op = net->oplists()->GetAs<Op>(i);
if (nullptr != op->inputIndexes())
{
auto data = op->inputIndexes()->data();
for (int j = 0; j < op->inputIndexes()->size(); ++j)
{
auto index = data[j];
schedule.allTensors[index].first += 1;
// Tensor被算子引用次数
}
}
}
for (auto outputIndex : outputIndexesDiff)
{
TensorUtils::getDescribe(schedule.allTensors[outputIndex].second.get())->usage = TensorUsage::OUTPUT;
schedule.allTensors[outputIndex].first += 1;
// Tensor被输出Tensor引用次数
}
return schedule;
}
static bool _setUpTensorInfo(std::vector<std::shared_ptr<Tensor>> &allTensors, const Net *net)
{
bool valid = true;
auto &tensors = allTensors;
tensors.resize(net->tensorName()->size());
for (int i = 0; i < tensors.size(); ++i) // 依次创建tensor对象
{
tensors[i].reset(new Tensor(4)); // NCHW, TODO
tensors[i]->setType(DataType_DT_FLOAT);
}
// Set Input Tensor, if the type of input is not the same with ExtraTensorDescribe, use input parameter
for (int opIndex = 0; opIndex < net->oplists()->size(); ++opIndex) // 遍历op找到input结点
{
auto op = net->oplists()->GetAs<Op>(opIndex);
if (OpType_Input == op->type())
{
MNN_ASSERT(nullptr != op->outputIndexes());
auto index = op->outputIndexes()->data()[0];
auto tensor = tensors[index].get();
auto &tb = tensor->buffer();
auto inputParam = op->main_as_Input();
if (auto idims = inputParam->dims())
{
for (int i = 0; i < idims->size(); ++i)
{
tb.dim[i].min = 0;
int extent = idims->data()[i];
// dim-0 is batch(when input batch is -1, set it to be 1, ignore other dim)
if (i == 0 && extent == -1)
{
extent = 1;
}
if (extent < 0) // 维度为0,出现错误返回
{
valid = false;
}
tb.dim[i].extent = extent; // buffer上指定的维度
}
tb.dimensions = idims->size();
}
else
{
tb.dimensions = 0;
}
tensor->setType(inputParam->dtype());
TensorUtils::getDescribe(tensor)->dimensionFormat = inputParam->dformat();
}
}
return valid;
}
static MNNForwardType _getApprociateType(const ScheduleConfig& config, const Net* net, const std::vector<std::shared_ptr<Tensor>>& allTensors, bool inputShapeValid) {
MNNForwardType type = config.type;
if (MNN_FORWARD_AUTO == config.type) {
// Search Backend Exclude MNN_FORWARD_CPU
for (int i = 1; i < MNN_FORWARD_ALL; ++i) { // 检查下 传递进来的配置的backend type是否支持
if (MNNGetExtraBackendCreator((MNNForwardType)i) != nullptr) {
type = (MNNForwardType)i;
break;
}
}
}
auto creator = MNNGetExtraBackendCreator(type); // 根据backend type找 backend creater创建backend, 具体过程后面分析
if (nullptr == creator) {
MNN_PRINT("Can't Find type=%d backend, use %d instead\n", type, config.backupType);
type = config.backupType;
}
return type;
}