下面三个是定义在python/tvm/relay/testing/init.py的函数
    relay.testing.run_opt_pass
    relay.testing.run_infer_type
    relay.testing.rand

    1. # 运行pass
    2. def run_opt_pass(expr, opt_pass):
    3. assert isinstance(opt_pass, transform.Pass)
    4. mod = relay.Module.from_expr(expr)
    5. mod = opt_pass(mod)
    6. entry = mod["main"]
    7. return entry if isinstance(expr, relay.Function) else entry.body
    8. # 推断类型
    9. def run_infer_type(expr):
    10. return run_opt_pass(expr, transform.InferType())
    11. # 指定数据类型和形状,获取随机值
    12. def rand(dtype, *shape):
    13. return tvm.nd.array(np.random.rand(*shape).astype(dtype))

    下面的是一些自定义的函数:

    1. # 绑定参数
    2. def bind_params(func, params):
    3. name_dict = {}
    4. for arg in func.params:
    5. name = arg.name_hint
    6. if name in name_dict:
    7. name_dict[name] = None
    8. else:
    9. name_dict[name] = arg
    10. bind_dict = {}
    11. for k, v in params.items():
    12. if k not in name_dict:
    13. continue
    14. arg = name_dict[k]
    15. if arg is None:
    16. raise ValueError("Multiple args in the function have name %s" % k)
    17. bind_dict[arg] = tvm.relay.expr.const(v)
    18. return tvm.relay.expr.bind(func, bind_dict)
    19. # 绑定参数,然后运行pass (用于测试那些需要处理参数的的pass)
    20. def run_opt_pass(expr, opt_pass, params):
    21. assert isinstance(opt_pass, transform.Pass)
    22. mod = relay.Module.from_expr(expr)
    23. entry = bind_params(mod["main"], params)
    24. mod = relay.Module.from_expr(entry)
    25. mod = opt_pass(mod)
    26. return mod
    27. # 绑定参数,编译,运行,获得运行结果
    28. def build_run(expr, params):
    29. mod = expr if isinstance(expr, relay.Module) else relay.Module.from_expr(expr)
    30. target = tvm.target.create("llvm")
    31. ctx = tvm.cpu(0)
    32. with relay.build_config(opt_level=2):
    33. graph, lib, params = relay.build(mod, target=target, params=params)
    34. module = graph_runtime.create(graph, lib, ctx)
    35. module.set_input(**params)
    36. module.run()
    37. tvm_output = module.get_output(0)
    38. return np.array(tvm_output.asnumpy())
    39. # 绑定参数(输入),编译,运行,获得运行结果
    40. def build_and_run_with_graph(graph_func, params, input=None, pass_only=False, build_only=False):
    41. expr = graph_func(params)
    42. mod = relay.Module.from_expr(expr)
    43. entry = mod['main']
    44. entry = bind_params(entry, params)
    45. if pass_only:
    46. mod = relay.Module({'main': entry})
    47. seq = transform.Sequential([transform.InferType(), transform.VaccGroupConvPatch()])
    48. with transform.PassContext():
    49. mod = seq(mod)
    50. print(mod)
    51. print('\n')
    52. return
    53. # build
    54. with relay.build_config(opt_level=2):
    55. graph, lib, params = relay.build(entry,
    56. target=target,
    57. params=params)
    58. if build_only:
    59. print('Test case module built.')
    60. return
    61. # create runtime
    62. m = graph_runtime.create(graph, lib, ctx)
    63. # set param and input
    64. m.set_input(**params
    65. if input is not None:
    66. m.set_input('data', tvm.nd.array(input))
    67. # run
    68. m.run()
    69. # 绑定参数,编译,运行,获得运行结果
    70. def get_tvm_output(func, x, params, target, ctx,
    71. out_shape=(1, 1000), input_name='image', dtype='float32'):
    72. with relay.transform.build_config(opt_level=3):
    73. graph, lib, params = relay.build(func, target, params=params)
    74. m = graph_runtime.create(graph, lib, ctx)
    75. # set inputs
    76. m.set_input(input_name, tvm.nd.array(x.astype(dtype)))
    77. m.set_input(**params)
    78. m.run()
    79. # get outputs
    80. out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
    81. return out.asnumpy()