重点类
tvm中3个最基本的类:Object, ObjectPtr,ObjectRef
ObjectPtr包含一个Object指针类型的字段
ObjectRef包含一个ObjectPtr类型的字段
详细参考tvm源码学习1:Object, ObjectPtr, ObjectRef
重点函数
make_object, make_node
使用默认的分配器来分配一个object智能指针ObjectPtr,
/*!* \brief Allocate an object using default allocator.* \param args arguments to the constructor.* \tparam T the node type.* \return The ObjectPtr to the allocated object.*/template<typename T, typename... Args>inline ObjectPtr<T> make_object(Args&&... args);
在make_object的基础上还有一个make_node函数,是新建一个Node智能指针
/*!* \brief Allocate a node object.* \param args arguments to the constructor.* \tparam T the node type.* \return The NodePtr to the allocated object.* \note This function is an alias of make_object.*/template<typename T, typename... Args>inline NodePtr<T> make_node(Args&&... args) {return runtime::make_object<T>(std::forward<Args>(args)...);}
GetRef
通过GetRef函数可以从Object类指针获取其引用对象ObjectRef
/*!* \brief Get a reference type from a raw object ptr type** It is always important to get a reference type* if we want to return a value as reference or keep* the object alive beyond the scope of the function.** \param ptr The object pointer* \tparam RefType The reference type* \tparam ObjectType The object type* \return The corresponding RefType*/template <typename RefType, typename ObjectType>inline RefType GetRef(const ObjectType* ptr);template <typename RefType, typename ObjType>inline RefType GetRef(const ObjType* ptr) {static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,"Can only cast to the ref of same container type");return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));}// 例Expr ExprMutator::VisitExpr_(const OpNode* op) {return GetRef<Expr>(op);}// 该op是OpNode类指针, GetRef<Expr>(op)可以获取op的引用对象Expr
Downcast
该函数将父类对象向下转换成具体的子类对象,不管是父类还是子类必须都是ObjectRef类的子类。
/*!* \brief Downcast a base reference type to a more specific type.** \param ref The inptut reference* \return The corresponding SubRef.* \tparam SubRef The target specific reference type.* \tparam BaseRef the current reference type.*/template <typename SubRef, typename BaseRef>inline SubRef Downcast(BaseRef ref);template <typename SubRef, typename BaseRef>inline SubRef Downcast(BaseRef ref) {CHECK(ref->template IsInstance<typename SubRef::ContainerType>())<< "Downcast from " << ref->GetTypeKey() << " to "<< SubRef::ContainerType::_type_key << " failed.";return SubRef(std::move(ref.data_));}// 例NodeRef tmp = z;Expr zz = Downcast<Expr>(tmp);// 该tmp是NodeRef类对象,zz是Expr类对象,Expr是NodeRef类的子类
GetObjectPtr
指定一个Object类指针,返回它相关联的智能指针ObjectPtr
/*!* \brief Get an object ptr type from a raw object ptr.** \param ptr The object pointer* \tparam BaseType The reference type* \tparam ObjectType The object type* \return The corresponding RefType*/template <typename BaseType, typename ObjectType>inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);template <typename BaseType, typename ObjType>inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {static_assert(std::is_base_of<BaseType, ObjType>::value,"Can only cast to the ref of same container type");return ObjectPtr<BaseType>(static_cast<Object*>(ptr));}// 例
ObjectRef::get
该函数用于获取ObjectRef类对象内部引用的Object类指针(const 指针), 相当于由ObjectRef类对象获取Object类指针
/*! \return the internal object pointer */const Object* get() const {return data_.get();}// 例
ObjectRef::as
该函数用于尝试将ObjectRef类对象的内部引用对象转换成对应Object类型的指针(const 指针),如果返回nullptr表示失败
template <typename ObjectType>inline const ObjectType* ObjectRef::as() const {if (data_ != nullptr &&data_->IsInstance<ObjectType>()) {return static_cast<ObjectType*>(data_.get());} else {return nullptr;}}// 例const CallNode* prev_node = res_call->args[0].as<CallNode>();//将Expr类对象转换成CallNode类型指针
ObjectRef::same_as
用于比较2个ObjectRef类对象是否相同, 例如
/*!* \brief Comparator* \param other Another object ref.* \return the compare result.*/bool same_as(const ObjectRef& other) const {return data_ == other.data_;}// 例static const Op& relu_op = Op::Get("nn.relu");if (!call->op.same_as(relu_op)) {return false; // not relu op}
GetExprRefCount
用于获取graph中的每个内部表达式Node的引用次数,源码如下:
代码路径:src/relay/pass/util.cc
*!* \brief Get reference counter of each internal ExprNode in body.* \param body The body expression.* \return The reference count mapping.*/std::unordered_map<const Node*, size_t>GetExprRefCount(const Expr& body) {class ExprRefCounter : private ExprVisitor {public:std::unordered_map<const Node*, size_t>Get(const Expr& body) {this->VisitExpr(body);return std::move(this->visit_counter_);}};return ExprRefCounter().Get(body);}
ExprNode::type_as
ExprNode::type_as尝试检查ExprNode的类型是否是指定的某类型
template<typename TTypeNode>inline const TTypeNode* ExprNode::type_as() const {static_assert(std::is_base_of<TypeNode, TTypeNode>::value,"TType must be a special case of type");CHECK(checked_type_.defined())<< "Type inference for this Expr has not completed. Try to call infer_type pass.";const TTypeNode* node = checked_type_.as<TTypeNode>();CHECK(node != nullptr)<< "Expected type to be " << TTypeNode::_type_key<< ", but get " << checked_type_->GetTypeKey();return node;}// 例如:Expr data = res_call->args[0]; // input dataTensorTypeNode* dat_node = data->type_as<TensorTypeNode>();//检查输入是否是TensorTypeNode类型
调用typeas()的前提是该ExprNode的checked_type是defined(),否则就要手动获取它。
// 下面是一个从Expr推断它的类型的函数static inline Type Expr_InferType(const Expr& expr) {auto mod = relay::ModuleNode::FromExpr(expr);mod = relay::transform::InferType()(mod);auto type_fx = mod->Lookup("main");return type_fx->ret_type;}// 下面是调用Expr_InferType和type_as的例子if (!args[0]->checked_type_.defined()) {args[0]->checked_type_ = Expr_InferType(args[i]);}auto ttype = args[0]->type_as<TensorTypeNode>();
ExprNode::checked_type()
返回类型推断(检查)的结果(Type类型)
inline const Type& ExprNode::checked_type() const {CHECK(checked_type_.defined())<< "internal error: the type checker has "<< "not populated the checked_type "<< "field for "<< GetRef<Expr>(this);return this->checked_type_;}
