以Tuple和TupleNode为例(Tuple是NodeRef的子类, TupleNode是Node的子类):

定义Node子类

  1. /*! \brief Tuple of multiple Exprs */
  2. class Tuple;
  3. /*! \brief Tuple container */
  4. class TupleNode : public ExprNode {
  5. public:
  6. // 1 自定义的字段
  7. /*! \brief the fields of the tuple */
  8. tvm::Array<relay::Expr> fields;
  9. // 2 访问属性(按格式模仿即可)
  10. // 注意: fields是自定义的,而span和checked_type_是继承来的
  11. void VisitAttrs(tvm::AttrVisitor* v) {
  12. v->Visit("fields", &fields);
  13. v->Visit("span", &span);
  14. v->Visit("_checked_type_", &checked_type_);
  15. }
  16. // 3 声明make函数
  17. TVM_DLL static Tuple make(tvm::Array<relay::Expr> fields);
  18. // 4 定义个唯一的字符串类型的类型关键字字段:_type_key
  19. static constexpr const char* _type_key = "relay.Tuple";
  20. // 5 声明节点类型信息,TupleNode是ExprNode的子类,且TupleNode是终端节点
  21. TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode);
  22. };
  23. // 6 定义节点引用,Tuple是TupleNode的引用类型,Tuple是Expr的子类
  24. RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr);
  25. // 7 定义make函数
  26. Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
  27. NodePtr<TupleNode> n = make_node<TupleNode>();
  28. n->fields = std::move(fields);
  29. return Tuple(n);
  30. }
  31. // 8 将节点类型TupleNode注册到对象注册表和反射注册表。
  32. TVM_REGISTER_NODE_TYPE(TupleNode);

2 定义NodeRef子类

2.1 通过RELAY_DEFINE_NODE_REF宏快速定义

如果没有特殊的要求(扩展成员变量或函数),直接通过RELAY_DEFINE_NODE_REF可以快速定义NodeRef子类,简单方便。
Tuple是通过RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr)定义的,Tuple是Expr的子类,Tuple是TupleNode的引用类型。

  • (1) 在TupleNode定义之前先声明Tuple类
  • (2) 在TupleNode定义之后使用RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr);来定义Tupe类。

RELAY_DEFINE_NODE_REF宏的定义如下:

  1. /*!
  2. * \brief Macro to make it easy to define node ref type given node
  3. * \param TypeName The name of the reference type.
  4. * \param NodeName The internal container name.
  5. * \param NodeRefBase The base type.
  6. */
  7. #define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \
  8. class TypeName : public NodeRefBase { \
  9. public: \
  10. TypeName() {} \
  11. explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
  12. : NodeRefBase(n) { \
  13. } \
  14. const NodeName* operator->() const { \
  15. return static_cast<const NodeName*>(get()); \
  16. } \
  17. operator bool() { return this->defined(); } \
  18. using ContainerType = NodeName; \
  19. };

可见,该宏其实实现了以下几个功能:

  • TypeName类继承NodeRefBase
  • 定义TypeName类的构造函数(一个不带参,一个带参)
  • 定义operator->(),可使用-> 操作符直接访问Node类指针
  • 定义bool()函数
  • 添加using ContainerType = NodeName;

假设t是一个Tuple类对象,那么t->将会指向一个TupleNode *指针。

2.2 自定义NodeRef子类

如果有特殊的要求(扩展成员变量或函数),通过RELAY_DEFINE_NODE_REF定义的NodeRef子类不满足要求,就需要自定义。

  • 自定义类继承自NodeRef
  • 定义构造函数(一个不带参,一个带参)
  • 定义operator->(),可使用-> 操作符直接访问Node类指针
  • 在类定义的结尾处添加using ContainerType = NodeName; 注意:NodeName是自定义的类名。
  • 添加扩展成员变量或函数…

例如,QConfig的定义如下:

  1. /*!
  2. * \brief Container for build configuration options
  3. */
  4. class QConfig : public NodeRef {
  5. public:
  6. QConfig() {}
  7. explicit QConfig(ObjectPtr<Object> n) : NodeRef(n) {}
  8. const QConfigNode* operator->() const {
  9. return static_cast<const QConfigNode*>(get());
  10. }
  11. QConfigNode* operator->() {
  12. return static_cast<QConfigNode*>(get_mutable());
  13. }
  14. /*!
  15. * \brief Push a new BuildConfig context onto the thread local stack.
  16. * \param build_config The configuration to set as the current context.
  17. */
  18. static void EnterQConfigScope(const QConfig& qconfig);
  19. /*!
  20. * \brief Pop a build config off the thread local context stack, restoring the previous
  21. * configuration as the current context.
  22. */
  23. static void ExitQConfigScope();
  24. /*!
  25. * \brief Get the current BuildConfig context from thread local storage, or a default
  26. * configuration if a BuildConfig scope has not been entered.
  27. * \return The configuration that is the current context.
  28. */
  29. static QConfig& Current();
  30. using ContainerType = QConfigNode;
  31. };