简介
TVM主要由c++和python实现,也提供了很多其他语言接口。为了实现多语言混合调用,TVM设计了一套机制。
PackedFunc封装函数
参见:
include/tvm/runtime/packed_func.h
TVM中python与底层c++交互接口使用了python自带的ctypes。
为了支持更加灵活的函数定义,TVM将不同类型的函数包装成统一的函数原型。它对参数类型做了限制,但对于深度学习大多数场景都足够,它通过c++模板来消除类型限制。
void(TVMArgs args, TVMRetValue *rv);
class PackedFunc {
public:
using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;
template<typename... Args>
inline TVMRetValue operator()(Args&& ...args) const;
inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;
private:
/*! \brief internal container of packed function */
FType body_;
};
全局函数注册管理
参见:
include/tvm/runtime/c_runtime_api.h include/tvm/runtime/registry.h python/tvm/_ffi
TVM中使用了一个Registry类来管理全局,通过它可以查看、添加、删除全局函数
/*! \brief Registry for global function */
class Registry {
public:
//设置函数体
TVM_DLL Registry& set_body(PackedFunc f); // NOLINT(*)
Registry& set_body(PackedFunc::FType f) { // NOLINT(*)
return set_body(PackedFunc(f));
}
//给一个任意函数,萃取函数签名
template <typename FLambda>
Registry& set_body_typed(FLambda f) {
using FType = typename detail::function_signature<FLambda>::FType;
return set_body(TypedPackedFunc<FType>(std::move(f), name_).packed());
}
//给一个类成员函数、返回值、参数,使用lambda包装
template <typename T, typename R, typename... Args>
Registry& set_body_method(R (T::*f)(Args...)) {
auto fwrap = [f](T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
};
return set_body(TypedPackedFunc<R(T, Args...)>(fwrap, name_));
}
template <typename T, typename R, typename... Args>
Registry& set_body_method(R (T::*f)(Args...) const) {
auto fwrap = [f](const T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
};
return set_body(TypedPackedFunc<R(const T, Args...)>(fwrap, name_));
}
//
template <typename TObjectRef, typename TNode, typename R, typename... Args,
typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...)) {
auto fwrap = [f](TObjectRef ref, Args... params) {
TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
};
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_));
}
template <typename TObjectRef, typename TNode, typename R, typename... Args,
typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...) const) {
auto fwrap = [f](TObjectRef ref, Args... params) {
const TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
};
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_));
}
TVM_DLL static Registry& Register(const std::string& name, bool override = false); // NOLINT(*)
TVM_DLL static bool Remove(const std::string& name);
TVM_DLL static const PackedFunc* Get(const std::string& name);
TVM_DLL static std::vector<std::string> ListNames();
struct Manager;
protected:
std::string name_;
PackedFunc func_;
friend struct Manager;
};
注册方式
c++ 使用宏 TVM_REGISTER_GLOBAL
TVM_REGISTER_GLOBAL("hello")
.set_body([](TVMArgs args, TVMRetValue* rv) {
PackedFunc f = args[0];
f("hello world");
});
python中使用装饰器 @tvm._ffi.register_func,内部调用了TVMFuncRegisterGlobal
@tvm._ffi.register_func("relay.backend.lower_call")
主要接口
TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override);
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array);
TVM_DLL int TVMFuncRemoveGlobal(const char* name);
def register_func(func_name, f=None, override=False):
def get_global_func(name, allow_missing=False):
def list_global_func_names():
def remove_global_func(name)
Python关联全局函数到不同模块
Python中通过 _init_api这个接口,根据设定的前缀来动态设定属性完成绑定,实现如下:
# python/tvm/_ffi/registry.py
def _init_api(prefix, module_name):
target_module = sys.modules[module_name]
for name in list_global_func_names():
if not name.startswith(prefix):
continue
fname = name[len(prefix) + 1 :]
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = "TVM PackedFunc %s. " % fname
setattr(target_module, ff.__name__, ff)