重点类

tvm中3个最基本的类:Object, ObjectPtr,ObjectRef
ObjectPtr包含一个Object指针类型的字段
ObjectRef包含一个ObjectPtr类型的字段
详细参考tvm源码学习1:Object, ObjectPtr, ObjectRef

重点函数

make_object, make_node

使用默认的分配器来分配一个object智能指针ObjectPtr,

  1. /*!
  2. * \brief Allocate an object using default allocator.
  3. * \param args arguments to the constructor.
  4. * \tparam T the node type.
  5. * \return The ObjectPtr to the allocated object.
  6. */
  7. template<typename T, typename... Args>
  8. inline ObjectPtr<T> make_object(Args&&... args);

在make_object的基础上还有一个make_node函数,是新建一个Node智能指针

  1. /*!
  2. * \brief Allocate a node object.
  3. * \param args arguments to the constructor.
  4. * \tparam T the node type.
  5. * \return The NodePtr to the allocated object.
  6. * \note This function is an alias of make_object.
  7. */
  8. template<typename T, typename... Args>
  9. inline NodePtr<T> make_node(Args&&... args) {
  10. return runtime::make_object<T>(std::forward<Args>(args)...);
  11. }

GetRef

通过GetRef函数可以从Object类指针获取其引用对象ObjectRef

  1. /*!
  2. * \brief Get a reference type from a raw object ptr type
  3. *
  4. * It is always important to get a reference type
  5. * if we want to return a value as reference or keep
  6. * the object alive beyond the scope of the function.
  7. *
  8. * \param ptr The object pointer
  9. * \tparam RefType The reference type
  10. * \tparam ObjectType The object type
  11. * \return The corresponding RefType
  12. */
  13. template <typename RefType, typename ObjectType>
  14. inline RefType GetRef(const ObjectType* ptr);
  15. template <typename RefType, typename ObjType>
  16. inline RefType GetRef(const ObjType* ptr) {
  17. static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
  18. "Can only cast to the ref of same container type");
  19. return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
  20. }
  21. // 例
  22. Expr ExprMutator::VisitExpr_(const OpNode* op) {
  23. return GetRef<Expr>(op);
  24. }
  25. // 该op是OpNode类指针, GetRef<Expr>(op)可以获取op的引用对象Expr

Downcast

该函数将父类对象向下转换成具体的子类对象,不管是父类还是子类必须都是ObjectRef类的子类。

  1. /*!
  2. * \brief Downcast a base reference type to a more specific type.
  3. *
  4. * \param ref The inptut reference
  5. * \return The corresponding SubRef.
  6. * \tparam SubRef The target specific reference type.
  7. * \tparam BaseRef the current reference type.
  8. */
  9. template <typename SubRef, typename BaseRef>
  10. inline SubRef Downcast(BaseRef ref);
  11. template <typename SubRef, typename BaseRef>
  12. inline SubRef Downcast(BaseRef ref) {
  13. CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
  14. << "Downcast from " << ref->GetTypeKey() << " to "
  15. << SubRef::ContainerType::_type_key << " failed.";
  16. return SubRef(std::move(ref.data_));
  17. }
  18. // 例
  19. NodeRef tmp = z;
  20. Expr zz = Downcast<Expr>(tmp);
  21. // 该tmp是NodeRef类对象,zz是Expr类对象,Expr是NodeRef类的子类

GetObjectPtr

指定一个Object类指针,返回它相关联的智能指针ObjectPtr

  1. /*!
  2. * \brief Get an object ptr type from a raw object ptr.
  3. *
  4. * \param ptr The object pointer
  5. * \tparam BaseType The reference type
  6. * \tparam ObjectType The object type
  7. * \return The corresponding RefType
  8. */
  9. template <typename BaseType, typename ObjectType>
  10. inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
  11. template <typename BaseType, typename ObjType>
  12. inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
  13. static_assert(std::is_base_of<BaseType, ObjType>::value,
  14. "Can only cast to the ref of same container type");
  15. return ObjectPtr<BaseType>(static_cast<Object*>(ptr));
  16. }
  17. // 例

ObjectRef::get

该函数用于获取ObjectRef类对象内部引用的Object类指针(const 指针), 相当于由ObjectRef类对象获取Object类指针

  1. /*! \return the internal object pointer */
  2. const Object* get() const {
  3. return data_.get();
  4. }
  5. // 例

ObjectRef::as

该函数用于尝试将ObjectRef类对象的内部引用对象转换成对应Object类型的指针(const 指针),如果返回nullptr表示失败

  1. template <typename ObjectType>
  2. inline const ObjectType* ObjectRef::as() const {
  3. if (data_ != nullptr &&
  4. data_->IsInstance<ObjectType>()) {
  5. return static_cast<ObjectType*>(data_.get());
  6. } else {
  7. return nullptr;
  8. }
  9. }
  10. // 例
  11. const CallNode* prev_node = res_call->args[0].as<CallNode>();//将Expr类对象转换成CallNode类型指针

ObjectRef::same_as

用于比较2个ObjectRef类对象是否相同, 例如

  1. /*!
  2. * \brief Comparator
  3. * \param other Another object ref.
  4. * \return the compare result.
  5. */
  6. bool same_as(const ObjectRef& other) const {
  7. return data_ == other.data_;
  8. }
  9. // 例
  10. static const Op& relu_op = Op::Get("nn.relu");
  11. if (!call->op.same_as(relu_op)) {
  12. return false; // not relu op
  13. }

GetExprRefCount

用于获取graph中的每个内部表达式Node的引用次数,源码如下:
代码路径:src/relay/pass/util.cc

  1. *!
  2. * \brief Get reference counter of each internal ExprNode in body.
  3. * \param body The body expression.
  4. * \return The reference count mapping.
  5. */
  6. std::unordered_map<const Node*, size_t>
  7. GetExprRefCount(const Expr& body) {
  8. class ExprRefCounter : private ExprVisitor {
  9. public:
  10. std::unordered_map<const Node*, size_t>
  11. Get(const Expr& body) {
  12. this->VisitExpr(body);
  13. return std::move(this->visit_counter_);
  14. }
  15. };
  16. return ExprRefCounter().Get(body);
  17. }

ExprNode::type_as

ExprNode::type_as尝试检查ExprNode的类型是否是指定的某类型

  1. template<typename TTypeNode>
  2. inline const TTypeNode* ExprNode::type_as() const {
  3. static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
  4. "TType must be a special case of type");
  5. CHECK(checked_type_.defined())
  6. << "Type inference for this Expr has not completed. Try to call infer_type pass.";
  7. const TTypeNode* node = checked_type_.as<TTypeNode>();
  8. CHECK(node != nullptr)
  9. << "Expected type to be " << TTypeNode::_type_key
  10. << ", but get " << checked_type_->GetTypeKey();
  11. return node;
  12. }
  13. // 例如:
  14. Expr data = res_call->args[0]; // input data
  15. TensorTypeNode* dat_node = data->type_as<TensorTypeNode>();//检查输入是否是TensorTypeNode类型

调用typeas()的前提是该ExprNode的checked_type是defined(),否则就要手动获取它。

  1. // 下面是一个从Expr推断它的类型的函数
  2. static inline Type Expr_InferType(const Expr& expr) {
  3. auto mod = relay::ModuleNode::FromExpr(expr);
  4. mod = relay::transform::InferType()(mod);
  5. auto type_fx = mod->Lookup("main");
  6. return type_fx->ret_type;
  7. }
  8. // 下面是调用Expr_InferType和type_as的例子
  9. if (!args[0]->checked_type_.defined()) {
  10. args[0]->checked_type_ = Expr_InferType(args[i]);
  11. }
  12. auto ttype = args[0]->type_as<TensorTypeNode>();

ExprNode::checked_type()

返回类型推断(检查)的结果(Type类型)

  1. inline const Type& ExprNode::checked_type() const {
  2. CHECK(checked_type_.defined())
  3. << "internal error: the type checker has "
  4. << "not populated the checked_type "
  5. << "field for "
  6. << GetRef<Expr>(this);
  7. return this->checked_type_;
  8. }