简介

TVM主要由c++和python实现,也提供了很多其他语言接口。为了实现多语言混合调用,TVM设计了一套机制。

PackedFunc封装函数

参见:

include/tvm/runtime/packed_func.h

TVM中python与底层c++交互接口使用了python自带的ctypes。
为了支持更加灵活的函数定义,TVM将不同类型的函数包装成统一的函数原型。它对参数类型做了限制,但对于深度学习大多数场景都足够,它通过c++模板来消除类型限制。

void(TVMArgs args, TVMRetValue *rv);

  1. class PackedFunc {
  2. public:
  3. using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;
  4. template<typename... Args>
  5. inline TVMRetValue operator()(Args&& ...args) const;
  6. inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;
  7. private:
  8. /*! \brief internal container of packed function */
  9. FType body_;
  10. };

全局函数注册管理

参见:

include/tvm/runtime/c_runtime_api.h include/tvm/runtime/registry.h python/tvm/_ffi

TVM中使用了一个Registry类来管理全局,通过它可以查看、添加、删除全局函数

  1. /*! \brief Registry for global function */
  2. class Registry {
  3. public:
  4. //设置函数体
  5. TVM_DLL Registry& set_body(PackedFunc f); // NOLINT(*)
  6. Registry& set_body(PackedFunc::FType f) { // NOLINT(*)
  7. return set_body(PackedFunc(f));
  8. }
  9. //给一个任意函数,萃取函数签名
  10. template <typename FLambda>
  11. Registry& set_body_typed(FLambda f) {
  12. using FType = typename detail::function_signature<FLambda>::FType;
  13. return set_body(TypedPackedFunc<FType>(std::move(f), name_).packed());
  14. }
  15. //给一个类成员函数、返回值、参数,使用lambda包装
  16. template <typename T, typename R, typename... Args>
  17. Registry& set_body_method(R (T::*f)(Args...)) {
  18. auto fwrap = [f](T target, Args... params) -> R {
  19. // call method pointer
  20. return (target.*f)(params...);
  21. };
  22. return set_body(TypedPackedFunc<R(T, Args...)>(fwrap, name_));
  23. }
  24. template <typename T, typename R, typename... Args>
  25. Registry& set_body_method(R (T::*f)(Args...) const) {
  26. auto fwrap = [f](const T target, Args... params) -> R {
  27. // call method pointer
  28. return (target.*f)(params...);
  29. };
  30. return set_body(TypedPackedFunc<R(const T, Args...)>(fwrap, name_));
  31. }
  32. //
  33. template <typename TObjectRef, typename TNode, typename R, typename... Args,
  34. typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
  35. Registry& set_body_method(R (TNode::*f)(Args...)) {
  36. auto fwrap = [f](TObjectRef ref, Args... params) {
  37. TNode* target = ref.operator->();
  38. // call method pointer
  39. return (target->*f)(params...);
  40. };
  41. return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_));
  42. }
  43. template <typename TObjectRef, typename TNode, typename R, typename... Args,
  44. typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
  45. Registry& set_body_method(R (TNode::*f)(Args...) const) {
  46. auto fwrap = [f](TObjectRef ref, Args... params) {
  47. const TNode* target = ref.operator->();
  48. // call method pointer
  49. return (target->*f)(params...);
  50. };
  51. return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_));
  52. }
  53. TVM_DLL static Registry& Register(const std::string& name, bool override = false); // NOLINT(*)
  54. TVM_DLL static bool Remove(const std::string& name);
  55. TVM_DLL static const PackedFunc* Get(const std::string& name);
  56. TVM_DLL static std::vector<std::string> ListNames();
  57. struct Manager;
  58. protected:
  59. std::string name_;
  60. PackedFunc func_;
  61. friend struct Manager;
  62. };

注册方式

c++ 使用宏 TVM_REGISTER_GLOBAL

  1. TVM_REGISTER_GLOBAL("hello")
  2. .set_body([](TVMArgs args, TVMRetValue* rv) {
  3. PackedFunc f = args[0];
  4. f("hello world");
  5. });

python中使用装饰器 @tvm._ffi.register_func,内部调用了TVMFuncRegisterGlobal

  1. @tvm._ffi.register_func("relay.backend.lower_call")

主要接口

  1. TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override);
  2. TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
  3. TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array);
  4. TVM_DLL int TVMFuncRemoveGlobal(const char* name);
  1. def register_func(func_name, f=None, override=False):
  2. def get_global_func(name, allow_missing=False):
  3. def list_global_func_names():
  4. def remove_global_func(name)

Python关联全局函数到不同模块

Python中通过 _init_api这个接口,根据设定的前缀来动态设定属性完成绑定,实现如下:

  1. # python/tvm/_ffi/registry.py
  2. def _init_api(prefix, module_name):
  3. target_module = sys.modules[module_name]
  4. for name in list_global_func_names():
  5. if not name.startswith(prefix):
  6. continue
  7. fname = name[len(prefix) + 1 :]
  8. f = get_global_func(name)
  9. ff = _get_api(f)
  10. ff.__name__ = fname
  11. ff.__doc__ = "TVM PackedFunc %s. " % fname
  12. setattr(target_module, ff.__name__, ff)