ExprVisitor之visitcounter源码学习


std::unorderedmap ref_counter; Expr Reorder(const Expr& expr) { refcounter = GetExprRefCount(expr); return this->Mutate(expr); }

/! \brief Get reference counter of each internal ExprNode in body. \param body The body expression. \return The reference count mapping. /std::unordered_map<const Node, sizet>GetExprRefCount(const Expr& body) { class ExprRefCounter : private ExprVisitor { public: std::unordered_map Get(const Expr& body) { this->VisitExpr(body); return std::move(this->visit_counter); } }; return ExprRefCounter().Get(body);}

std::unorderedmap visit_counter;
void ExprVisitor::VisitExpr(const Expr& expr) { auto it = visitcounter.find(expr.get()); if (it != visitcounter.end()) { ++it->second; } else { using TParent = ExprFunctor; TParent::VisitExpr(expr); visitcounter.insert({expr.get(), 1}); }}
using TSelf = ExprFunctor; using FType = NodeFunctor;
virtual R VisitExpr(const Expr& n, Args… args) { static FType vtable = InitVTable(); return vtable(n, this, std::forward(args)…);//此处最终会调用ExprVisitor::VisitExpr(n.get())ExprVisitor::VisitExpr }
static FType InitVTable() { FType vtable; // Set dispatch RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode); RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode); RELAY_EXPR_FUNCTOR_DISPATCH(VarNode); RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode); RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode); RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode); return vtable; }
RELAY_EXPR_FUNCTOR_DISPATCH#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ { \ return self->VisitExpr
(staticcast(n.get()), \ std::forward(args)…); \ });
//set_dispatch的定义template TSelf& set_dispatch(FPointer f) { // NOLINT(*) uint32_t tindex = TNode::RuntimeTypeIndex(); if (func
.size() <= tindex) { func.resize(tindex + 1, nullptr); } CHECK(func[tindex] == nullptr) << “Dispatch for “ << TNode::type_key << “ is already set”; func[tindex] = f; return *this; }