InferTypeOpt-类型推断
static inline Expr InferTypeOpt(const Expr& expr) { auto mod = ModuleNode::FromExpr(expr); mod = transform::InferType()(mod); auto entry_func = mod->Lookup("main"); return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;}
ConstEvaluateOpt-部分评估,死代码消除,类型推断
static inline Expr ConstEvaluateOpt(Expr expr) { auto mod = ModuleNode::FromExpr(expr); Array<transform::Pass> passes = {transform::PartialEval(), transform::DeadCodeElimination(true), transform::InferType()}; auto seq_pass = transform::Sequential(passes); mod = seq_pass(mod); auto entry_func = mod->Lookup("main"); return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;}
FoldConstantOpt-常量折叠
static inline Expr FoldConstantOpt(const Expr& expr) { auto mod = ModuleNode::FromExpr(expr); mod = transform::FoldConstant()(mod); auto entry_func = mod->Lookup("main"); return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;}
IsPowerOf2 - 判断一个整数是不是2的幂
// check if num is the nth power of 2static inline bool IsPowerOf2(int num) { return (num > 0 && 0 == (num & (num - 1))); }
IsElementWise-
static inline bool IsElemwiseShape(const Type& lhs, const Type& rhs) { const auto* lhs_node = lhs.as<TensorTypeNode>(); const auto* rhs_node = rhs.as<TensorTypeNode>(); AttrsEqual attr_equal_; if (lhs_node && rhs_node && attr_equal_(lhs_node->shape, rhs_node->shape)) { return true; } return false;}static inline bool IsElemwiseShape(const Expr& lhs, const Expr& rhs) { const auto lhs_type = lhs->checked_type(); const auto rhs_type = rhs->checked_type(); return IsElemwiseShape(lhs_type, rhs_type);}static inline bool IsElemwise(const Expr& lhs, const Expr& rhs) { const auto* lhs_node = lhs->checked_type().as<TensorTypeNode>(); const auto* rhs_node = rhs->checked_type().as<TensorTypeNode>(); AttrsEqual attr_equal_; if (lhs_node && rhs_node && attr_equal_(lhs_node->shape, rhs_node->shape) && lhs_node->dtype == rhs_node->dtype) { return true; } return false;}
ConvertToConstants
static Array<Integer> ConvertToConstants(const Array<IndexExpr>& arr) { Array<Integer> convert_result; for (size_t i = 0; i < arr.size(); ++i) { const IntImm* const_elem = arr[i].as<IntImm>(); CHECK(const_elem); convert_result.push_back(const_elem->value); } return std::move(convert_result);}
/*! * \brief Get x as constant int expression. * \param x The expression * \return the address to the int expression, * return nullptr, if x is not IntImm. */inline const int64_t* as_const_int(const Expr& x) { if (!x.defined()) return nullptr; if (const ir::IntImm* op = x.as<ir::IntImm>()) { return &(op->value); } else { return nullptr; }}static inline int64_t get_const_int(const tvm::Expr& x) { auto* value_ptr = as_const_int(x); CHECK(value_ptr) << "Expr is not a constant int"; return value_ptr[0];}
推断类型,获取shape,获取DataType
//推断类型static inline Type Expr_InferType(const Expr& expr) { auto mod = relay::ModuleNode::FromExpr(expr); mod = relay::transform::InferType()(mod); auto type_fx = mod->Lookup("main"); return type_fx->ret_type;}//获取shapestatic inline Array<Integer> GetShape(const Expr& expr) { auto tt = Expr_InferType(expr); CHECK(tt.defined()); auto ttnode = tt.as<TensorTypeNode>(); return ConvertToConstants(ttnode->shape);}//获取DataTypestatic inline DataType GetDataType(const Expr& expr) { auto tt = Expr_InferType(expr); CHECK(tt.defined()); auto ttnode = tt.as<TensorTypeNode>(); return ttnode->dtype;}