build, run相关文件

build相关

注意:tvm.build和tvm.relay.build是不同的

  • python

文件路径:python/tvm/relay/build_module.py
类:BuildModule
函数:build,optimize

  • c++

文件路径:src/relay/backend/build_module.cc
类:RelayBuildModule
函数:build,optimize

run相关

  • python

文件路径:python/tvm/contrib/graph_runtime.py
类:GraphModule
函数: run

  • c++

文件路径:src/runtime/graph/graph_runtime.cc
类:GraphRuntime

编译和运行过程

下面以relay_quick_start.py为例,详细解释relay的编译和运行过程

1 加载网络模型和参数

源码:

  1. batch_size = 1
  2. num_class = 1000
  3. image_shape = (3, 224, 224)
  4. data_shape = (batch_size,) + image_shape
  5. out_shape = (batch_size, num_class)
  6. mod, params = relay.testing.resnet.get_workload(
  7. num_layers=18, batch_size=batch_size, image_shape=image_shape)
  8. # set show_meta_data=True if you want to show meta data
  9. print(mod.astext(show_meta_data=False))

说明:
mod : tvm.relay.Module
The relay module that contains a ResNet network.
params : dict of str to NDArray
The parameters.

2 编译

源码:

  1. opt_level = 3
  2. target = tvm.target.cuda()
  3. with relay.build_config(opt_level=opt_level):
  4. graph, lib, params = relay.build_module.build(mod, target, params=params)

调用链如下:
[python] relay.build = relay.build_module.build
[python] BuildModule.build
[c++] RelayBuildModule::Build - Build relay function for graph runtime
[c++] RelayBuildModule::BuildRelay - Compile a Relay function to runtime module.
[c++] RelayBuildModule::Optimize - Optimize input Relay Function and returns Relay Module, 对relay IR做一些pass优化,例如fuseop等。
[c++] GraphCodegen::Codegen - Generate code for the updated function. 主要的工作是Relay IR-> Te expression -> schedule -> Tir。还包括外部codegen 的调用。
[c++] tvm::build - Build for heterogeneous execution. 主要是tir 转换成llvm ir的功能

3 运行

源码:

  1. # create random input
  2. ctx = tvm.gpu()
  3. data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
  4. # create module - 开始初始化设备信息,分配内存(SRAM, DDR)
  5. module = graph_runtime.create(graph, lib, ctx)
  6. # set input and parameters - 将数据保存到设备相应位置,data->SRAM, weight->DDR
  7. module.set_input("data", data)
  8. module.set_input(**params)
  9. # run - 运行,开始计算
  10. module.run()
  11. # get output
  12. out = module.get_output(0, tvm.nd.empty(out_shape)).asnumpy()
  13. # Print first 10 elements of output
  14. print(out.flatten()[0:10])

注意:其中target, ctx请参考TVM build and run中的target和ctx的选择

函数调用链:
graph_runtime.create(graph, lib, ctx)创建一个runtime可执行的module,即GraphModule
graph_runtime.create函数的部分源码:

  1. if num_rpc_ctx == len(ctx):
  2. #hmod = rpc_base._ModuleHandle(libmod)
  3. fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.create")
  4. #return GraphModule(fcreate(graph_json_str, hmod, name, *device_type_id))
  5. else:
  6. fcreate = get_global_func("tvm.graph_runtime.create")
  7. return GraphModule(fcreate(graph_json_str, libmod, name, *device_type_id))

fcreate其实是一个函数,对应c++中的注册为”tvm.graph_runtime.create”的函数,该函数注册如下:

  1. TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
  2. .set_body([](TVMArgs args, TVMRetValue* rv) {
  3. CHECK_GE(args.num_args, 4)
  4. << "The expected number of arguments for graph_runtime.create is "
  5. "at least 4, but it has "
  6. << args.num_args;
  7. const auto& contexts = GetAllContext(args);
  8. *rv = GraphRuntimeCreate(args[0], args[1], args[2], contexts);
  9. });

该函数返回值是调用GraphRuntimeCreate函数的返回值

  1. Module GraphRuntimeCreate(const std::string& sym_json,
  2. const tvm::runtime::Module& m,
  3. const std::string& name,
  4. const std::vector<TVMContext>& ctxs) {
  5. auto exec = make_object<GraphRuntime>();
  6. exec->Init(sym_json, m, ctxs, name);
  7. return Module(exec);
  8. }

