获取输入tensor

  1. /**
  2. * @brief get input tensor for given name.
  3. * @param session given session.
  4. * @param name given name. if NULL, return first input.
  5. * @return tensor if found, NULL otherwise.
  6. */
  7. Tensor* getSessionInput(const Session* session, const char* name);
  8. /**
  9. * @brief get all input tensors.
  10. * @param session given session.
  11. * @return all output tensors mapped with name.
  12. */
  13. const std::map<std::string, Tensor*>& getSessionInputAll(const Session* session) const;

Interpreter上提供了两个用于获取输入Tensor的方法:getSessionInput用于获取单个输入tensor,
getSessionInputAll用于获取输入tensor映射。

在只有一个输入tensor时,可以在调用getSessionInput时传入NULL以获取tensor。

拷贝数据

NCHW示例,适用 ONNX / Caffe / Torchscripts 转换而来的模型:

  1. auto inputTensor = interpreter->getSessionInput(session, NULL);
  2. auto nchwTensor = new Tensor(inputTensor, Tensor::CAFFE);
  3. // nchwTensor-host<float>()[x] = ...
  4. inputTensor->copyFromHostTensor(nchwTensor);
  5. delete nchwTensor;

NHWC示例,适用于由 Tensorflow / Tflite 转换而来的模型:

  1. auto inputTensor = interpreter->getSessionInput(session, NULL);
  2. auto nhwcTensor = new Tensor(inputTensor, Tensor::TENSORFLOW);
  3. // nhwcTensor-host<float>()[x] = ...
  4. inputTensor->copyFromHostTensor(nhwcTensor);
  5. delete nhwcTensor;

通过这类拷贝数据的方式,用户只需要关注自己创建的tensor的数据布局,copyFromHostTensor会负责处理数据布局上的转换(如需)和后端间的数据拷贝(如需)。

直接填充数据

  1. auto inputTensor = interpreter->getSessionInput(session, NULL);
  2. inputTensor->host<float>()[0] = 1.f;

Tensor上最简洁的输入方式是直接利用host填充数据,但这种使用方式仅限于CPU后端,其他后端需要通过deviceid来输入。另一方面,用户需要自行处理NC4HW4NHWC数据格式上的差异。

对于非CPU后端,或不熟悉数据布局的用户,宜使用拷贝数据接口。

图像处理

MNN中提供了CV模块,可以帮助用户简化图像的处理,还可以免于引入opencv、libyuv等图片处理库。

1、支持目标Tensor为float或 uint8_t 的数据格式 2、支持目标Tensor为NC4HW4或NHWC的维度格式 3、CV模块支持直接输入Device Tensor,也即由Session中获取的Tensor。

图像处理配置

  1. struct Config
  2. {
  3. Filter filterType = NEAREST;
  4. ImageFormat sourceFormat = RGBA;
  5. ImageFormat destFormat = RGBA;
  6. //Only valid if the dest type is float
  7. float mean[4] = {0.0f,0.0f,0.0f, 0.0f};
  8. float normal[4] = {1.0f, 1.0f, 1.0f, 1.0f};
  9. };

CV::ImageProcess::Config

  • 通过sourceFormatdestFormat指定输入和输出的格式,当前支持RGBARGBBGRGRAYBGRAYUV_NV21、YUV_NV12
  • 通过filterType指定插值的类型,当前支持NEARESTBILINEARBICUBIC三种插值方式
  • 通过meannormal指定均值归一化,但数据类型不是浮点类型时,设置会被忽略

图像变换矩阵

CV::Matrix移植自Android 系统使用的Skia引擎,用法可参考Skia的Matrix:https://skia.org/user/api/SkMatrix_Reference

需要注意的是,ImageProcess中设置的Matrix是从目标图像到源图像的变换矩阵。使用时,可以按源图像到目标图像的变换设定,最后取逆。例如:

  1. // 源图像:1280x720
  2. // 目标图像:逆时针旋转90度再缩小到原来的1/10,即变为72x128
  3. Matrix matrix;
  4. // 重设为单位矩阵
  5. matrix.setIdentity();
  6. // 缩小,变换到 [0,1] 区间:
  7. matrix.postScale(1.0f / 1280, 1.0f / 720);
  8. // 以中心点[0.5, 0.5]旋转90度
  9. matrix.postRotate(90, 0.5f, 0.5f);
  10. // 放大回 72x128
  11. matrix.postScale(72.0f, 128.0f);
  12. // 转变为 目标图像 -> 源图的变换矩阵
  13. matrix.invert(&matrix);

