1 通过relay.create_executor来执行Relay表达式
def test_any_concat():
"""测试concatenate算子"""
x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
y = relay.var('y', shape=(1, 2), dtype="float32")
z = relay.op.concatenate([x, y], axis=0)
# mod = relay.module.Module()
# mod["main"] = relay.Function([x, y], z)
f = relay.Function([x, y], z)
mod = relay.module.Module.from_expr(f)
x_np = np.random.uniform(size=(3, 2)).astype('float32')
y_np = np.random.uniform(size=(1, 2)).astype('float32')
ref = np.concatenate([x_np, y_np], axis=0)
print("\n----------------------------\nexpected result:")
print(ref)
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") #创建执行者
result = ex.evaluate()(x_np, y_np) #执行表达式,返回结果
print("\n----------------------------\nreal result:")
print(result.asnumpy())
tvm.testing.assert_allclose(result.asnumpy(), ref) #比较结果
2 采用build + run获得运行结果
def tvm_run(input_val, params, inputs, level=2):
"""A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that
whose output shape depends on the value of a tensor.
"""
import tvm
from tvm.relay import expr as _expr
from tvm.relay import analysis
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
# assert all(var.name_hint in params.keys() for var in analysis.free_vars(
# input_val)), "All inputs to infer must be available in params."
func = _expr.Function(analysis.free_vars(input_val), input_val)
print("\n------------------------------------------------------------")
print(func.astext(show_meta_data=False))
# target = tvm.target.vacc()
# ctx = tvm.vacc(0)
target = "llvm"
ctx = tvm.cpu(0)
with tvm.relay.build_config(opt_level=level):
graph, lib, params = tvm.relay.build(
func, target=target, params=params)
# print("\n------------------------------------------------------------")
# print(graph)
print("\n------------------------------------------------------------")
print(params)
m = graph_runtime.create(graph, lib, ctx)
if inputs is not None:
m.set_input(**inputs)
m.set_input(**params)
m.run()
out = m.get_output(0)
print("\n------------------------------------------------------------")
print(out)
return out