build, run相关文件
build相关
注意:tvm.build和tvm.relay.build是不同的
- python
文件路径:python/tvm/relay/build_module.py
类:BuildModule
函数:build,optimize
- c++
文件路径:src/relay/backend/build_module.cc
类:RelayBuildModule
函数:build,optimize
run相关
- python
文件路径:python/tvm/contrib/graph_runtime.py
类:GraphModule
函数: run
- c++
文件路径:src/runtime/graph/graph_runtime.cc
类:GraphRuntime
编译和运行过程
下面以relay_quick_start.py为例,详细解释relay的编译和运行过程
1 加载网络模型和参数
源码:
batch_size = 1num_class = 1000image_shape = (3, 224, 224)data_shape = (batch_size,) + image_shapeout_shape = (batch_size, num_class)mod, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=batch_size, image_shape=image_shape)# set show_meta_data=True if you want to show meta dataprint(mod.astext(show_meta_data=False))
说明:
mod : tvm.relay.Module
The relay module that contains a ResNet network.
params : dict of str to NDArray
The parameters.
2 编译
源码:
opt_level = 3target = tvm.target.cuda()with relay.build_config(opt_level=opt_level):graph, lib, params = relay.build_module.build(mod, target, params=params)
调用链如下:
[python] relay.build = relay.build_module.build
[python] BuildModule.build
[c++] RelayBuildModule::Build - Build relay function for graph runtime
[c++] RelayBuildModule::BuildRelay - Compile a Relay function to runtime module.
[c++] RelayBuildModule::Optimize - Optimize input Relay Function and returns Relay Module, 对relay IR做一些pass优化,例如fuseop等。
[c++] GraphCodegen::Codegen - Generate code for the updated function. 主要的工作是Relay IR-> Te expression -> schedule -> Tir。还包括外部codegen 的调用。
[c++] tvm::build - Build for heterogeneous execution. 主要是tir 转换成llvm ir的功能
3 运行
源码:
# create random inputctx = tvm.gpu()data = np.random.uniform(-1, 1, size=data_shape).astype("float32")# create module - 开始初始化设备信息,分配内存(SRAM, DDR)module = graph_runtime.create(graph, lib, ctx)# set input and parameters - 将数据保存到设备相应位置,data->SRAM, weight->DDRmodule.set_input("data", data)module.set_input(**params)# run - 运行,开始计算module.run()# get outputout = module.get_output(0, tvm.nd.empty(out_shape)).asnumpy()# Print first 10 elements of outputprint(out.flatten()[0:10])
注意:其中target, ctx请参考TVM build and run中的target和ctx的选择
函数调用链:
graph_runtime.create(graph, lib, ctx)创建一个runtime可执行的module,即GraphModule
graph_runtime.create函数的部分源码:
if num_rpc_ctx == len(ctx):#hmod = rpc_base._ModuleHandle(libmod)fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.create")#return GraphModule(fcreate(graph_json_str, hmod, name, *device_type_id))else:fcreate = get_global_func("tvm.graph_runtime.create")return GraphModule(fcreate(graph_json_str, libmod, name, *device_type_id))
fcreate其实是一个函数,对应c++中的注册为”tvm.graph_runtime.create”的函数,该函数注册如下:
TVM_REGISTER_GLOBAL("tvm.graph_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) {CHECK_GE(args.num_args, 4)<< "The expected number of arguments for graph_runtime.create is ""at least 4, but it has "<< args.num_args;const auto& contexts = GetAllContext(args);*rv = GraphRuntimeCreate(args[0], args[1], args[2], contexts);});
该函数返回值是调用GraphRuntimeCreate函数的返回值
Module GraphRuntimeCreate(const std::string& sym_json,const tvm::runtime::Module& m,const std::string& name,const std::vector<TVMContext>& ctxs) {auto exec = make_object<GraphRuntime>();exec->Init(sym_json, m, ctxs, name);return Module(exec);}
注意,GraphRuntimeCreate的函数会生成一个GraphRuntime,并初始化(在GraphRuntime::Init中,会调用VaccGetDeviceCount,IDVaccGetDeviceIDs,IDVaccSetDevice,vaccrt_create_context,vaccrt_add_model,vaccrt_begin_gen_model等函数),最终是返回一个了c++的Module,对应的是python中的Module, 是包含了许多函数的模块。
void GraphRuntime::Init(const std::string& graph_json,tvm::runtime::Module module,const std::vector<TVMContext>& ctxs,const std::string& name) {#ifndef _LIBCPP_SGX_NO_IOSTREAMSstd::istringstream is(graph_json);#elsestd::string is = graph_json;#endifm_grt = this;dmlc::JSONReader reader(&is);this->Load(&reader);module_ = module;ctxs_ = ctxs;const auto& cit =std::find_if(ctxs_.begin(), ctxs_.end(), [](const TVMContext& c) {return kDLVacc == static_cast<int>(c.device_type);});if (cit != ctxs_.end()) { //device is Vacc//Get total number of devicesuint32_t device_count = 0;uint32_t *p_count = &device_count;VACC_MEM_CALL(VaccGetDeviceCount(p_count));device_count = *p_count;VACC_LOG_DEBUG("Get the total number of devices, count=%u", device_count);//Get all devices iduint64_t *id_array = new uint64_t[device_count];VACC_MEM_CALL(VaccGetDeviceIDs(id_array, device_count));VACC_LOG_INFO("Get all devices id." );for (uint32_t i=0; i<device_count; i++) {VACC_LOG_INFO("Device id list: %lu", *(id_array+i));}delete[] id_array;//Set deviceVACC_MEM_CALL(VaccSetDevice(static_cast<uint64_t>((*cit).device_id)));VACC_LOG_INFO("Choose a device by index.");dev_ctx_t ctx;INT8 *modelName = (INT8*)name.c_str();module_name_ = modelName;dev_id = static_cast<UINT8>((*cit).device_id);ctx.next = nullptr;//call device_runtime so to create device contextVACC_LOG_INFO("Begin to create device context...");VACC_CALL(vaccrt_create_context(&ctx, dev_id))//add modelVACC_LOG_INFO("Begin to add model...");vaccrt_add_model(&ctx, dev_id, modelName);//set modelVACC_LOG_INFO( "Begin to set model...");VACC_CALL(vaccrt_set_model(&ctx, dev_id, modelName));//init the selected modelVACC_LOG_INFO("Begin to init the selected model...")VACC_CALL(vaccrt_begin_gen_model());}this->SetupStorage();this->SetupOpExecs();for (size_t i = 0; i < input_nodes_.size(); i++) {const uint32_t nid = input_nodes_[i];std::string& name = nodes_[nid].name;input_map_[name] = i;}this->GenerateCsr();}
现在回头看看 return GraphModule(fcreate(graph_json_str, libmod, name, *device_type_id)),相当于返回一个GraphModule对象
再看看GraphModule的初始化函数:
def __init__(self, module):self.module = moduleself._set_input = module["set_input"]self._run = module["run"]self._get_output = module["get_output"]self._get_input = module["get_input"]self._get_num_outputs = module["get_num_outputs"]self._load_params = module["load_params"]self._share_params = module["share_params"]self._get_memory_address = module["get_memory_address"]
说明,此时在c++端创建的Module就绑定到该GraphModule,python端的module就是然后,module.set_input其实是调用c++端的GraphRuntime::SetInput
module.run其实是调用c++端的GraphRuntime::Run
module.get_output其实是调用c++端的GraphRuntime::GetOutput
…
graph_runtime是执行relay编译好的程序。前端框架的图,转化成relay IR,然后进行编译。最后通过runtime来跑inference。
