下面三个是定义在python/tvm/relay/testing/init.py的函数
relay.testing.run_opt_pass
relay.testing.run_infer_type
relay.testing.rand
# 运行passdef run_opt_pass(expr, opt_pass):assert isinstance(opt_pass, transform.Pass)mod = relay.Module.from_expr(expr)mod = opt_pass(mod)entry = mod["main"]return entry if isinstance(expr, relay.Function) else entry.body# 推断类型def run_infer_type(expr):return run_opt_pass(expr, transform.InferType())# 指定数据类型和形状,获取随机值def rand(dtype, *shape):return tvm.nd.array(np.random.rand(*shape).astype(dtype))
下面的是一些自定义的函数:
# 绑定参数def bind_params(func, params):name_dict = {}for arg in func.params:name = arg.name_hintif name in name_dict:name_dict[name] = Noneelse:name_dict[name] = argbind_dict = {}for k, v in params.items():if k not in name_dict:continuearg = name_dict[k]if arg is None:raise ValueError("Multiple args in the function have name %s" % k)bind_dict[arg] = tvm.relay.expr.const(v)return tvm.relay.expr.bind(func, bind_dict)# 绑定参数,然后运行pass (用于测试那些需要处理参数的的pass)def run_opt_pass(expr, opt_pass, params):assert isinstance(opt_pass, transform.Pass)mod = relay.Module.from_expr(expr)entry = bind_params(mod["main"], params)mod = relay.Module.from_expr(entry)mod = opt_pass(mod)return mod# 绑定参数,编译,运行,获得运行结果def build_run(expr, params):mod = expr if isinstance(expr, relay.Module) else relay.Module.from_expr(expr)target = tvm.target.create("llvm")ctx = tvm.cpu(0)with relay.build_config(opt_level=2):graph, lib, params = relay.build(mod, target=target, params=params)module = graph_runtime.create(graph, lib, ctx)module.set_input(**params)module.run()tvm_output = module.get_output(0)return np.array(tvm_output.asnumpy())# 绑定参数(输入),编译,运行,获得运行结果def build_and_run_with_graph(graph_func, params, input=None, pass_only=False, build_only=False):expr = graph_func(params)mod = relay.Module.from_expr(expr)entry = mod['main']entry = bind_params(entry, params)if pass_only:mod = relay.Module({'main': entry})seq = transform.Sequential([transform.InferType(), transform.VaccGroupConvPatch()])with transform.PassContext():mod = seq(mod)print(mod)print('\n')return# buildwith relay.build_config(opt_level=2):graph, lib, params = relay.build(entry,target=target,params=params)if build_only:print('Test case module built.')return# create runtimem = graph_runtime.create(graph, lib, ctx)# set param and inputm.set_input(**paramsif input is not None:m.set_input('data', tvm.nd.array(input))# runm.run()# 绑定参数,编译,运行,获得运行结果def get_tvm_output(func, x, params, target, ctx,out_shape=(1, 1000), input_name='image', dtype='float32'):with relay.transform.build_config(opt_level=3):graph, lib, params = relay.build(func, target, params=params)m = graph_runtime.create(graph, lib, ctx)# set inputsm.set_input(input_name, tvm.nd.array(x.astype(dtype)))m.set_input(**params)m.run()# get outputsout = m.get_output(0, tvm.nd.empty(out_shape, dtype))return out.asnumpy()