图像处理实例

MNN中使用CV::ImageProcess进行图像处理。ImageProcess内部包含一系列缓存,为了避免内存的重复申请释放,建议将其作缓存,仅创建一次。我们使用ImageProcessconvert填充tensor数据。

  1. /*
  2. * source: 源图像地址
  3. * iw: 源图像宽
  4. * ih:源图像高,
  5. * stride:源图像对齐后的一行byte数(若不需要对齐,设成 0(相当于 iw*bpp))
  6. * dest: 目标 tensor,可以为 uint8 或 float 类型
  7. */
  8. ErrorCode convert(const uint8_t* source, int iw, int ih, int stride, Tensor* dest);

完整示例

  1. auto input = net->getSessionInput(session, NULL);
  2. auto output = net->getSessionOutput(session, NULL);
  3. auto dims = input->shape();
  4. int bpp = dims[1];
  5. int size_h = dims[2];
  6. int size_w = dims[3];
  7. auto inputPatch = argv[2];
  8. FREE_IMAGE_FORMAT f = FreeImage_GetFileType(inputPatch);
  9. FIBITMAP* bitmap = FreeImage_Load(f, inputPatch);
  10. auto newBitmap = FreeImage_ConvertTo32Bits(bitmap);
  11. auto width = FreeImage_GetWidth(newBitmap);
  12. auto height = FreeImage_GetHeight(newBitmap);
  13. FreeImage_Unload(bitmap);
  14. Matrix trans;
  15. //Dst -> [0, 1]
  16. trans.postScale(1.0/size_w, 1.0/size_h);
  17. //Flip Y (因为 FreeImage 解出来的图像排列是Y方向相反的)
  18. trans.postScale(1.0,-1.0, 0.0, 0.5);
  19. //[0, 1] -> Src
  20. trans.postScale(width, height);
  21. ImageProcess::Config config;
  22. config.filterType = NEAREST;
  23. float mean[3] = {103.94f, 116.78f, 123.68f};
  24. float normals[3] = {0.017f,0.017f,0.017f};
  25. ::memcpy(config.mean, mean, sizeof(mean));
  26. ::memcpy(config.normal, normals, sizeof(normals));
  27. config.sourceFormat = RGBA;
  28. config.destFormat = BGR;
  29. std::shared_ptr<ImageProcess> pretreat(ImageProcess::create(config));
  30. pretreat->setMatrix(trans);
  31. pretreat->convert((uint8_t*)FreeImage_GetScanLine(newBitmap, 0), width, height, 0, input);
  32. net->runSession(session);

可变维度

  1. /**
  2. * @brief resize given tensor.
  3. * @param tensor given tensor.
  4. * @param dims new dims. at most 6 dims.
  5. */
  6. void resizeTensor(Tensor* tensor, const std::vector<int>& dims);
  7. /**
  8. * @brief resize given tensor by nchw.
  9. * @param batch / N.
  10. * @param channel / C.
  11. * @param height / H.
  12. * @param width / W
  13. */
  14. void resizeTensor(Tensor* tensor, int batch, int channel, int height, int width);
  15. /**
  16. * @brief call this function to get tensors ready. output tensor buffer (host or deviceId) should be retrieved
  17. * after resize of any input tensor.
  18. * @param session given session.
  19. */
  20. void resizeSession(Session* session);

在输入Tensor维度不确定或需要修改时,需要调用resizeTensor来更新维度信息。这种情况一般发生在未设置输入维度和输入维度信息可变的情况。更新完所有Tensor的维度信息之后,需要再调用resizeSession来进行预推理,进行内存分配及复用。示例如下:

  1. auto inputTensor = interpreter->getSessionInput(session, NULL);
  2. interpreter->resizeTensor(inputTensor, {newBatch, newChannel, newHeight, newWidth});
  3. interpreter->resizeSession(session);
  4. inputTensor->copyFromHostTensor(imageTensor);
  5. interpreter->runSession(session);