注意,GraphRuntimeCreate的函数会生成一个GraphRuntime,并初始化(在GraphRuntime::Init中,会调用VaccGetDeviceCount,IDVaccGetDeviceIDs,IDVaccSetDevice,vaccrt_create_context,vaccrt_add_model,vaccrt_begin_gen_model等函数),最终是返回一个了c++的Module,对应的是python中的Module, 是包含了许多函数的模块。

  1. void GraphRuntime::Init(const std::string& graph_json,
  2. tvm::runtime::Module module,
  3. const std::vector<TVMContext>& ctxs,
  4. const std::string& name) {
  5. #ifndef _LIBCPP_SGX_NO_IOSTREAMS
  6. std::istringstream is(graph_json);
  7. #else
  8. std::string is = graph_json;
  9. #endif
  10. m_grt = this;
  11. dmlc::JSONReader reader(&is);
  12. this->Load(&reader);
  13. module_ = module;
  14. ctxs_ = ctxs;
  15. const auto& cit =
  16. std::find_if(ctxs_.begin(), ctxs_.end(), [](const TVMContext& c) {
  17. return kDLVacc == static_cast<int>(c.device_type);
  18. });
  19. if (cit != ctxs_.end()) { //device is Vacc
  20. //Get total number of devices
  21. uint32_t device_count = 0;
  22. uint32_t *p_count = &device_count;
  23. VACC_MEM_CALL(VaccGetDeviceCount(p_count));
  24. device_count = *p_count;
  25. VACC_LOG_DEBUG("Get the total number of devices, count=%u", device_count);
  26. //Get all devices id
  27. uint64_t *id_array = new uint64_t[device_count];
  28. VACC_MEM_CALL(VaccGetDeviceIDs(id_array, device_count));
  29. VACC_LOG_INFO("Get all devices id." );
  30. for (uint32_t i=0; i<device_count; i++) {
  31. VACC_LOG_INFO("Device id list: %lu", *(id_array+i));
  32. }
  33. delete[] id_array;
  34. //Set device
  35. VACC_MEM_CALL(VaccSetDevice(static_cast<uint64_t>((*cit).device_id)));
  36. VACC_LOG_INFO("Choose a device by index.");
  37. dev_ctx_t ctx;
  38. INT8 *modelName = (INT8*)name.c_str();
  39. module_name_ = modelName;
  40. dev_id = static_cast<UINT8>((*cit).device_id);
  41. ctx.next = nullptr;
  42. //call device_runtime so to create device context
  43. VACC_LOG_INFO("Begin to create device context...");
  44. VACC_CALL(vaccrt_create_context(&ctx, dev_id))
  45. //add model
  46. VACC_LOG_INFO("Begin to add model...");
  47. vaccrt_add_model(&ctx, dev_id, modelName);
  48. //set model
  49. VACC_LOG_INFO( "Begin to set model...");
  50. VACC_CALL(vaccrt_set_model(&ctx, dev_id, modelName));
  51. //init the selected model
  52. VACC_LOG_INFO("Begin to init the selected model...")
  53. VACC_CALL(vaccrt_begin_gen_model());
  54. }
  55. this->SetupStorage();
  56. this->SetupOpExecs();
  57. for (size_t i = 0; i < input_nodes_.size(); i++) {
  58. const uint32_t nid = input_nodes_[i];
  59. std::string& name = nodes_[nid].name;
  60. input_map_[name] = i;
  61. }
  62. this->GenerateCsr();
  63. }

现在回头看看 return GraphModule(fcreate(graph_json_str, libmod, name, *device_type_id)),相当于返回一个GraphModule对象
再看看GraphModule的初始化函数:

  1. def __init__(self, module):
  2. self.module = module
  3. self._set_input = module["set_input"]
  4. self._run = module["run"]
  5. self._get_output = module["get_output"]
  6. self._get_input = module["get_input"]
  7. self._get_num_outputs = module["get_num_outputs"]
  8. self._load_params = module["load_params"]
  9. self._share_params = module["share_params"]
  10. self._get_memory_address = module["get_memory_address"]

说明,此时在c++端创建的Module就绑定到该GraphModule,python端的module就是然后,module.set_input其实是调用c++端的GraphRuntime::SetInput
module.run其实是调用c++端的GraphRuntime::Run
module.get_output其实是调用c++端的GraphRuntime::GetOutput

graph_runtime是执行relay编译好的程序。前端框架的图,转化成relay IR,然后进行编译。最后通过runtime来跑inference。