_expr.Bind函数用来绑定expr中的变量,就是用参数字典中提供的参数值(expr)替换指定expr中的变量(var)

python中的定义

定义文件:python/tvm/relay/expr.py

  1. def bind(expr, binds):
  2. """Bind an free variables in expr or function arguments.
  3. We can bind parameters expr if it is a function.
  4. Parameters
  5. ----------
  6. expr : tvm.relay.Expr
  7. The input expression.
  8. binds : Union[Map[tvm.relay.Var, tvm.relay.Expr], Map[str, tvm.relay.Expr]]
  9. The specific bindings.
  10. Returns
  11. -------
  12. result : tvm.relay.Expr
  13. The expression or function after binding.
  14. """
  15. return _expr.Bind(expr, binds)

_expr.Bind其实是在c++中注册的一个全局函数:relay._expr.Bind

c++注册

注册文件:src/relay/ir/expr_functor.cc

  1. // Implement bind.
  2. class ExprBinder : public ExprMutator, PatternMutator {
  3. public:
  4. explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
  5. : args_map_(args_map) {
  6. }
  7. Expr VisitExpr_(const LetNode* op) final {
  8. CHECK(!args_map_.count(op->var))
  9. << "Cannot bind an internel variable in let";
  10. return ExprMutator::VisitExpr_(op);
  11. }
  12. Expr VisitExpr_(const FunctionNode* op) final {
  13. for (Var param : op->params) {
  14. CHECK(!args_map_.count(param))
  15. << "Cannnot bind an internal function parameter";
  16. }
  17. return ExprMutator::VisitExpr_(op);
  18. }
  19. //使用变量对应的实际值(const)替换变量(var)
  20. Expr VisitExpr_(const VarNode* op) final {
  21. auto id = GetRef<Var>(op);//id是Var
  22. auto it = args_map_.find(id);
  23. if (it != args_map_.end()) {
  24. return (*it).second;//返回变量的实际值(const)
  25. } else {
  26. return std::move(id);
  27. }
  28. }
  29. Pattern VisitPattern(const Pattern& p) final {
  30. return PatternMutator::VisitPattern(p);
  31. }
  32. Clause VisitClause(const Clause& c) final {
  33. Pattern pat = VisitPattern(c->lhs);
  34. return ClauseNode::make(pat, VisitExpr(c->rhs));
  35. }
  36. Var VisitVar(const Var& v) final {
  37. CHECK(!args_map_.count(v))
  38. << "Cannnot bind an internal pattern variable";
  39. return v;
  40. }
  41. private:
  42. const tvm::Map<Var, Expr>& args_map_;
  43. };
  44. Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
  45. if (const FunctionNode* func = expr.as<FunctionNode>()) {
  46. Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
  47. Array<Var> new_params;
  48. for (Var param : func->params) {
  49. if (!args_map.count(param)) {
  50. new_params.push_back(param);
  51. }
  52. }
  53. if (new_body.same_as(func->body) &&
  54. new_params.size() == func->params.size()) {
  55. return expr;
  56. }
  57. auto ret = FunctionNode::make(new_params,
  58. new_body,
  59. func->ret_type,
  60. func->type_params,
  61. func->attrs);
  62. std::unordered_set<Var, NodeHash, NodeEqual> set;
  63. for (const auto& v : FreeVars(expr)) {
  64. set.insert(v);
  65. }
  66. for (const auto& v : FreeVars(ret)) {
  67. if (set.count(v) == 0) {
  68. new_params.push_back(v);
  69. }
  70. }
  71. ret = FunctionNode::make(new_params,
  72. new_body,
  73. func->ret_type,
  74. func->type_params,
  75. func->attrs);
  76. CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
  77. return std::move(ret);
  78. } else {
  79. return ExprBinder(args_map).VisitExpr(expr);
  80. }
  81. }
  82. TVM_REGISTER_API("relay._expr.Bind")
  83. .set_body([](TVMArgs args, TVMRetValue* ret) {
  84. NodeRef input = args[0];
  85. if (input->IsInstance<ExprNode>()) {
  86. *ret = Bind(Downcast<Expr>(input), args[1]);
  87. } else {
  88. CHECK(input->IsInstance<TypeNode>());
  89. *ret = Bind(Downcast<Type>(input), args[1]);
  90. }
  91. });

详细解释:
Bind的第一个参数是const Expr& expr类型,是要绑定的表达式。
第二个参数是const tvm::Map& args_map,表示Var->Expr的映射, 对于参数来说,Expr一般都对应ConstantNode类型。
Bind函数会在expr所表示的图中,找到所有的VarNode,检查该VarNode的Var是否在args_map中,如果在,就用args_map中对应的Expr替换Var。