尽可能的将Node获取TensorTypeNode,就可以获取shape。
CallNode
ExprNode有一个变量是Type checkedtype,它保存了推断的类型(类型检查)
class ExprNode : public RelayNode {
public:
/*!
* \brief Stores the result of type inference(type checking).
*
* \note This can be undefined before type inference.
* This value is discarded during serialization.
*/
mutable Type checked_type_ = Type(nullptr);
...
CallNode继承了ExprNode,也保留了该字段。如果它可以转换成TensorTypeNode,就可以获取shape。
因为TensorTypeNode有shape字段:
class TensorType;
/*! \brief TensorType container node */
class TensorTypeNode : public BaseTensorTypeNode {
public:
/*!
* \brief The shape of the tensor,
* represented by IndexExpr(tvm::Expr).
*/
Array<IndexExpr> shape;
/*! \brief The content data type */
DataType dtype;
...
通过TensorTypeNode获取shape。
例如:
if (new_call->op.same_as(dense_op)) {
CHECK_EQ(new_call->args.size(), 2) << "The argument size should be equal to 2.";
const auto* dense_param = new_call->attrs.as<DenseAttrs>();
CHECK(dense_param != nullptr);
auto arg_0 = new_call->args[0];
auto arg_1 = new_call->args[1]; //weight params
call_args.push_back(arg_0);
auto type_0 = arg_0->checked_type_;
const auto* t_flatten = type_0.as<TensorTypeNode>();
auto type_1 = arg_1->checked_type_;
const auto* t_weight = type_1.as<TensorTypeNode>();
CHECK_EQ(t_flatten->shape.size(), 2) << "The size should be equal to 2.";
CHECK_EQ(t_weight->shape.size(), 2) << "The size should be equal to 2.";
}
当然,有时候checkedtype是没有定义的,需要推断该类型。
static inline Type Expr_InferType(const Expr& expr) {
auto mod = relay::ModuleNode::FromExpr(expr);
mod = relay::transform::InferType()(mod);
auto type_fx = mod->Lookup("main");
return type_fx->ret_type;
}
if (!args[i]->checked_type_.defined()) {
args[i]->checked_type_ = Expr_InferType(args[i]);
}
auto ttype = args[i]->type_as<TensorTypeNode>();
VarNode
VarNode有type_annotation字段:
/*!
* \brief type annotaion of the variable.
* This field records user provided type annotation of the Var.
* This field is optional and can be None.
*/
Type type_annotation;
通过VarNode的type_annotation获取Type, 转换成TensorTypeNode, 然后获取shape。
例如:
Expr ExprVaccDataAlignment::VisitExpr_(const VarNode* op, std::string layout_name) {
if (op->type_annotation.defined()) {
auto type = this->VisitType(op->type_annotation);
const auto* ttnode = type.as<TensorTypeNode>();
ttnode->shape;
...
}
}
ConstantNode
ConstantNode有一个tensor_type()函数,可以获取TensorType
class ConstantNode : public ExprNode {
public:
/*! \brief The data of the tensor */
runtime::NDArray data;
/*! \brief key/value pairs like weight:OIHW64i32o, bias:xx*/
/*! in most case, only there is one pair of key:value*/
std::unordered_map<std::string, std::string> kvs;
/*! \return The corresponding tensor type of the data */
TensorType tensor_type() const;
示例:
Expr ExprVaccDataAlignment::AffineConstantAlign(const ConstantNode* op) {
auto src_ret = op->data;
auto tensorType = op->tensor_type();
vector<int64_t> src_shape;
int64_t channel;
for (size_t i = 0; i < tensorType->shape.size(); ++i) {
int64_t shape = *as_const_int(tensorType->shape[i]);
src_shape.push_back(shape);
}
综合例子:
if (op.same_as(affine_op)) {
// For an affine op, first check its granularity, scalar or by-channel.
Expr arg = call->args[1];
Array<IndexExpr> arg_shape;
if (arg->IsInstance<VarNode>()) {
const VarNode* var = arg.as<VarNode>();
arg_shape = var->type_annotation.as<TensorTypeNode>()->shape;
} else if (arg->IsInstance<ConstantNode>()) {
const ConstantNode* cons = arg.as<ConstantNode>();
arg_shape = cons->tensor_type()->shape;
}
ExprNode::type_as
ExprNode::type_as尝试检查ExprNode的类型是否是指定的某类型
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed. Try to call infer_type pass.";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
<< ", but get " << checked_type_->GetTypeKey();
return node;
}
所以,一个Expr可以通过type_as尝试转换成指定的某类型:
Expr data = res_call->args[0]; // input data
const TensorTypeNode* dat_node = data->type_as<TensorTypeNode>();//检查输入是否是TensorTypeNode类型
CHECK(dat_node)
有时候checkedtype是没有定义的,需要推断该类型。
//推断类型
static inline Type Expr_InferType(const Expr& expr) {
auto mod = relay::ModuleNode::FromExpr(expr);
mod = relay::transform::InferType()(mod);
auto type_fx = mod->Lookup("main");
return type_fx->ret_type;
}
//获取shape
static inline Array<Integer> GetShape(const Expr& expr) {
auto tt = Expr_InferType(expr);
CHECK(tt.defined());
auto ttnode = tt.as<TensorTypeNode>();
return ConvertToConstants(ttnode->shape);
}
//获取DataType
static inline DataType GetDataType(const Expr& expr) {
auto tt = Expr_InferType(expr);
CHECK(tt.defined());
auto ttnode = tt.as<TensorTypeNode>();
return ttnode->dtype;
}