1 通过relay.create_executor来执行Relay表达式

  1. def test_any_concat():
  2. """测试concatenate算子"""
  3. x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
  4. y = relay.var('y', shape=(1, 2), dtype="float32")
  5. z = relay.op.concatenate([x, y], axis=0)
  6. # mod = relay.module.Module()
  7. # mod["main"] = relay.Function([x, y], z)
  8. f = relay.Function([x, y], z)
  9. mod = relay.module.Module.from_expr(f)
  10. x_np = np.random.uniform(size=(3, 2)).astype('float32')
  11. y_np = np.random.uniform(size=(1, 2)).astype('float32')
  12. ref = np.concatenate([x_np, y_np], axis=0)
  13. print("\n----------------------------\nexpected result:")
  14. print(ref)
  15. for kind in ["debug"]:
  16. ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") #创建执行者
  17. result = ex.evaluate()(x_np, y_np) #执行表达式,返回结果
  18. print("\n----------------------------\nreal result:")
  19. print(result.asnumpy())
  20. tvm.testing.assert_allclose(result.asnumpy(), ref) #比较结果

2 采用build + run获得运行结果

  1. def tvm_run(input_val, params, inputs, level=2):
  2. """A hack for getting the value of an expression by evaluating a
  3. portion of the relay graph. This is often needed for functions that
  4. whose output shape depends on the value of a tensor.
  5. """
  6. import tvm
  7. from tvm.relay import expr as _expr
  8. from tvm.relay import analysis
  9. from tvm.contrib import graph_runtime
  10. # Check that all free variables have associated parameters.
  11. # assert all(var.name_hint in params.keys() for var in analysis.free_vars(
  12. # input_val)), "All inputs to infer must be available in params."
  13. func = _expr.Function(analysis.free_vars(input_val), input_val)
  14. print("\n------------------------------------------------------------")
  15. print(func.astext(show_meta_data=False))
  16. # target = tvm.target.vacc()
  17. # ctx = tvm.vacc(0)
  18. target = "llvm"
  19. ctx = tvm.cpu(0)
  20. with tvm.relay.build_config(opt_level=level):
  21. graph, lib, params = tvm.relay.build(
  22. func, target=target, params=params)
  23. # print("\n------------------------------------------------------------")
  24. # print(graph)
  25. print("\n------------------------------------------------------------")
  26. print(params)
  27. m = graph_runtime.create(graph, lib, ctx)
  28. if inputs is not None:
  29. m.set_input(**inputs)
  30. m.set_input(**params)
  31. m.run()
  32. out = m.get_output(0)
  33. print("\n------------------------------------------------------------")
  34. print(out)
  35. return out