¶ TVM PackedFunc 实现
为了便于 Python 和 C++ 混合编程,TVM 使用了统一的 PackedFunc 机制。PackedFunc 可以将 C++ 中的各类函数打包成统一的函数接口,并自动导出到 Python 模块中进行调用,并且也支持从 Python 中注册一个函数,并伪装成 PackedFunc 在 C++ 和 Python 中调用。
¶ 预备知识
¶ Python ctypes 混合编程
ctypes 是 Python 自带的跨语言函数调用库,ctypes 提供了简单的 C 数据类型,可以将 C/C++ 动态库中的函数包装成 Python 函数进行调用。
- 导出 C++ 函数
首先在 C++ 中定义一个全局函数,并编译生成 C++ 动态库。// test.hextern "C" {int add(int a, int b);}
// test.cc#include "test.h"int add(int a, int b) {return a + b;}
用 ctypes 模块在 Python 中加载生成的动态库(test.so),并调用 C++ 中的函数。
import ctypes# Load shared library_LIB = ctypes.CDLL("./test.so", ctypes.RTLD_GLOBAL)a = ctypes.c_int(1)b = ctypes.c_int(2)# Call C func in Pythonprint(_LIB.add(a, b))# Orprint(_LIB.add(1, 2))
- 传递 Python 函数到 C++
ctypes 也支持将 Python 函数转换成 C 类型的函数,并在 C/C++ 中进行调用。def add(a, b):return a + b
Python add 有两个参数 a 和 b,返回值类型与 a 和 b 的类型一致。在 C++ 中可以为 Python add 定义一个函数原型 int(int, int)。
extern "C" {typedef int (*PyCFunc)(int, int);int call_py_func(PyCFunc f, int a, int b);}
#include "test.h"int call_py_func(PyCFunc f, int a, int b) {return f(a, b);}
使用 ctypes 将 Python 函数转换成 C function,传入 C++ 中进行调用。
import ctypescfunc = ctypes.CFUNCTYPE(ctypes.c_int, # return typectypes.c_int, # arg0 typectypes.c_int # arg1 type)f = cfunc(add)# CFUNCTYPE is callable in Pythonprint(f(5, 1))# Call Python func in Cprint(_LIB.call_py_func(f, 5, 1))
¶ PackedFunc 实现
¶ PackedFunc 定义
ctypes 可以很方便的将 C/C++ 中的函数导出到 Python,调用时直接传入对应的参数即可,但如果需要将 Python 函数导入到 C/C++,则需要在 C/C++ 中提前定义好对应的函数原型(比如上面的 PyCFunc),并提供对应函数的调用入口(call_py_func)。为了支持更加灵活的函数定义,TVM 将不同类型的函数包装成统一的函数原型。
void(TVMArgs args, TVMRetValue *rv);
统一的函数原型被封装成 PackedFunc 对象,提供通用的调用接口,直接与调用者进行交互。
class PackedFunc {public:using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;template<typename... Args>inline TVMRetValue operator()(Args&& ...args) const;inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;private:/*! \brief internal container of packed function */FType body_;};
当获得一个 PackedFunc 对象时,我们就可以像调用普通函数一样调用 PackedFunc 打包的函数。比如:
PackedFunc f;// f(1, 2)首先会自动将参数1,2打包成TVMArgs,接着调用CallPacked,CallPacked最终的执行体是body_TVMRetValue ret = f(1, 2);
¶ 函数打包
TVM 支持对各类函数进行打包,包括一般的函数、类的成员函数以及 lamda 表达式。
- 函数原型萃取
萃取函数原型是为了得到函数的参数和返回值类型。TVM 中使用 decltype 和模版结构体 function_signature 来实现。
比如定义一个简单的 C 函数,int add(int a, int b) {return a + b;}
接下来就可以使用如下的代码来萃取 add 的函数原型,
template <typename R, typename ...Args>struct function_signature<R(Args...)> {using FType = R(Args...);};// 萃取add的函数原型using FType = function_signature<decltype(add)>::FType;
此外只需要特化 function_signature 就可以支持函数指针和 lambda 表达式。注意:TVM function_signature 不支持普通成员函数的类型萃取,因此 TVM 需要借助一个辅助 function_signature_helper 来对 lambda 表达式类型进行萃取,而我们这里的 function_signature 支持普通成员函数,因此 lambda 表达式类型萃取可以通过递归的 function_signature 来实现。
// 普通函数指针template <typename R, typename ...Args>struct function_signature<R(*)(Args...)> {using FType = R(Args...);};// 非const类的成员函数指针template <typename T, typename R, typename ...Args>struct function_signature<R(T::*)(Args...)> {using FType = R(Args...);};// const类的成员函数指针template <typename T, typename R, typename ...Args>struct function_signature<R(T::*)(Args...) const> {using FType = R(Args...);};// lambda表达式template<typename T>struct function_signature {using FType = typename function_signature<decltype(&T::operator())>::FType;};
- 函数打包
一旦萃取到了函数原型,TVM 就利用 TypedPackedFunc 对普通函数或 lambda 表达式进行打包。TypedPackedFunc 只支持对 R(Args…) 类型的函数打包,所以如果被打包的函数是一个函数指针,则需要创建一个 lambda 表达式,转换成 R(Args…) 类型之后再用 TypedPackedFunc 对创建的 lambda 表达式进行打包。template<typename R, typename ...Args>class TypedPackedFunc<R(Args...)> {public:using TSelf = TypedPackedFunc<R(Args...)>;template<typename FLambda,typename = typename std::enable_if<std::is_convertible<FLambda,std::function<R(Args...)>>::value>::type>TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*)this->AssignTypedLambda(typed_lambda);}...private:...PackedFunc packed_;};
当被打包的函数用来实例化 TypedPackedFunc 对象时,会立刻调用 AssignTypedLambda 将被打包的函数打包成 PackedFunc。
template<typename R, typename ...Args>template<typename FType>inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) {detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);});}
AssignTypedLambda 实际上是将被打包的函数先封装成了一个函数原型为 void(const TVMArgs &args, TVMRetValue *rv) 的 lambda 表达式,然后将这个 lambda 表达式作为 PackedFunc 对象的一个成员,通过设置合适的接口(重载 operator ()),使得 PackedFunc 与被打包的源函数表现的完全一样了。
¶ 自动导出函数
TVM 将需要从 C++ 自动导出的函数打包成 PackedFunc,然后通过宏 TVM_REGISTER_GLOBAL 注册到全局的一个 map 中。比如:
TVM_REGISTER_GLOBAL("_Var").set_body_typed([](std::string s, DataType t) {return VarNode::make(t, s);});
当 Python 加载编译好的动态库时,会自动查询 map 中静态注册的函数,每个函数都包装成 Python 中的 Function 对象,最终添加到 Python 模块中。Function 重定义了函数调用接口,自动完成参数打包过程。
如果是在 Python 中动态注册的函数,则需要在 Python 中通过函数名和来查询 PackedFunc,返回一个 PackedFunc 的 handle(函数指针),并封装成 Function。
def get_global_func(name, allow_missing=False):handle = FunctionHandle()check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))if handle.value:return Function(handle, False)if allow_missing:return Noneraise ValueError("Cannot find global function %s" % name)
注:TVMFuncGetGlobal 是通过 ctypes 导出的 C++ 接口,FunctionHandle 是 ctypes 中表示 void 指针类型(c_void_p)。
¶ 从 Python 注册函数
由于 TVM 中 PackedFunc 的精心设计,我们只需要将 Python 中的函数转换成统一的函数原型 void(const TVMArgs, TVMRetValue),然后将函数转换成 PackedFunc 并动态地注册到全局的 map 中。
先将 Python 函数用 ctypes 转成 int(TVMValue , int , int, void , void ) 的 C 函数。
TVMPackedCFunc = ctypes.CFUNCTYPE(ctypes.c_int,ctypes.POINTER(TVMValue),ctypes.POINTER(ctypes.c_int),ctypes.c_int,ctypes.c_void_p,ctypes.c_void_p)
然后通过 TVMFuncCreateFromCFunc 将上面的 C 函数转换成统一的 PackedFunc 函数。
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,void* resource_handle,TVMPackedCFuncFinalizer fin,TVMFunctionHandle *out) {API_BEGIN();if (fin == nullptr) {*out = new PackedFunc([func, resource_handle](TVMArgs args, TVMRetValue* rv) {int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)args.num_args, rv, resource_handle);if (ret != 0) {throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace());}});} else {...}API_END();}
最后通过接口 TVMFuncRegisterGlobal 注册到全局的 map 中。下面是从 Python 中注册一个函数,并在 Python 中调用的例子。
targs = (10, 10.0, "hello")@tvm.register_funcdef my_packed_func(*args):assert(tuple(args) == targs)return 10# Get it out from global function tablef = tvm.get_global_func("my_packed_func")assert isinstance(f, tvm.nd.Function)y = f(*targs)assert y == 10
https://hjchen2.github.io/2020/01/10/TVM-PackedFunc%E5%AE%9E%E7%8E%B0%E6%9C%BA%E5%88%B6/

