尽可能的将Node获取TensorTypeNode,就可以获取shape。
CallNode
ExprNode有一个变量是Type checkedtype,它保存了推断的类型(类型检查)
class ExprNode : public RelayNode {public:/*!* \brief Stores the result of type inference(type checking).** \note This can be undefined before type inference.* This value is discarded during serialization.*/mutable Type checked_type_ = Type(nullptr);...
CallNode继承了ExprNode,也保留了该字段。如果它可以转换成TensorTypeNode,就可以获取shape。
因为TensorTypeNode有shape字段:
class TensorType;/*! \brief TensorType container node */class TensorTypeNode : public BaseTensorTypeNode {public:/*!* \brief The shape of the tensor,* represented by IndexExpr(tvm::Expr).*/Array<IndexExpr> shape;/*! \brief The content data type */DataType dtype;...
通过TensorTypeNode获取shape。
例如:
if (new_call->op.same_as(dense_op)) {CHECK_EQ(new_call->args.size(), 2) << "The argument size should be equal to 2.";const auto* dense_param = new_call->attrs.as<DenseAttrs>();CHECK(dense_param != nullptr);auto arg_0 = new_call->args[0];auto arg_1 = new_call->args[1]; //weight paramscall_args.push_back(arg_0);auto type_0 = arg_0->checked_type_;const auto* t_flatten = type_0.as<TensorTypeNode>();auto type_1 = arg_1->checked_type_;const auto* t_weight = type_1.as<TensorTypeNode>();CHECK_EQ(t_flatten->shape.size(), 2) << "The size should be equal to 2.";CHECK_EQ(t_weight->shape.size(), 2) << "The size should be equal to 2.";}
当然,有时候checkedtype是没有定义的,需要推断该类型。
static inline Type Expr_InferType(const Expr& expr) {auto mod = relay::ModuleNode::FromExpr(expr);mod = relay::transform::InferType()(mod);auto type_fx = mod->Lookup("main");return type_fx->ret_type;}if (!args[i]->checked_type_.defined()) {args[i]->checked_type_ = Expr_InferType(args[i]);}auto ttype = args[i]->type_as<TensorTypeNode>();
VarNode
VarNode有type_annotation字段:
/*!* \brief type annotaion of the variable.* This field records user provided type annotation of the Var.* This field is optional and can be None.*/Type type_annotation;
通过VarNode的type_annotation获取Type, 转换成TensorTypeNode, 然后获取shape。
例如:
Expr ExprVaccDataAlignment::VisitExpr_(const VarNode* op, std::string layout_name) {if (op->type_annotation.defined()) {auto type = this->VisitType(op->type_annotation);const auto* ttnode = type.as<TensorTypeNode>();ttnode->shape;...}}
ConstantNode
ConstantNode有一个tensor_type()函数,可以获取TensorType
class ConstantNode : public ExprNode {public:/*! \brief The data of the tensor */runtime::NDArray data;/*! \brief key/value pairs like weight:OIHW64i32o, bias:xx*//*! in most case, only there is one pair of key:value*/std::unordered_map<std::string, std::string> kvs;/*! \return The corresponding tensor type of the data */TensorType tensor_type() const;
示例:
Expr ExprVaccDataAlignment::AffineConstantAlign(const ConstantNode* op) {auto src_ret = op->data;auto tensorType = op->tensor_type();vector<int64_t> src_shape;int64_t channel;for (size_t i = 0; i < tensorType->shape.size(); ++i) {int64_t shape = *as_const_int(tensorType->shape[i]);src_shape.push_back(shape);}
综合例子:
if (op.same_as(affine_op)) {// For an affine op, first check its granularity, scalar or by-channel.Expr arg = call->args[1];Array<IndexExpr> arg_shape;if (arg->IsInstance<VarNode>()) {const VarNode* var = arg.as<VarNode>();arg_shape = var->type_annotation.as<TensorTypeNode>()->shape;} else if (arg->IsInstance<ConstantNode>()) {const ConstantNode* cons = arg.as<ConstantNode>();arg_shape = cons->tensor_type()->shape;}
ExprNode::type_as
ExprNode::type_as尝试检查ExprNode的类型是否是指定的某类型
template<typename TTypeNode>inline const TTypeNode* ExprNode::type_as() const {static_assert(std::is_base_of<TypeNode, TTypeNode>::value,"TType must be a special case of type");CHECK(checked_type_.defined())<< "Type inference for this Expr has not completed. Try to call infer_type pass.";const TTypeNode* node = checked_type_.as<TTypeNode>();CHECK(node != nullptr)<< "Expected type to be " << TTypeNode::_type_key<< ", but get " << checked_type_->GetTypeKey();return node;}
所以,一个Expr可以通过type_as尝试转换成指定的某类型:
Expr data = res_call->args[0]; // input dataconst TensorTypeNode* dat_node = data->type_as<TensorTypeNode>();//检查输入是否是TensorTypeNode类型CHECK(dat_node)
有时候checkedtype是没有定义的,需要推断该类型。
//推断类型static inline Type Expr_InferType(const Expr& expr) {auto mod = relay::ModuleNode::FromExpr(expr);mod = relay::transform::InferType()(mod);auto type_fx = mod->Lookup("main");return type_fx->ret_type;}//获取shapestatic inline Array<Integer> GetShape(const Expr& expr) {auto tt = Expr_InferType(expr);CHECK(tt.defined());auto ttnode = tt.as<TensorTypeNode>();return ConvertToConstants(ttnode->shape);}//获取DataTypestatic inline DataType GetDataType(const Expr& expr) {auto tt = Expr_InferType(expr);CHECK(tt.defined());auto ttnode = tt.as<TensorTypeNode>();return ttnode->dtype;}
