_expr.Bind函数用来绑定expr中的变量,就是用参数字典中提供的参数值(expr)替换指定expr中的变量(var)
python中的定义
定义文件:python/tvm/relay/expr.py
def bind(expr, binds):
"""Bind an free variables in expr or function arguments.
We can bind parameters expr if it is a function.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
binds : Union[Map[tvm.relay.Var, tvm.relay.Expr], Map[str, tvm.relay.Expr]]
The specific bindings.
Returns
-------
result : tvm.relay.Expr
The expression or function after binding.
"""
return _expr.Bind(expr, binds)
_expr.Bind其实是在c++中注册的一个全局函数:relay._expr.Bind
c++注册
注册文件:src/relay/ir/expr_functor.cc
// Implement bind.
class ExprBinder : public ExprMutator, PatternMutator {
public:
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
: args_map_(args_map) {
}
Expr VisitExpr_(const LetNode* op) final {
CHECK(!args_map_.count(op->var))
<< "Cannot bind an internel variable in let";
return ExprMutator::VisitExpr_(op);
}
Expr VisitExpr_(const FunctionNode* op) final {
for (Var param : op->params) {
CHECK(!args_map_.count(param))
<< "Cannnot bind an internal function parameter";
}
return ExprMutator::VisitExpr_(op);
}
//使用变量对应的实际值(const)替换变量(var)
Expr VisitExpr_(const VarNode* op) final {
auto id = GetRef<Var>(op);//id是Var
auto it = args_map_.find(id);
if (it != args_map_.end()) {
return (*it).second;//返回变量的实际值(const)
} else {
return std::move(id);
}
}
Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
}
Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
}
Var VisitVar(const Var& v) final {
CHECK(!args_map_.count(v))
<< "Cannnot bind an internal pattern variable";
return v;
}
private:
const tvm::Map<Var, Expr>& args_map_;
};
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
if (const FunctionNode* func = expr.as<FunctionNode>()) {
Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
Array<Var> new_params;
for (Var param : func->params) {
if (!args_map.count(param)) {
new_params.push_back(param);
}
}
if (new_body.same_as(func->body) &&
new_params.size() == func->params.size()) {
return expr;
}
auto ret = FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
std::unordered_set<Var, NodeHash, NodeEqual> set;
for (const auto& v : FreeVars(expr)) {
set.insert(v);
}
for (const auto& v : FreeVars(ret)) {
if (set.count(v) == 0) {
new_params.push_back(v);
}
}
ret = FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
return std::move(ret);
} else {
return ExprBinder(args_map).VisitExpr(expr);
}
}
TVM_REGISTER_API("relay._expr.Bind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef input = args[0];
if (input->IsInstance<ExprNode>()) {
*ret = Bind(Downcast<Expr>(input), args[1]);
} else {
CHECK(input->IsInstance<TypeNode>());
*ret = Bind(Downcast<Type>(input), args[1]);
}
});
详细解释:
Bind的第一个参数是const Expr& expr类型,是要绑定的表达式。
第二个参数是const tvm::Map& args_map,表示Var->Expr的映射, 对于参数来说,Expr一般都对应ConstantNode类型。
Bind函数会在expr所表示的图中,找到所有的VarNode,检查该VarNode的Var是否在args_map中,如果在,就用args_map中对应的Expr替换Var。