TVM常用的类和函数


TVM常用的类和函数
GetRef通过GetRef函数可以从Object类指针获取其引用对象ObjectRef
template inline RefType GetRef(const ObjType ptr) { static_assert(std::is_base_of::value, “Can only cast to the ref of same container type”); return RefType(ObjectPtr(const_cast<Object>(staticcast(ptr))));}
例如:Expr ExprMutator::VisitExpr
(const OpNode op) { return GetRef(op);}
Node类转为NodeRef:Node
node;GetRef(node);
该op是OpNode类指针, GetRef(op)可以获取op的引用对象Expr
ObjectRef::get该函数用于获取ObjectRef类对象内部引用的Object类指针(const 指针), 相当于由ObjectRef类对象获取Object类指针 /! \return the internal object pointer / const Object get() const { return data_.get(); }例如:
ObjectRef::as该函数用于尝试将ObjectRef类对象的内部引用对象转换成对应Object类型的指针(const 指针),如果返回nullptr表示失败template inline const ObjectType
ObjectRef::as() const { if (data != nullptr && data->IsInstance()) { return staticcast(data.get()); } else { return nullptr; }}例如:const CallNode prev_node = res_call->args[0].as();//将Expr类对象转换成CallNode类型指针
ObjectRef::same_as用于比较2个ObjectRef类对象是否相同, 例如 static const Op& relu_op = Op::Get(“nn.relu”); if (!call->op.same_as(relu_op)) { return false; // not relu op }
Downcast 该函数将父类对象向下转换成子类对象NodeRef tmp = z;Expr zz = Downcast(tmp);该tmp是NodeRef类对象,zz是Expr类对象,Expr是NodeRef类的子类
GetExprRefCount用于获取graph中的每个内部表达式Node的引用次数,源码如下:代码路径:src/relay/pass/util.cc/
! \brief Get reference counter of each internal ExprNode in body. \param body The body expression. \return The reference count mapping. /std::unorderedmapGetExprRefCount(const Expr& body) { class ExprRefCounter : private ExprVisitor { public: std::unordered_map Get(const Expr& body) { this->VisitExpr(body); return std::move(this->visit_counter); } }; return ExprRefCounter().Get(body);}