build, run相关文件
build相关
注意:tvm.build和tvm.relay.build是不同的
- python
文件路径:python/tvm/relay/build_module.py
类:BuildModule
函数:build,optimize
- c++
文件路径:src/relay/backend/build_module.cc
类:RelayBuildModule
函数:build,optimize
run相关
- python
文件路径:python/tvm/contrib/graph_runtime.py
类:GraphModule
函数: run
- c++
文件路径:src/runtime/graph/graph_runtime.cc
类:GraphRuntime
编译和运行过程
下面以relay_quick_start.py为例,详细解释relay的编译和运行过程
1 加载网络模型和参数
源码:
batch_size = 1
num_class = 1000
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
out_shape = (batch_size, num_class)
mod, params = relay.testing.resnet.get_workload(
num_layers=18, batch_size=batch_size, image_shape=image_shape)
# set show_meta_data=True if you want to show meta data
print(mod.astext(show_meta_data=False))
说明:
mod : tvm.relay.Module
The relay module that contains a ResNet network.
params : dict of str to NDArray
The parameters.
2 编译
源码:
opt_level = 3
target = tvm.target.cuda()
with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build_module.build(mod, target, params=params)
调用链如下:
[python] relay.build = relay.build_module.build
[python] BuildModule.build
[c++] RelayBuildModule::Build - Build relay function for graph runtime
[c++] RelayBuildModule::BuildRelay - Compile a Relay function to runtime module.
[c++] RelayBuildModule::Optimize - Optimize input Relay Function and returns Relay Module, 对relay IR做一些pass优化,例如fuseop等。
[c++] GraphCodegen::Codegen - Generate code for the updated function. 主要的工作是Relay IR-> Te expression -> schedule -> Tir。还包括外部codegen 的调用。
[c++] tvm::build - Build for heterogeneous execution. 主要是tir 转换成llvm ir的功能
3 运行
源码:
# create random input
ctx = tvm.gpu()
data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
# create module - 开始初始化设备信息,分配内存(SRAM, DDR)
module = graph_runtime.create(graph, lib, ctx)
# set input and parameters - 将数据保存到设备相应位置,data->SRAM, weight->DDR
module.set_input("data", data)
module.set_input(**params)
# run - 运行,开始计算
module.run()
# get output
out = module.get_output(0, tvm.nd.empty(out_shape)).asnumpy()
# Print first 10 elements of output
print(out.flatten()[0:10])
注意:其中target, ctx请参考TVM build and run中的target和ctx的选择
函数调用链:
graph_runtime.create(graph, lib, ctx)创建一个runtime可执行的module,即GraphModule
graph_runtime.create函数的部分源码:
if num_rpc_ctx == len(ctx):
#hmod = rpc_base._ModuleHandle(libmod)
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.create")
#return GraphModule(fcreate(graph_json_str, hmod, name, *device_type_id))
else:
fcreate = get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, name, *device_type_id))
fcreate其实是一个函数,对应c++中的注册为”tvm.graph_runtime.create”的函数,该函数注册如下:
TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4)
<< "The expected number of arguments for graph_runtime.create is "
"at least 4, but it has "
<< args.num_args;
const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeCreate(args[0], args[1], args[2], contexts);
});
该函数返回值是调用GraphRuntimeCreate函数的返回值
Module GraphRuntimeCreate(const std::string& sym_json,
const tvm::runtime::Module& m,
const std::string& name,
const std::vector<TVMContext>& ctxs) {
auto exec = make_object<GraphRuntime>();
exec->Init(sym_json, m, ctxs, name);
return Module(exec);
}
注意,GraphRuntimeCreate的函数会生成一个GraphRuntime,并初始化(在GraphRuntime::Init中,会调用VaccGetDeviceCount,IDVaccGetDeviceIDs,IDVaccSetDevice,vaccrt_create_context,vaccrt_add_model,vaccrt_begin_gen_model等函数),最终是返回一个了c++的Module,对应的是python中的Module, 是包含了许多函数的模块。
void GraphRuntime::Init(const std::string& graph_json,
tvm::runtime::Module module,
const std::vector<TVMContext>& ctxs,
const std::string& name) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::istringstream is(graph_json);
#else
std::string is = graph_json;
#endif
m_grt = this;
dmlc::JSONReader reader(&is);
this->Load(&reader);
module_ = module;
ctxs_ = ctxs;
const auto& cit =
std::find_if(ctxs_.begin(), ctxs_.end(), [](const TVMContext& c) {
return kDLVacc == static_cast<int>(c.device_type);
});
if (cit != ctxs_.end()) { //device is Vacc
//Get total number of devices
uint32_t device_count = 0;
uint32_t *p_count = &device_count;
VACC_MEM_CALL(VaccGetDeviceCount(p_count));
device_count = *p_count;
VACC_LOG_DEBUG("Get the total number of devices, count=%u", device_count);
//Get all devices id
uint64_t *id_array = new uint64_t[device_count];
VACC_MEM_CALL(VaccGetDeviceIDs(id_array, device_count));
VACC_LOG_INFO("Get all devices id." );
for (uint32_t i=0; i<device_count; i++) {
VACC_LOG_INFO("Device id list: %lu", *(id_array+i));
}
delete[] id_array;
//Set device
VACC_MEM_CALL(VaccSetDevice(static_cast<uint64_t>((*cit).device_id)));
VACC_LOG_INFO("Choose a device by index.");
dev_ctx_t ctx;
INT8 *modelName = (INT8*)name.c_str();
module_name_ = modelName;
dev_id = static_cast<UINT8>((*cit).device_id);
ctx.next = nullptr;
//call device_runtime so to create device context
VACC_LOG_INFO("Begin to create device context...");
VACC_CALL(vaccrt_create_context(&ctx, dev_id))
//add model
VACC_LOG_INFO("Begin to add model...");
vaccrt_add_model(&ctx, dev_id, modelName);
//set model
VACC_LOG_INFO( "Begin to set model...");
VACC_CALL(vaccrt_set_model(&ctx, dev_id, modelName));
//init the selected model
VACC_LOG_INFO("Begin to init the selected model...")
VACC_CALL(vaccrt_begin_gen_model());
}
this->SetupStorage();
this->SetupOpExecs();
for (size_t i = 0; i < input_nodes_.size(); i++) {
const uint32_t nid = input_nodes_[i];
std::string& name = nodes_[nid].name;
input_map_[name] = i;
}
this->GenerateCsr();
}
现在回头看看 return GraphModule(fcreate(graph_json_str, libmod, name, *device_type_id)),相当于返回一个GraphModule对象
再看看GraphModule的初始化函数:
def __init__(self, module):
self.module = module
self._set_input = module["set_input"]
self._run = module["run"]
self._get_output = module["get_output"]
self._get_input = module["get_input"]
self._get_num_outputs = module["get_num_outputs"]
self._load_params = module["load_params"]
self._share_params = module["share_params"]
self._get_memory_address = module["get_memory_address"]
说明,此时在c++端创建的Module就绑定到该GraphModule,python端的module就是然后,module.set_input其实是调用c++端的GraphRuntime::SetInput
module.run其实是调用c++端的GraphRuntime::Run
module.get_output其实是调用c++端的GraphRuntime::GetOutput
…
graph_runtime是执行relay编译好的程序。前端框架的图,转化成relay IR,然后进行编译。最后通过runtime来跑inference。