以Tuple和TupleNode为例(Tuple是NodeRef的子类, TupleNode是Node的子类):
定义Node子类
/*! \brief Tuple of multiple Exprs */
class Tuple;
/*! \brief Tuple container */
class TupleNode : public ExprNode {
public:
// 1 自定义的字段
/*! \brief the fields of the tuple */
tvm::Array<relay::Expr> fields;
// 2 访问属性(按格式模仿即可)
// 注意: fields是自定义的,而span和checked_type_是继承来的
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
// 3 声明make函数
TVM_DLL static Tuple make(tvm::Array<relay::Expr> fields);
// 4 定义个唯一的字符串类型的类型关键字字段:_type_key
static constexpr const char* _type_key = "relay.Tuple";
// 5 声明节点类型信息,TupleNode是ExprNode的子类,且TupleNode是终端节点
TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode);
};
// 6 定义节点引用,Tuple是TupleNode的引用类型,Tuple是Expr的子类
RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr);
// 7 定义make函数
Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
NodePtr<TupleNode> n = make_node<TupleNode>();
n->fields = std::move(fields);
return Tuple(n);
}
// 8 将节点类型TupleNode注册到对象注册表和反射注册表。
TVM_REGISTER_NODE_TYPE(TupleNode);
2 定义NodeRef子类
2.1 通过RELAY_DEFINE_NODE_REF宏快速定义
如果没有特殊的要求(扩展成员变量或函数),直接通过RELAY_DEFINE_NODE_REF可以快速定义NodeRef子类,简单方便。
Tuple是通过RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr)定义的,Tuple是Expr的子类,Tuple是TupleNode的引用类型。
- (1) 在TupleNode定义之前先声明Tuple类
- (2) 在TupleNode定义之后使用RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr);来定义Tupe类。
RELAY_DEFINE_NODE_REF宏的定义如下:
/*!
* \brief Macro to make it easy to define node ref type given node
* \param TypeName The name of the reference type.
* \param NodeName The internal container name.
* \param NodeRefBase The base type.
*/
#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \
class TypeName : public NodeRefBase { \
public: \
TypeName() {} \
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
: NodeRefBase(n) { \
} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(get()); \
} \
operator bool() { return this->defined(); } \
using ContainerType = NodeName; \
};
可见,该宏其实实现了以下几个功能:
- TypeName类继承NodeRefBase
- 定义TypeName类的构造函数(一个不带参,一个带参)
- 定义operator->(),可使用-> 操作符直接访问Node类指针
- 定义bool()函数
- 添加using ContainerType = NodeName;
假设t是一个Tuple类对象,那么t->将会指向一个TupleNode *指针。
2.2 自定义NodeRef子类
如果有特殊的要求(扩展成员变量或函数),通过RELAY_DEFINE_NODE_REF定义的NodeRef子类不满足要求,就需要自定义。
- 自定义类继承自NodeRef
- 定义构造函数(一个不带参,一个带参)
- 定义operator->(),可使用-> 操作符直接访问Node类指针
- 在类定义的结尾处添加using ContainerType = NodeName; 注意:NodeName是自定义的类名。
- 添加扩展成员变量或函数…
例如,QConfig的定义如下:
/*!
* \brief Container for build configuration options
*/
class QConfig : public NodeRef {
public:
QConfig() {}
explicit QConfig(ObjectPtr<Object> n) : NodeRef(n) {}
const QConfigNode* operator->() const {
return static_cast<const QConfigNode*>(get());
}
QConfigNode* operator->() {
return static_cast<QConfigNode*>(get_mutable());
}
/*!
* \brief Push a new BuildConfig context onto the thread local stack.
* \param build_config The configuration to set as the current context.
*/
static void EnterQConfigScope(const QConfig& qconfig);
/*!
* \brief Pop a build config off the thread local context stack, restoring the previous
* configuration as the current context.
*/
static void ExitQConfigScope();
/*!
* \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context.
*/
static QConfig& Current();
using ContainerType = QConfigNode;
};