ExprVisitor之visitcounter源码学习

创建时间: 2020-06-15 16:00
更新时间: 2020-06-15 21:07
作者: 先之
标签: TVM, 熊选文

// 1 应用
std::unorderedmap ref_counter;

Expr Reorder(const Expr& expr) {
refcounter = GetExprRefCount(expr);
return this->Mutate(expr);
}

// 2 全局函数
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unorderedmap
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
public:
std::unordered_map
Get(const Expr& _body
) {
this->VisitExpr(body);
return std::move(this->visitcounter);
}
};
return ExprRefCounter().Get(body);
}

// 3 class ExprVisitor:

std::unorderedmap visit_counter;

void ExprVisitor::VisitExpr(const Expr& expr) {
auto it = visitcounter.find(expr.get());
if (it != visitcounter.end()) {
++it->second;
} else {
using TParent = ExprFunctor;
TParent::VisitExpr(expr);
visitcounter.insert({expr.get(), 1});
}
}

// 4 class ExprFunctor:
using TSelf = ExprFunctor;
using FType = NodeFunctor;

virtual R VisitExpr(const Expr& n, Args… args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward(args)…);//此处最终会调用ExprVisitor::VisitExpr(n.get()),如果是CallNode就会向上递归,反复调用ExprVisitor::VisitExpr
}
_

static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
return vtable;
}

// RELAYEXPR_FUNCTOR_DISPATCH的定义
#define RELAY_EXPR_FUNCTOR_DISPATCH(_OP
) \
vtable.template setdispatch( \
[](const ObjectRef& _n
, TSelf self, Args… args) { \
return self->VisitExpr_(static_cast<const OP
>(n.get()), \
std::forward(args)…); \
});

// 5 Class NodeFunctor
// setdispatch的定义
template
TSelf& set_dispatch(FPointer _f
) { // NOLINT(*)
uint32t tindex = TNode::RuntimeTypeIndex();
if (func
.size() <= tindex) {
func.resize(tindex + 1, nullptr);
}
CHECK(func
[tindex] == nullptr)
<< “Dispatch for “ << TNode::type_key
<< “ is already set”;
func
[tindex] = f;
return *this;
}