获取输出tensor

  1. /**
  2. * @brief get output tensor for given name.
  3. * @param session given session.
  4. * @param name given name. if NULL, return first output.
  5. * @return tensor if found, NULL otherwise.
  6. */
  7. Tensor* getSessionOutput(const Session* session, const char* name);
  8. /**
  9. * @brief get all output tensors.
  10. * @param session given session.
  11. * @return all output tensors mapped with name.
  12. */
  13. const std::map<std::string, Tensor*>& getSessionOutputAll(const Session* session) const;

Interpreter上提供了两个用于获取输出Tensor的方法:getSessionOutput用于获取单个输出tensor,
getSessionOutputAll用于获取输出tensor映射。

在只有一个输出tensor时,可以在调用getSessionOutput时传入NULL以获取tensor。

拷贝数据

不熟悉MNN源码的用户,必须使用这种方式获取输出!!!
NCHW (适用于 Caffe / TorchScript / Onnx 转换而来的模型)示例:

  1. auto outputTensor = interpreter->getSessionOutput(session, NULL);
  2. auto nchwTensor = new Tensor(outputTensor, Tensor::CAFFE);
  3. outputTensor->copyToHostTensor(nchwTensor);
  4. auto score = nchwTensor->host<float>()[0];
  5. auto index = nchwTensor->host<float>()[1];
  6. // ...
  7. delete nchwTensor;

NHWC (适用于 Tensorflow / Tflite 转换而来的模型)示例:

  1. auto outputTensor = interpreter->getSessionOutput(session, NULL);
  2. auto nhwcTensor = new Tensor(outputTensor, Tensor::TENSORFLOW);
  3. outputTensor->copyToHostTensor(nhwcTensor);
  4. auto score = nhwcTensor->host<float>()[0];
  5. auto index = nhwcTensor->host<float>()[1];
  6. // ...
  7. delete nhwcTensor;

通过这类拷贝数据的方式,用户只需要关注自己创建的tensor的数据布局,**copyToHostTensor**会负责处理数据布局上的转换(如需)和后端间的数据拷贝(如需)。

直接读取数据

由于绝大多数用户都不熟悉MNN底层数据布局,所以不要使用这种方式!!!

  1. auto outputTensor = interpreter->getSessionOutput(session, NULL);
  2. auto score = outputTensor->host<float>()[0];
  3. auto index = outputTensor->host<float>()[1];
  4. // ...

Tensor上最简洁的输出方式是直接读取host数据,但这种使用方式仅限于CPU后端,其他后端需要通过deviceid来读取数据。另一方面,用户需要自行处理NC4HW4NHWC数据格式上的差异。

对于非CPU后端,或不熟悉数据布局的用户,宜使用拷贝数据接口。