_expr.Bind函数用来绑定expr中的变量,就是用参数字典中提供的参数值(expr)替换指定expr中的变量(var)
python中的定义
定义文件:python/tvm/relay/expr.py
def bind(expr, binds):"""Bind an free variables in expr or function arguments.We can bind parameters expr if it is a function.Parameters----------expr : tvm.relay.ExprThe input expression.binds : Union[Map[tvm.relay.Var, tvm.relay.Expr], Map[str, tvm.relay.Expr]]The specific bindings.Returns-------result : tvm.relay.ExprThe expression or function after binding."""return _expr.Bind(expr, binds)
_expr.Bind其实是在c++中注册的一个全局函数:relay._expr.Bind
c++注册
注册文件:src/relay/ir/expr_functor.cc
// Implement bind.class ExprBinder : public ExprMutator, PatternMutator {public:explicit ExprBinder(const tvm::Map<Var, Expr>& args_map): args_map_(args_map) {}Expr VisitExpr_(const LetNode* op) final {CHECK(!args_map_.count(op->var))<< "Cannot bind an internel variable in let";return ExprMutator::VisitExpr_(op);}Expr VisitExpr_(const FunctionNode* op) final {for (Var param : op->params) {CHECK(!args_map_.count(param))<< "Cannnot bind an internal function parameter";}return ExprMutator::VisitExpr_(op);}//使用变量对应的实际值(const)替换变量(var)Expr VisitExpr_(const VarNode* op) final {auto id = GetRef<Var>(op);//id是Varauto it = args_map_.find(id);if (it != args_map_.end()) {return (*it).second;//返回变量的实际值(const)} else {return std::move(id);}}Pattern VisitPattern(const Pattern& p) final {return PatternMutator::VisitPattern(p);}Clause VisitClause(const Clause& c) final {Pattern pat = VisitPattern(c->lhs);return ClauseNode::make(pat, VisitExpr(c->rhs));}Var VisitVar(const Var& v) final {CHECK(!args_map_.count(v))<< "Cannnot bind an internal pattern variable";return v;}private:const tvm::Map<Var, Expr>& args_map_;};Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {if (const FunctionNode* func = expr.as<FunctionNode>()) {Expr new_body = ExprBinder(args_map).VisitExpr(func->body);Array<Var> new_params;for (Var param : func->params) {if (!args_map.count(param)) {new_params.push_back(param);}}if (new_body.same_as(func->body) &&new_params.size() == func->params.size()) {return expr;}auto ret = FunctionNode::make(new_params,new_body,func->ret_type,func->type_params,func->attrs);std::unordered_set<Var, NodeHash, NodeEqual> set;for (const auto& v : FreeVars(expr)) {set.insert(v);}for (const auto& v : FreeVars(ret)) {if (set.count(v) == 0) {new_params.push_back(v);}}ret = FunctionNode::make(new_params,new_body,func->ret_type,func->type_params,func->attrs);CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());return std::move(ret);} else {return ExprBinder(args_map).VisitExpr(expr);}}TVM_REGISTER_API("relay._expr.Bind").set_body([](TVMArgs args, TVMRetValue* ret) {NodeRef input = args[0];if (input->IsInstance<ExprNode>()) {*ret = Bind(Downcast<Expr>(input), args[1]);} else {CHECK(input->IsInstance<TypeNode>());*ret = Bind(Downcast<Type>(input), args[1]);}});
详细解释:
Bind的第一个参数是const Expr& expr类型,是要绑定的表达式。
第二个参数是const tvm::Map& args_map,表示Var->Expr的映射, 对于参数来说,Expr一般都对应ConstantNode类型。
Bind函数会在expr所表示的图中,找到所有的VarNode,检查该VarNode的Var是否在args_map中,如果在,就用args_map中对应的Expr替换Var。
