尽可能的将Node获取TensorTypeNode,就可以获取shape。

CallNode

ExprNode有一个变量是Type checkedtype,它保存了推断的类型(类型检查)

  1. class ExprNode : public RelayNode {
  2. public:
  3. /*!
  4. * \brief Stores the result of type inference(type checking).
  5. *
  6. * \note This can be undefined before type inference.
  7. * This value is discarded during serialization.
  8. */
  9. mutable Type checked_type_ = Type(nullptr);
  10. ...

CallNode继承了ExprNode,也保留了该字段。如果它可以转换成TensorTypeNode,就可以获取shape。
因为TensorTypeNode有shape字段:

  1. class TensorType;
  2. /*! \brief TensorType container node */
  3. class TensorTypeNode : public BaseTensorTypeNode {
  4. public:
  5. /*!
  6. * \brief The shape of the tensor,
  7. * represented by IndexExpr(tvm::Expr).
  8. */
  9. Array<IndexExpr> shape;
  10. /*! \brief The content data type */
  11. DataType dtype;
  12. ...

通过TensorTypeNode获取shape。
例如:

  1. if (new_call->op.same_as(dense_op)) {
  2. CHECK_EQ(new_call->args.size(), 2) << "The argument size should be equal to 2.";
  3. const auto* dense_param = new_call->attrs.as<DenseAttrs>();
  4. CHECK(dense_param != nullptr);
  5. auto arg_0 = new_call->args[0];
  6. auto arg_1 = new_call->args[1]; //weight params
  7. call_args.push_back(arg_0);
  8. auto type_0 = arg_0->checked_type_;
  9. const auto* t_flatten = type_0.as<TensorTypeNode>();
  10. auto type_1 = arg_1->checked_type_;
  11. const auto* t_weight = type_1.as<TensorTypeNode>();
  12. CHECK_EQ(t_flatten->shape.size(), 2) << "The size should be equal to 2.";
  13. CHECK_EQ(t_weight->shape.size(), 2) << "The size should be equal to 2.";
  14. }

当然,有时候checkedtype是没有定义的,需要推断该类型。

  1. static inline Type Expr_InferType(const Expr& expr) {
  2. auto mod = relay::ModuleNode::FromExpr(expr);
  3. mod = relay::transform::InferType()(mod);
  4. auto type_fx = mod->Lookup("main");
  5. return type_fx->ret_type;
  6. }
  7. if (!args[i]->checked_type_.defined()) {
  8. args[i]->checked_type_ = Expr_InferType(args[i]);
  9. }
  10. auto ttype = args[i]->type_as<TensorTypeNode>();

VarNode

VarNode有type_annotation字段:

  1. /*!
  2. * \brief type annotaion of the variable.
  3. * This field records user provided type annotation of the Var.
  4. * This field is optional and can be None.
  5. */
  6. Type type_annotation;

通过VarNode的type_annotation获取Type, 转换成TensorTypeNode, 然后获取shape。
例如:

  1. Expr ExprVaccDataAlignment::VisitExpr_(const VarNode* op, std::string layout_name) {
  2. if (op->type_annotation.defined()) {
  3. auto type = this->VisitType(op->type_annotation);
  4. const auto* ttnode = type.as<TensorTypeNode>();
  5. ttnode->shape;
  6. ...
  7. }
  8. }

ConstantNode

ConstantNode有一个tensor_type()函数,可以获取TensorType

  1. class ConstantNode : public ExprNode {
  2. public:
  3. /*! \brief The data of the tensor */
  4. runtime::NDArray data;
  5. /*! \brief key/value pairs like weight:OIHW64i32o, bias:xx*/
  6. /*! in most case, only there is one pair of key:value*/
  7. std::unordered_map<std::string, std::string> kvs;
  8. /*! \return The corresponding tensor type of the data */
  9. TensorType tensor_type() const;

示例:

  1. Expr ExprVaccDataAlignment::AffineConstantAlign(const ConstantNode* op) {
  2. auto src_ret = op->data;
  3. auto tensorType = op->tensor_type();
  4. vector<int64_t> src_shape;
  5. int64_t channel;
  6. for (size_t i = 0; i < tensorType->shape.size(); ++i) {
  7. int64_t shape = *as_const_int(tensorType->shape[i]);
  8. src_shape.push_back(shape);
  9. }

综合例子:

  1. if (op.same_as(affine_op)) {
  2. // For an affine op, first check its granularity, scalar or by-channel.
  3. Expr arg = call->args[1];
  4. Array<IndexExpr> arg_shape;
  5. if (arg->IsInstance<VarNode>()) {
  6. const VarNode* var = arg.as<VarNode>();
  7. arg_shape = var->type_annotation.as<TensorTypeNode>()->shape;
  8. } else if (arg->IsInstance<ConstantNode>()) {
  9. const ConstantNode* cons = arg.as<ConstantNode>();
  10. arg_shape = cons->tensor_type()->shape;
  11. }

ExprNode::type_as

ExprNode::type_as尝试检查ExprNode的类型是否是指定的某类型

  1. template<typename TTypeNode>
  2. inline const TTypeNode* ExprNode::type_as() const {
  3. static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
  4. "TType must be a special case of type");
  5. CHECK(checked_type_.defined())
  6. << "Type inference for this Expr has not completed. Try to call infer_type pass.";
  7. const TTypeNode* node = checked_type_.as<TTypeNode>();
  8. CHECK(node != nullptr)
  9. << "Expected type to be " << TTypeNode::_type_key
  10. << ", but get " << checked_type_->GetTypeKey();
  11. return node;
  12. }

所以,一个Expr可以通过type_as尝试转换成指定的某类型:

  1. Expr data = res_call->args[0]; // input data
  2. const TensorTypeNode* dat_node = data->type_as<TensorTypeNode>();//检查输入是否是TensorTypeNode类型
  3. CHECK(dat_node)

有时候checkedtype是没有定义的,需要推断该类型。

  1. //推断类型
  2. static inline Type Expr_InferType(const Expr& expr) {
  3. auto mod = relay::ModuleNode::FromExpr(expr);
  4. mod = relay::transform::InferType()(mod);
  5. auto type_fx = mod->Lookup("main");
  6. return type_fx->ret_type;
  7. }
  8. //获取shape
  9. static inline Array<Integer> GetShape(const Expr& expr) {
  10. auto tt = Expr_InferType(expr);
  11. CHECK(tt.defined());
  12. auto ttnode = tt.as<TensorTypeNode>();
  13. return ConvertToConstants(ttnode->shape);
  14. }
  15. //获取DataType
  16. static inline DataType GetDataType(const Expr& expr) {
  17. auto tt = Expr_InferType(expr);
  18. CHECK(tt.defined());
  19. auto ttnode = tt.as<TensorTypeNode>();
  20. return ttnode->dtype;
  21. }