1 通过relay.create_executor来执行Relay表达式
def test_any_concat(): """测试concatenate算子""" x = relay.var('x', shape=(relay.Any(), 2), dtype="float32") y = relay.var('y', shape=(1, 2), dtype="float32") z = relay.op.concatenate([x, y], axis=0) # mod = relay.module.Module() # mod["main"] = relay.Function([x, y], z) f = relay.Function([x, y], z) mod = relay.module.Module.from_expr(f) x_np = np.random.uniform(size=(3, 2)).astype('float32') y_np = np.random.uniform(size=(1, 2)).astype('float32') ref = np.concatenate([x_np, y_np], axis=0) print("\n----------------------------\nexpected result:") print(ref) for kind in ["debug"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") #创建执行者 result = ex.evaluate()(x_np, y_np) #执行表达式,返回结果 print("\n----------------------------\nreal result:") print(result.asnumpy()) tvm.testing.assert_allclose(result.asnumpy(), ref) #比较结果
2 采用build + run获得运行结果
def tvm_run(input_val, params, inputs, level=2): """A hack for getting the value of an expression by evaluating a portion of the relay graph. This is often needed for functions that whose output shape depends on the value of a tensor. """ import tvm from tvm.relay import expr as _expr from tvm.relay import analysis from tvm.contrib import graph_runtime # Check that all free variables have associated parameters. # assert all(var.name_hint in params.keys() for var in analysis.free_vars( # input_val)), "All inputs to infer must be available in params." func = _expr.Function(analysis.free_vars(input_val), input_val) print("\n------------------------------------------------------------") print(func.astext(show_meta_data=False)) # target = tvm.target.vacc() # ctx = tvm.vacc(0) target = "llvm" ctx = tvm.cpu(0) with tvm.relay.build_config(opt_level=level): graph, lib, params = tvm.relay.build( func, target=target, params=params) # print("\n------------------------------------------------------------") # print(graph) print("\n------------------------------------------------------------") print(params) m = graph_runtime.create(graph, lib, ctx) if inputs is not None: m.set_input(**inputs) m.set_input(**params) m.run() out = m.get_output(0) print("\n------------------------------------------------------------") print(out) return out