¶ TVM PackedFunc 实现
为了便于 Python 和 C++ 混合编程,TVM 使用了统一的 PackedFunc 机制。PackedFunc 可以将 C++ 中的各类函数打包成统一的函数接口,并自动导出到 Python 模块中进行调用,并且也支持从 Python 中注册一个函数,并伪装成 PackedFunc 在 C++ 和 Python 中调用。
¶ 预备知识
¶ Python ctypes 混合编程
ctypes 是 Python 自带的跨语言函数调用库,ctypes 提供了简单的 C 数据类型,可以将 C/C++ 动态库中的函数包装成 Python 函数进行调用。
- 导出 C++ 函数
首先在 C++ 中定义一个全局函数,并编译生成 C++ 动态库。// test.h
extern "C" {
int add(int a, int b);
}
// test.cc
#include "test.h"
int add(int a, int b) {
return a + b;
}
用 ctypes 模块在 Python 中加载生成的动态库(test.so),并调用 C++ 中的函数。
import ctypes
# Load shared library
_LIB = ctypes.CDLL("./test.so", ctypes.RTLD_GLOBAL)
a = ctypes.c_int(1)
b = ctypes.c_int(2)
# Call C func in Python
print(_LIB.add(a, b))
# Or
print(_LIB.add(1, 2))
- 传递 Python 函数到 C++
ctypes 也支持将 Python 函数转换成 C 类型的函数,并在 C/C++ 中进行调用。def add(a, b):
return a + b
Python add 有两个参数 a 和 b,返回值类型与 a 和 b 的类型一致。在 C++ 中可以为 Python add 定义一个函数原型 int(int, int)。
extern "C" {
typedef int (*PyCFunc)(int, int);
int call_py_func(PyCFunc f, int a, int b);
}
#include "test.h"
int call_py_func(PyCFunc f, int a, int b) {
return f(a, b);
}
使用 ctypes 将 Python 函数转换成 C function,传入 C++ 中进行调用。
import ctypes
cfunc = ctypes.CFUNCTYPE(
ctypes.c_int, # return type
ctypes.c_int, # arg0 type
ctypes.c_int # arg1 type
)
f = cfunc(add)
# CFUNCTYPE is callable in Python
print(f(5, 1))
# Call Python func in C
print(_LIB.call_py_func(f, 5, 1))
¶ PackedFunc 实现
¶ PackedFunc 定义
ctypes 可以很方便的将 C/C++ 中的函数导出到 Python,调用时直接传入对应的参数即可,但如果需要将 Python 函数导入到 C/C++,则需要在 C/C++ 中提前定义好对应的函数原型(比如上面的 PyCFunc),并提供对应函数的调用入口(call_py_func)。为了支持更加灵活的函数定义,TVM 将不同类型的函数包装成统一的函数原型。
void(TVMArgs args, TVMRetValue *rv);
统一的函数原型被封装成 PackedFunc 对象,提供通用的调用接口,直接与调用者进行交互。
class PackedFunc {
public:
using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;
template<typename... Args>
inline TVMRetValue operator()(Args&& ...args) const;
inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;
private:
/*! \brief internal container of packed function */
FType body_;
};
当获得一个 PackedFunc 对象时,我们就可以像调用普通函数一样调用 PackedFunc 打包的函数。比如:
PackedFunc f;
// f(1, 2)首先会自动将参数1,2打包成TVMArgs,接着调用CallPacked,CallPacked最终的执行体是body_
TVMRetValue ret = f(1, 2);
¶ 函数打包
TVM 支持对各类函数进行打包,包括一般的函数、类的成员函数以及 lamda 表达式。
- 函数原型萃取
萃取函数原型是为了得到函数的参数和返回值类型。TVM 中使用 decltype 和模版结构体 function_signature 来实现。
比如定义一个简单的 C 函数,int add(int a, int b) {
return a + b;
}
接下来就可以使用如下的代码来萃取 add 的函数原型,
template <typename R, typename ...Args>
struct function_signature<R(Args...)> {
using FType = R(Args...);
};
// 萃取add的函数原型
using FType = function_signature<decltype(add)>::FType;
此外只需要特化 function_signature 就可以支持函数指针和 lambda 表达式。注意:TVM function_signature 不支持普通成员函数的类型萃取,因此 TVM 需要借助一个辅助 function_signature_helper 来对 lambda 表达式类型进行萃取,而我们这里的 function_signature 支持普通成员函数,因此 lambda 表达式类型萃取可以通过递归的 function_signature 来实现。
// 普通函数指针
template <typename R, typename ...Args>
struct function_signature<R(*)(Args...)> {
using FType = R(Args...);
};
// 非const类的成员函数指针
template <typename T, typename R, typename ...Args>
struct function_signature<R(T::*)(Args...)> {
using FType = R(Args...);
};
// const类的成员函数指针
template <typename T, typename R, typename ...Args>
struct function_signature<R(T::*)(Args...) const> {
using FType = R(Args...);
};
// lambda表达式
template<typename T>
struct function_signature {
using FType = typename function_signature<decltype(&T::operator())>::FType;
};
- 函数打包
一旦萃取到了函数原型,TVM 就利用 TypedPackedFunc 对普通函数或 lambda 表达式进行打包。TypedPackedFunc 只支持对 R(Args…) 类型的函数打包,所以如果被打包的函数是一个函数指针,则需要创建一个 lambda 表达式,转换成 R(Args…) 类型之后再用 TypedPackedFunc 对创建的 lambda 表达式进行打包。template<typename R, typename ...Args>
class TypedPackedFunc<R(Args...)> {
public:
using TSelf = TypedPackedFunc<R(Args...)>;
template<typename FLambda,
typename = typename std::enable_if<
std::is_convertible<FLambda,
std::function<R(Args...)>
>::value>::type>
TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*)
this->AssignTypedLambda(typed_lambda);
}
...
private:
...
PackedFunc packed_;
};
当被打包的函数用来实例化 TypedPackedFunc 对象时,会立刻调用 AssignTypedLambda 将被打包的函数打包成 PackedFunc。
template<typename R, typename ...Args>
template<typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) {
detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
});
}
AssignTypedLambda 实际上是将被打包的函数先封装成了一个函数原型为 void(const TVMArgs &args, TVMRetValue *rv) 的 lambda 表达式,然后将这个 lambda 表达式作为 PackedFunc 对象的一个成员,通过设置合适的接口(重载 operator ()),使得 PackedFunc 与被打包的源函数表现的完全一样了。
¶ 自动导出函数
TVM 将需要从 C++ 自动导出的函数打包成 PackedFunc,然后通过宏 TVM_REGISTER_GLOBAL 注册到全局的一个 map 中。比如:
TVM_REGISTER_GLOBAL("_Var")
.set_body_typed([](std::string s, DataType t) {
return VarNode::make(t, s);
});
当 Python 加载编译好的动态库时,会自动查询 map 中静态注册的函数,每个函数都包装成 Python 中的 Function 对象,最终添加到 Python 模块中。Function 重定义了函数调用接口,自动完成参数打包过程。
如果是在 Python 中动态注册的函数,则需要在 Python 中通过函数名和来查询 PackedFunc,返回一个 PackedFunc 的 handle(函数指针),并封装成 Function。
def get_global_func(name, allow_missing=False):
handle = FunctionHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
if handle.value:
return Function(handle, False)
if allow_missing:
return None
raise ValueError("Cannot find global function %s" % name)
注:TVMFuncGetGlobal 是通过 ctypes 导出的 C++ 接口,FunctionHandle 是 ctypes 中表示 void 指针类型(c_void_p)。
¶ 从 Python 注册函数
由于 TVM 中 PackedFunc 的精心设计,我们只需要将 Python 中的函数转换成统一的函数原型 void(const TVMArgs, TVMRetValue),然后将函数转换成 PackedFunc 并动态地注册到全局的 map 中。
先将 Python 函数用 ctypes 转成 int(TVMValue , int , int, void , void ) 的 C 函数。
TVMPackedCFunc = ctypes.CFUNCTYPE(
ctypes.c_int,
ctypes.POINTER(TVMValue),
ctypes.POINTER(ctypes.c_int),
ctypes.c_int,
ctypes.c_void_p,
ctypes.c_void_p)
然后通过 TVMFuncCreateFromCFunc 将上面的 C 函数转换成统一的 PackedFunc 函数。
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out) {
API_BEGIN();
if (fin == nullptr) {
*out = new PackedFunc(
[func, resource_handle](TVMArgs args, TVMRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
if (ret != 0) {
throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace());
}
});
} else {
...
}
API_END();
}
最后通过接口 TVMFuncRegisterGlobal 注册到全局的 map 中。下面是从 Python 中注册一个函数,并在 Python 中调用的例子。
targs = (10, 10.0, "hello")
@tvm.register_func
def my_packed_func(*args):
assert(tuple(args) == targs)
return 10
# Get it out from global function table
f = tvm.get_global_func("my_packed_func")
assert isinstance(f, tvm.nd.Function)
y = f(*targs)
assert y == 10
https://hjchen2.github.io/2020/01/10/TVM-PackedFunc%E5%AE%9E%E7%8E%B0%E6%9C%BA%E5%88%B6/