InferTypeOpt-类型推断

  1. static inline Expr InferTypeOpt(const Expr& expr) {
  2. auto mod = ModuleNode::FromExpr(expr);
  3. mod = transform::InferType()(mod);
  4. auto entry_func = mod->Lookup("main");
  5. return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
  6. }

ConstEvaluateOpt-部分评估,死代码消除,类型推断

  1. static inline Expr ConstEvaluateOpt(Expr expr) {
  2. auto mod = ModuleNode::FromExpr(expr);
  3. Array<transform::Pass> passes = {transform::PartialEval(),
  4. transform::DeadCodeElimination(true),
  5. transform::InferType()};
  6. auto seq_pass = transform::Sequential(passes);
  7. mod = seq_pass(mod);
  8. auto entry_func = mod->Lookup("main");
  9. return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
  10. }

FoldConstantOpt-常量折叠

  1. static inline Expr FoldConstantOpt(const Expr& expr) {
  2. auto mod = ModuleNode::FromExpr(expr);
  3. mod = transform::FoldConstant()(mod);
  4. auto entry_func = mod->Lookup("main");
  5. return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
  6. }

IsPowerOf2 - 判断一个整数是不是2的幂

  1. // check if num is the nth power of 2
  2. static inline bool IsPowerOf2(int num) { return (num > 0 && 0 == (num & (num - 1))); }

IsElementWise-

  1. static inline bool IsElemwiseShape(const Type& lhs, const Type& rhs) {
  2. const auto* lhs_node = lhs.as<TensorTypeNode>();
  3. const auto* rhs_node = rhs.as<TensorTypeNode>();
  4. AttrsEqual attr_equal_;
  5. if (lhs_node && rhs_node &&
  6. attr_equal_(lhs_node->shape, rhs_node->shape)) {
  7. return true;
  8. }
  9. return false;
  10. }
  11. static inline bool IsElemwiseShape(const Expr& lhs, const Expr& rhs) {
  12. const auto lhs_type = lhs->checked_type();
  13. const auto rhs_type = rhs->checked_type();
  14. return IsElemwiseShape(lhs_type, rhs_type);
  15. }
  16. static inline bool IsElemwise(const Expr& lhs, const Expr& rhs) {
  17. const auto* lhs_node = lhs->checked_type().as<TensorTypeNode>();
  18. const auto* rhs_node = rhs->checked_type().as<TensorTypeNode>();
  19. AttrsEqual attr_equal_;
  20. if (lhs_node && rhs_node &&
  21. attr_equal_(lhs_node->shape, rhs_node->shape) &&
  22. lhs_node->dtype == rhs_node->dtype) {
  23. return true;
  24. }
  25. return false;
  26. }

ConvertToConstants

  1. static Array<Integer> ConvertToConstants(const Array<IndexExpr>& arr) {
  2. Array<Integer> convert_result;
  3. for (size_t i = 0; i < arr.size(); ++i) {
  4. const IntImm* const_elem = arr[i].as<IntImm>();
  5. CHECK(const_elem);
  6. convert_result.push_back(const_elem->value);
  7. }
  8. return std::move(convert_result);
  9. }
  1. /*!
  2. * \brief Get x as constant int expression.
  3. * \param x The expression
  4. * \return the address to the int expression,
  5. * return nullptr, if x is not IntImm.
  6. */
  7. inline const int64_t* as_const_int(const Expr& x) {
  8. if (!x.defined()) return nullptr;
  9. if (const ir::IntImm* op = x.as<ir::IntImm>()) {
  10. return &(op->value);
  11. } else {
  12. return nullptr;
  13. }
  14. }
  15. static inline int64_t get_const_int(const tvm::Expr& x) {
  16. auto* value_ptr = as_const_int(x);
  17. CHECK(value_ptr) << "Expr is not a constant int";
  18. return value_ptr[0];
  19. }

推断类型,获取shape,获取DataType

  1. //推断类型
  2. static inline Type Expr_InferType(const Expr& expr) {
  3. auto mod = relay::ModuleNode::FromExpr(expr);
  4. mod = relay::transform::InferType()(mod);
  5. auto type_fx = mod->Lookup("main");
  6. return type_fx->ret_type;
  7. }
  8. //获取shape
  9. static inline Array<Integer> GetShape(const Expr& expr) {
  10. auto tt = Expr_InferType(expr);
  11. CHECK(tt.defined());
  12. auto ttnode = tt.as<TensorTypeNode>();
  13. return ConvertToConstants(ttnode->shape);
  14. }
  15. //获取DataType
  16. static inline DataType GetDataType(const Expr& expr) {
  17. auto tt = Expr_InferType(expr);
  18. CHECK(tt.defined());
  19. auto ttnode = tt.as<TensorTypeNode>();
  20. return ttnode->dtype;
  21. }