下面三个是定义在python/tvm/relay/testing/init.py的函数
relay.testing.run_opt_pass
relay.testing.run_infer_type
relay.testing.rand
# 运行pass
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
# 推断类型
def run_infer_type(expr):
return run_opt_pass(expr, transform.InferType())
# 指定数据类型和形状,获取随机值
def rand(dtype, *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
下面的是一些自定义的函数:
# 绑定参数
def bind_params(func, params):
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = tvm.relay.expr.const(v)
return tvm.relay.expr.bind(func, bind_dict)
# 绑定参数,然后运行pass (用于测试那些需要处理参数的的pass)
def run_opt_pass(expr, opt_pass, params):
assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr)
entry = bind_params(mod["main"], params)
mod = relay.Module.from_expr(entry)
mod = opt_pass(mod)
return mod
# 绑定参数,编译,运行,获得运行结果
def build_run(expr, params):
mod = expr if isinstance(expr, relay.Module) else relay.Module.from_expr(expr)
target = tvm.target.create("llvm")
ctx = tvm.cpu(0)
with relay.build_config(opt_level=2):
graph, lib, params = relay.build(mod, target=target, params=params)
module = graph_runtime.create(graph, lib, ctx)
module.set_input(**params)
module.run()
tvm_output = module.get_output(0)
return np.array(tvm_output.asnumpy())
# 绑定参数(输入),编译,运行,获得运行结果
def build_and_run_with_graph(graph_func, params, input=None, pass_only=False, build_only=False):
expr = graph_func(params)
mod = relay.Module.from_expr(expr)
entry = mod['main']
entry = bind_params(entry, params)
if pass_only:
mod = relay.Module({'main': entry})
seq = transform.Sequential([transform.InferType(), transform.VaccGroupConvPatch()])
with transform.PassContext():
mod = seq(mod)
print(mod)
print('\n')
return
# build
with relay.build_config(opt_level=2):
graph, lib, params = relay.build(entry,
target=target,
params=params)
if build_only:
print('Test case module built.')
return
# create runtime
m = graph_runtime.create(graph, lib, ctx)
# set param and input
m.set_input(**params
if input is not None:
m.set_input('data', tvm.nd.array(input))
# run
m.run()
# 绑定参数,编译,运行,获得运行结果
def get_tvm_output(func, x, params, target, ctx,
out_shape=(1, 1000), input_name='image', dtype='float32'):
with relay.transform.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input(input_name, tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
m.run()
# get outputs
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy()