重点类
tvm中3个最基本的类:Object, ObjectPtr,ObjectRef
ObjectPtr包含一个Object指针类型的字段
ObjectRef包含一个ObjectPtr类型的字段
详细参考tvm源码学习1:Object, ObjectPtr, ObjectRef
重点函数
make_object, make_node
使用默认的分配器来分配一个object智能指针ObjectPtr,
/*!
* \brief Allocate an object using default allocator.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The ObjectPtr to the allocated object.
*/
template<typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args);
在make_object的基础上还有一个make_node函数,是新建一个Node智能指针
/*!
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
* \note This function is an alias of make_object.
*/
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args) {
return runtime::make_object<T>(std::forward<Args>(args)...);
}
GetRef
通过GetRef函数可以从Object类指针获取其引用对象ObjectRef
/*!
* \brief Get a reference type from a raw object ptr type
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the object alive beyond the scope of the function.
*
* \param ptr The object pointer
* \tparam RefType The reference type
* \tparam ObjectType The object type
* \return The corresponding RefType
*/
template <typename RefType, typename ObjectType>
inline RefType GetRef(const ObjectType* ptr);
template <typename RefType, typename ObjType>
inline RefType GetRef(const ObjType* ptr) {
static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
"Can only cast to the ref of same container type");
return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
}
// 例
Expr ExprMutator::VisitExpr_(const OpNode* op) {
return GetRef<Expr>(op);
}
// 该op是OpNode类指针, GetRef<Expr>(op)可以获取op的引用对象Expr
Downcast
该函数将父类对象向下转换成具体的子类对象,不管是父类还是子类必须都是ObjectRef类的子类。
/*!
* \brief Downcast a base reference type to a more specific type.
*
* \param ref The inptut reference
* \return The corresponding SubRef.
* \tparam SubRef The target specific reference type.
* \tparam BaseRef the current reference type.
*/
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref);
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) {
CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
<< "Downcast from " << ref->GetTypeKey() << " to "
<< SubRef::ContainerType::_type_key << " failed.";
return SubRef(std::move(ref.data_));
}
// 例
NodeRef tmp = z;
Expr zz = Downcast<Expr>(tmp);
// 该tmp是NodeRef类对象,zz是Expr类对象,Expr是NodeRef类的子类
GetObjectPtr
指定一个Object类指针,返回它相关联的智能指针ObjectPtr
/*!
* \brief Get an object ptr type from a raw object ptr.
*
* \param ptr The object pointer
* \tparam BaseType The reference type
* \tparam ObjectType The object type
* \return The corresponding RefType
*/
template <typename BaseType, typename ObjectType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
template <typename BaseType, typename ObjType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
static_assert(std::is_base_of<BaseType, ObjType>::value,
"Can only cast to the ref of same container type");
return ObjectPtr<BaseType>(static_cast<Object*>(ptr));
}
// 例
ObjectRef::get
该函数用于获取ObjectRef类对象内部引用的Object类指针(const 指针), 相当于由ObjectRef类对象获取Object类指针
/*! \return the internal object pointer */
const Object* get() const {
return data_.get();
}
// 例
ObjectRef::as
该函数用于尝试将ObjectRef类对象的内部引用对象转换成对应Object类型的指针(const 指针),如果返回nullptr表示失败
template <typename ObjectType>
inline const ObjectType* ObjectRef::as() const {
if (data_ != nullptr &&
data_->IsInstance<ObjectType>()) {
return static_cast<ObjectType*>(data_.get());
} else {
return nullptr;
}
}
// 例
const CallNode* prev_node = res_call->args[0].as<CallNode>();//将Expr类对象转换成CallNode类型指针
ObjectRef::same_as
用于比较2个ObjectRef类对象是否相同, 例如
/*!
* \brief Comparator
* \param other Another object ref.
* \return the compare result.
*/
bool same_as(const ObjectRef& other) const {
return data_ == other.data_;
}
// 例
static const Op& relu_op = Op::Get("nn.relu");
if (!call->op.same_as(relu_op)) {
return false; // not relu op
}
GetExprRefCount
用于获取graph中的每个内部表达式Node的引用次数,源码如下:
代码路径:src/relay/pass/util.cc
*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
public:
std::unordered_map<const Node*, size_t>
Get(const Expr& body) {
this->VisitExpr(body);
return std::move(this->visit_counter_);
}
};
return ExprRefCounter().Get(body);
}
ExprNode::type_as
ExprNode::type_as尝试检查ExprNode的类型是否是指定的某类型
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed. Try to call infer_type pass.";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
<< ", but get " << checked_type_->GetTypeKey();
return node;
}
// 例如:
Expr data = res_call->args[0]; // input data
TensorTypeNode* dat_node = data->type_as<TensorTypeNode>();//检查输入是否是TensorTypeNode类型
调用typeas()的前提是该ExprNode的checked_type是defined(),否则就要手动获取它。
// 下面是一个从Expr推断它的类型的函数
static inline Type Expr_InferType(const Expr& expr) {
auto mod = relay::ModuleNode::FromExpr(expr);
mod = relay::transform::InferType()(mod);
auto type_fx = mod->Lookup("main");
return type_fx->ret_type;
}
// 下面是调用Expr_InferType和type_as的例子
if (!args[0]->checked_type_.defined()) {
args[0]->checked_type_ = Expr_InferType(args[i]);
}
auto ttype = args[0]->type_as<TensorTypeNode>();
ExprNode::checked_type()
返回类型推断(检查)的结果(Type类型)
inline const Type& ExprNode::checked_type() const {
CHECK(checked_type_.defined())
<< "internal error: the type checker has "
<< "not populated the checked_type "
<< "field for "
<< GetRef<Expr>(this);
return this->checked_type_;
}