InferTypeOpt-类型推断
static inline Expr InferTypeOpt(const Expr& expr) {
auto mod = ModuleNode::FromExpr(expr);
mod = transform::InferType()(mod);
auto entry_func = mod->Lookup("main");
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
}
ConstEvaluateOpt-部分评估,死代码消除,类型推断
static inline Expr ConstEvaluateOpt(Expr expr) {
auto mod = ModuleNode::FromExpr(expr);
Array<transform::Pass> passes = {transform::PartialEval(),
transform::DeadCodeElimination(true),
transform::InferType()};
auto seq_pass = transform::Sequential(passes);
mod = seq_pass(mod);
auto entry_func = mod->Lookup("main");
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
}
FoldConstantOpt-常量折叠
static inline Expr FoldConstantOpt(const Expr& expr) {
auto mod = ModuleNode::FromExpr(expr);
mod = transform::FoldConstant()(mod);
auto entry_func = mod->Lookup("main");
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
}
IsPowerOf2 - 判断一个整数是不是2的幂
// check if num is the nth power of 2
static inline bool IsPowerOf2(int num) { return (num > 0 && 0 == (num & (num - 1))); }
IsElementWise-
static inline bool IsElemwiseShape(const Type& lhs, const Type& rhs) {
const auto* lhs_node = lhs.as<TensorTypeNode>();
const auto* rhs_node = rhs.as<TensorTypeNode>();
AttrsEqual attr_equal_;
if (lhs_node && rhs_node &&
attr_equal_(lhs_node->shape, rhs_node->shape)) {
return true;
}
return false;
}
static inline bool IsElemwiseShape(const Expr& lhs, const Expr& rhs) {
const auto lhs_type = lhs->checked_type();
const auto rhs_type = rhs->checked_type();
return IsElemwiseShape(lhs_type, rhs_type);
}
static inline bool IsElemwise(const Expr& lhs, const Expr& rhs) {
const auto* lhs_node = lhs->checked_type().as<TensorTypeNode>();
const auto* rhs_node = rhs->checked_type().as<TensorTypeNode>();
AttrsEqual attr_equal_;
if (lhs_node && rhs_node &&
attr_equal_(lhs_node->shape, rhs_node->shape) &&
lhs_node->dtype == rhs_node->dtype) {
return true;
}
return false;
}
ConvertToConstants
static Array<Integer> ConvertToConstants(const Array<IndexExpr>& arr) {
Array<Integer> convert_result;
for (size_t i = 0; i < arr.size(); ++i) {
const IntImm* const_elem = arr[i].as<IntImm>();
CHECK(const_elem);
convert_result.push_back(const_elem->value);
}
return std::move(convert_result);
}
/*!
* \brief Get x as constant int expression.
* \param x The expression
* \return the address to the int expression,
* return nullptr, if x is not IntImm.
*/
inline const int64_t* as_const_int(const Expr& x) {
if (!x.defined()) return nullptr;
if (const ir::IntImm* op = x.as<ir::IntImm>()) {
return &(op->value);
} else {
return nullptr;
}
}
static inline int64_t get_const_int(const tvm::Expr& x) {
auto* value_ptr = as_const_int(x);
CHECK(value_ptr) << "Expr is not a constant int";
return value_ptr[0];
}
推断类型,获取shape,获取DataType
//推断类型
static inline Type Expr_InferType(const Expr& expr) {
auto mod = relay::ModuleNode::FromExpr(expr);
mod = relay::transform::InferType()(mod);
auto type_fx = mod->Lookup("main");
return type_fx->ret_type;
}
//获取shape
static inline Array<Integer> GetShape(const Expr& expr) {
auto tt = Expr_InferType(expr);
CHECK(tt.defined());
auto ttnode = tt.as<TensorTypeNode>();
return ConvertToConstants(ttnode->shape);
}
//获取DataType
static inline DataType GetDataType(const Expr& expr) {
auto tt = Expr_InferType(expr);
CHECK(tt.defined());
auto ttnode = tt.as<TensorTypeNode>();
return ttnode->dtype;
}