TVM- ForwardRewriter

创建时间: 2020-07-08 11:51
更新时间: 2020-07-08 11:57
标签: TVM, 李国明, 熊选文

ForwardRewriter

  1. ForwardRewriterExprMutator的子类,所以直接利用ForwardRewriter来实现Pass,跟我们自己写一个ExprMutator子类来实现Pass本质上没有区别。但ForwardRewriter提供了一套框架,使得我们可以更方便地实现Pass,并写出更简洁的代码。ForwardRewriter主要适用于需要判断不同Op,并做出不同处理的Pass。此处的Forward是向前的意思,该方向和graph计算方向一样,但是跟graph的遍历方向是相反的。graph的遍历方法是:将graph视为一棵树,最后一个节点是根节点,第一个节点是叶子,树的遍历,从根节点到叶子,采用深度优先搜索(DFS)算法,准确的说是post order DFS。<br />![image.png](https://cdn.nlark.com/yuque/0/2020/png/644279/1594482594588-55532841-2508-4d3d-ac02-a4a3ff7a02f1.png#align=left&display=inline&height=294&margin=%5Bobject%20Object%5D&name=image.png&originHeight=294&originWidth=357&size=8851&status=done&style=none&width=357)<br />框架已经实现了graph的遍历,可以自动根据不同的Op调用注册的ForwardRewrite函数,我们可以专注于实现不同Op的rewrite。同时也提供了一些辅助手段供开发者使用,如TempExpr、fcontext上下文和fmulti_ref_trigger。

1 注册ForwardRewrite Pass

TVM- ForwardRewriter - 图1
第161行,ForwardRewrite是ForwardRewriter的入口函数,如下:
TVM- ForwardRewriter - 图2

2 注册Op的ForwardRewrite函数

TVM- ForwardRewriter - 图3
第99行,set_attr的第一个参数(此处为“FTestRewrite”)要与注册Pass时ForwardRewrite函数的rewrite_map_attr_name参数相同。
Conv2dTestRewrite是一个FForwardRewrite类型的函数,FForwardRewrite定义如下图:
TVM- ForwardRewriter - 图4
参数ref_call是将要被重写的原始call。
参数new_args是TempExpr或其自定义子类,可以携带自定义信息。
参数ctx是ref_call的上下文信息。
最后,TestRealizeExprNode::make(ref_call)返回一个TestRealizeExpr实例。
TestRealizeExpr是TempExpr的子类,TestRealizeExprNode是TempExprNode的子类。
如下所示,TestRealizeExprNode::make()比较简单,仅有一个参数:Expr expr。实际上,如果TestRealizeExprNode含有多个变量(除了expr之外还可以有多个附加成员变量,包含其他的信息),那么TestRealizeExprNode::make()的参数可以多一些(详细见TestRealizeExprNode的定义)。
TVM- ForwardRewriter - 图5
Conv2d被make为TestRealizeExpr后会携带expr和bTest两个信息,作为relu的new_args[0]时,可以在111行和112行拿到这两个信息。

TestRealizeExpr(Node)的定义

下图是TestRealizeExpr(Node)的定义(详细见):
TempExprNode**子类可用于向前传递,但是只能传递给下一层**。
TVM- ForwardRewriter - 图6
第49行Realize是TempExprNode的virtual函数,TestRealizeExprNode是TempExprNode的子类,需要重写该函数。

Realize的调用过程

下面尝试说明一下Realize被调用的过程,以如下graph为例:
TVM- ForwardRewriter - 图7
如果只注册了relu的ForwardRewrite函数,并make为TestRealizeExpr返回,如下图:
TVM- ForwardRewriter - 图8
根据ExprMutator post-order递归遍历graph的顺序,在遍历conv2d时会查找是否注册了ForwardRewrite函数,若未注册则下图第166行frewrite为nullptr(当前是conv2d算子,没有注册,所以frewrite为nullptr)。For循环遍历其args,args数目为1,args[0]即relu。所以,下图中的arg就是relu, 165行的newarg返回的是relu,由于frewrite为nullptr,所以new_arg就会是relizer.Realize(newarg)返回的结果。
TVM- ForwardRewriter - 图9
继续看realizer
.Realize后续代码,会判断是否继承自TempExprNode再调用其Realize:
TVM- ForwardRewriter - 图10
所以,TempExprNode子类的Realize函数是否被调用,有两个条件:

  1. 遇到未注册**ForwardRewrite函数的Op**
  2. 前一个**OpForwardRewrite函数makeTempExpr**子类

那么,前一个op注册的**TempExprNode子类的Realize函数就会被调用。最后一个op除外,只要它注册了ForwardRewrite函数,注册的TempExprNode子类的Realize函数就会被调用。**

举个例子:

序号 算子/Node 算子是否被注册ForwardRewrite函数 Realize()函数是否被调用
1 x1
- [x]

|
- [x]

| | 2 | x2 |
- [ ]

|
- [ ]

| | 3 | x3 |
- [x]

|
- [ ]

| | 4 | x4 |
- [x]

|
- [ ]

| | 5 | x5 |
- [x]

|
- [x]

|

序号 算子/Node 算子是否被注册ForwardRewrite函数 Realize()函数是否被调用
1 x1
- [x]

|
- [x]

| | 2 | x2 |
- [ ]

|
- [ ]

| | 3 | x3 |
- [x]

|
- [x]

| | 4 | x4 |
- [ ]

|
- [ ]

| | 5 | x5 |
- [x]

|
- [x]

|

**

参数ctx

参数ctx可以传递上下文信息,由ForwardRewrite函数的fcontext参数提供。下图是另一种注册Pass的方法,第345行提供了一个fcontext。

TVM- ForwardRewriter - 图11
如下图中relu的ForwardRewrite函数中就使用了这个上下文信息:
TVM- ForwardRewriter - 图12

3 介绍fmulti_ref_trigger

单独截图如下(详见“注册ForwardRewrite Pass”部分):
TVM- ForwardRewriter - 图13
如果需要针对an Expr consumed by multiple callers的情况单独处理,在注册ForwardRewrite Pass时,可以提供fmulti_ref_trigger函数,下图代码当ref_count > 1时(第111行),就会调用fmulti_ref_trigger函数。
TVM- ForwardRewriter - 图14

ForwardRewriter的定义:

  1. /*!
  2. *
  3. * \file forward_rewrite.cc
  4. * \brief Apply rewriting rules in a forward fashion.
  5. */
  6. #include <tvm/relay/expr_functor.h>
  7. #include <tvm/relay/op_attr_types.h>
  8. #include <tvm/relay/transform.h>
  9. #include "pass_util.h"
  10. namespace tvm {
  11. namespace relay {
  12. // Realizer class that realizes the expression
  13. // Note that we can take benefit of its internal memo
  14. // so that calling realize repeatively won't hurt perf.
  15. class TempRealizer : private ExprMutator {
  16. public:
  17. Expr Realize(Expr expr) {
  18. return VisitExpr(expr);
  19. }
  20. private:
  21. Expr VisitExpr(const Expr& expr) final {
  22. auto it = memo_.find(expr);
  23. if (it != memo_.end()) {
  24. return it->second;
  25. } else {
  26. Expr res;
  27. if (const auto* temp = expr.as<TempExprNode>()) {
  28. res = temp->Realize();
  29. } else {
  30. res = ExprFunctor::VisitExpr(expr);
  31. }
  32. memo_[res] = res;
  33. return res;
  34. }
  35. }
  36. };
  37. class ForwardRewriter : private ExprMutator {
  38. public:
  39. ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
  40. std::function<NodeRef(const Call&)> fcontext,
  41. std::function<Expr(const Expr&)> fmulti_ref_trigger)
  42. : rewrite_map_(rewrite_map),
  43. fcontext_(fcontext),
  44. fmulti_ref_trigger_(fmulti_ref_trigger) {}
  45. ForwardRewriter(const FForwardRewrite* rewrite_func,
  46. std::function<NodeRef(const Call&)> fcontext,
  47. std::function<Expr(const Expr&)> fmulti_ref_trigger)
  48. : rewrite_func_(rewrite_func),
  49. fcontext_(fcontext),
  50. fmulti_ref_trigger_(fmulti_ref_trigger) {}
  51. // Transform expression.
  52. Expr Rewrite(Expr expr) {
  53. if (fmulti_ref_trigger_ != nullptr) {
  54. ref_counter_ = GetExprRefCount(expr);
  55. }
  56. return this->VisitExpr(expr);
  57. }
  58. private:
  59. // The rewrite rule.
  60. const OpMap<FForwardRewrite>* rewrite_map_{nullptr};
  61. const FForwardRewrite* rewrite_func_{nullptr};
  62. // The context.const
  63. std::function<NodeRef(const Call&)> fcontext_{nullptr};
  64. // The multiple reference trigger
  65. std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr};
  66. // Internal ref counter
  67. std::unordered_map<const Node*, size_t> ref_counter_;
  68. // internal realizer
  69. TempRealizer realizer_;
  70. Expr VisitExpr(const Expr& expr) final {
  71. // by default always realize.
  72. return realizer_.Realize(ExprMutator::VisitExpr(expr));
  73. }
  74. // Visit and allow non-realized version.
  75. Expr GetTempExpr(const Expr& expr) {
  76. if (fmulti_ref_trigger_ != nullptr) {
  77. Expr ret = ExprMutator::VisitExpr(expr);
  78. auto it = ref_counter_.find(expr.get());
  79. CHECK(it != ref_counter_.end());
  80. if (it->second > 1) {
  81. ret = fmulti_ref_trigger_(ret);
  82. }
  83. return ret;
  84. } else {
  85. return ExprMutator::VisitExpr(expr);
  86. }
  87. }
  88. // Automatic fold TupleGetItem.
  89. Expr VisitExpr_(const TupleGetItemNode* op) final {
  90. Expr tuple = this->GetTempExpr(op->tuple);
  91. if (const auto* ptuple = tuple.as<TupleNode>()) {
  92. return ptuple->fields[op->index];
  93. } else {
  94. if (tuple.same_as(op->tuple)) {
  95. return GetRef<Expr>(op);
  96. } else {
  97. return TupleGetItemNode::make(tuple, op->index);
  98. }
  99. }
  100. }
  101. Expr VisitExpr_(const TupleNode* op) final {
  102. tvm::Array<Expr> fields;
  103. bool all_fields_unchanged = true;
  104. for (auto field : op->fields) {
  105. auto new_field = this->GetTempExpr(field);
  106. fields.push_back(new_field);
  107. all_fields_unchanged &= new_field.same_as(field);
  108. }
  109. if (all_fields_unchanged) {
  110. return GetRef<Expr>(op);
  111. } else {
  112. return TupleNode::make(fields);
  113. }
  114. }
  115. Expr VisitExpr_(const CallNode* call_node) final {
  116. const Call& ref_call = GetRef<Call>(call_node);
  117. PackedFunc frewrite;
  118. if (rewrite_func_) {
  119. frewrite = *rewrite_func_;
  120. } else {
  121. CHECK(rewrite_map_);
  122. frewrite = rewrite_map_->get(call_node->op, nullptr);
  123. }
  124. auto new_op = this->Mutate(call_node->op);
  125. bool unchanged = call_node->op.same_as(new_op);
  126. Array<Expr> call_args;
  127. for (auto arg : call_node->args) {
  128. Expr new_arg = this->GetTempExpr(arg);
  129. if (frewrite == nullptr) {
  130. //只有当前op没有注册FForwardRewrite函数时,frewrite为nullptr
  131. //此时,realizer_.Realize(new_arg)才会被调用
  132. new_arg = realizer_.Realize(new_arg);
  133. }
  134. unchanged &= new_arg.same_as(arg);
  135. call_args.push_back(new_arg);
  136. }
  137. // try to rewrite.
  138. if (frewrite != nullptr) {
  139. Expr res = frewrite(
  140. ref_call, call_args,
  141. fcontext_ != nullptr ? fcontext_(ref_call) : NodeRef(nullptr));
  142. if (res.defined()) return res;
  143. // abort, use old rule
  144. for (size_t i = 0; i < call_args.size(); ++i) {
  145. Expr arg = call_args[i];
  146. Expr new_arg = realizer_.Realize(arg);
  147. if (!arg.same_as(new_arg)) {
  148. call_args.Set(i, new_arg);
  149. unchanged = false;
  150. }
  151. }
  152. }
  153. if (unchanged) return ref_call;
  154. return CallNode::make(
  155. new_op, call_args, call_node->attrs, call_node->type_args);
  156. }
  157. };
  158. Expr ForwardRewrite(const Expr& expr,
  159. const std::string& rewrite_map_name,
  160. std::function<NodeRef(const Call&)> fcontext,
  161. std::function<Expr(const Expr&)> fmulti_ref_trigger) {
  162. auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
  163. return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);
  164. }
  165. Expr ForwardRewrite(const Expr& expr,
  166. const FForwardRewrite& rewrite_func,
  167. std::function<NodeRef(const Call&)> fcontext,
  168. std::function<Expr(const Expr&)> fmulti_ref_trigger) {
  169. return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
  170. }
  171. } // namespace relay
  172. } // namespace tvm

简单实例

  1. #include <tvm/relay/analysis.h>
  2. #include <tvm/relay/attrs/annotation.h>
  3. #include <tvm/relay/transform.h>
  4. #include "../../qnn/util.h"
  5. #include "../pattern_util.h"
  6. #include "util.h"
  7. namespace tvm {
  8. namespace relay {
  9. namespace testrewrite {
  10. class TestRealizeExpr;
  11. class TestRealizeExprNode : public TempExprNode {
  12. public:
  13. Expr expr;
  14. bool bTest;
  15. void VisitAttrs(tvm::AttrVisitor* v) {
  16. v->Visit("expr", &expr);
  17. v->Visit("bTest", &bTest);
  18. }
  19. Expr Realize() const final;
  20. TVM_DLL static TestRealizeExpr make(Expr expr);
  21. static constexpr const char* _type_key = "relay.transform.TestRealizeExpr";
  22. TVM_DECLARE_NODE_TYPE_INFO(TestRealizeExprNode, TempExprNode);
  23. };
  24. RELAY_DEFINE_NODE_REF(TestRealizeExpr, TestRealizeExprNode, TempExpr);
  25. Expr TestRealizeExprNode::Realize() const {
  26. Expr expr = this->expr;
  27. const CallNode* call_node = expr.as<CallNode>();
  28. const OpNode* call_op = call_node->op.as<OpNode>();
  29. LOG(INFO) << "Realize:" << call_op->name;
  30. return expr;
  31. }
  32. TestRealizeExpr TestRealizeExprNode::make(Expr expr) {
  33. NodePtr<TestRealizeExprNode> n = make_node<TestRealizeExprNode>();
  34. n->expr = std::move(expr);
  35. n->bTest = true;
  36. return TestRealizeExpr(n);
  37. }
  38. inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
  39. return CallNode::make(ref_call->op, args, ref_call->attrs, ref_call->type_args);
  40. }
  41. /* \brief forward the original operator */
  42. Expr IdentityTestRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
  43. const CallNode* call_node = ref_call.as<CallNode>();
  44. const OpNode* call_op = call_node->op.as<OpNode>();
  45. LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";
  46. return Expr(nullptr);
  47. }
  48. Expr Conv2dTestRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
  49. const CallNode* call_node = ref_call.as<CallNode>();
  50. const OpNode* call_op = call_node->op.as<OpNode>();
  51. LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";
  52. return TestRealizeExprNode::make(ref_call);
  53. }
  54. RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardRewrite>("FTestRewrite", Conv2dTestRewrite);
  55. Expr ReluTestRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
  56. const CallNode* call_node = ref_call.as<CallNode>();
  57. const OpNode* call_op = call_node->op.as<OpNode>();
  58. LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";
  59. if (const auto* n = new_args[0].as<TestRealizeExprNode>()) {
  60. n->expr;
  61. n->bTest;
  62. }
  63. return TestRealizeExprNode::make(ref_call);
  64. }
  65. RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FTestRewrite", ReluTestRewrite);
  66. Expr MultiplyTestRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
  67. const CallNode* call_node = ref_call.as<CallNode>();
  68. const OpNode* call_op = call_node->op.as<OpNode>();
  69. LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";
  70. return TestRealizeExprNode::make(ref_call);
  71. }
  72. RELAY_REGISTER_OP("multiply").set_attr<FForwardRewrite>("FTestRewrite", MultiplyTestRewrite);
  73. Expr AddRealize(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
  74. CHECK_EQ(new_args.size(), 2);
  75. const CallNode* call_node = ref_call.as<CallNode>();
  76. const OpNode* call_op = call_node->op.as<OpNode>();
  77. LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";
  78. CHECK_EQ(ref_call->type_args.size(), 2);
  79. if (IsElewiseShape(ref_call->type_args[0], ref_call->type_args[1])) {
  80. LOG(INFO) << "same shape";
  81. }
  82. const auto* n = new_args[0].as<TestRealizeExprNode>();
  83. if (!n) {
  84. // Expr ret = ForwardOp(ref_call, {n->data});
  85. return TestRealizeExprNode::make(ref_call);
  86. }
  87. CHECK(!new_args[0]->is_type<TempExprNode>());
  88. return Expr(nullptr);
  89. }
  90. RELAY_REGISTER_OP("add").set_attr<FForwardRewrite>("FTestRewrite", AddRealize);
  91. } // namespace testrewrite
  92. namespace transform {
  93. Pass TestRewritePass() {
  94. runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
  95. [=](Function f, Module m, PassContext pc) {
  96. return Downcast<Function>(ForwardRewrite(f, "FTestRewrite", nullptr, nullptr));
  97. };
  98. return CreateFunctionPass(pass_func, 1, "TestRewrite", {});
  99. }
  100. TVM_REGISTER_API("relay._transform.TestRewrite").set_body_typed(TestRewritePass);
  101. } // namespace transform
  102. } // namespace relay
  103. } // namespace tvm


复杂点的实例

  1. /*!
  2. * \file convert_strides.cc
  3. * \brief Convert existing 2x2-strides conv2d to
  4. * 1x1 strides conv2d + vacc_dropout.
  5. */
  6. /*
  7. * The pass does three things:
  8. * 1. Add an operator vacc_dropout to implement dropout with strides=2
  9. * 2. 1x1 kernel size and 2x2 strides conv2d is converted to 1x1 kernel size and 1x1 strides conv2d
  10. * 3. Move vacc_dropout(which is a PEP op) to the end of SEP unless elementwise add or multiply is encountered.
  11. */
  12. #include <tvm/relay/transform.h>
  13. #include <tvm/relay/attrs/nn.h>
  14. #include "../util.h"
  15. namespace tvm {
  16. namespace relay {
  17. inline Array<Integer> ConvertToConstants(const Array<IndexExpr>& arr) {
  18. Array<Integer> convert_result;
  19. for (size_t i = 0; i < arr.size(); ++i) {
  20. const IntImm* const_elem = arr[i].as<IntImm>();
  21. CHECK(const_elem);
  22. convert_result.push_back(const_elem->value);
  23. }
  24. return std::move(convert_result);
  25. }
  26. bool VaccDropoutRel(const Array<Type>& types,
  27. int num_inputs,
  28. const Attrs& attrs,
  29. const TypeReporter& reporter) {
  30. CHECK_EQ(types.size(), 2);
  31. const auto* data = types[0].as<TensorTypeNode>();
  32. if (data == nullptr) return false;
  33. auto dshape = data->shape;
  34. auto num_axis = dshape.size();
  35. std::vector<int64_t> stride_vec;
  36. for (size_t i = 0; i < num_axis - 2; ++i) {
  37. stride_vec.push_back(1);
  38. }
  39. stride_vec.push_back(2);
  40. stride_vec.push_back(2);
  41. std::vector<IndexExpr> oshape(dshape.size());
  42. for (size_t i = 0; i < num_axis; ++i) {
  43. int64_t stride_v = stride_vec[i];
  44. const int64_t* p_dim_size = as_const_int(dshape[i]);
  45. CHECK(p_dim_size)
  46. << "vacc_dropout requires dimension to be concrete int";
  47. int64_t dim_size = p_dim_size[0];
  48. oshape[i] = make_const(dshape[i].dtype(), (dim_size + stride_v - 1) / stride_v);
  49. }
  50. reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
  51. return true;
  52. }
  53. Expr MakeVaccDropout(Expr data) {
  54. static const Op& op = Op::Get("nn.vacc_dropout");
  55. return CallNode::make(op, {data}, Attrs{}, {});
  56. }
  57. TVM_REGISTER_API("relay.op._make.vacc_dropout")
  58. .set_body_typed(MakeVaccDropout);
  59. RELAY_REGISTER_OP("nn.vacc_dropout")
  60. .describe(R"code(Applies the dropout with H/W stride = 2 to the input array.
  61. Examples::
  62. x = [[ 1., 4., 7., 10.],
  63. [ 2., 5., 8., 11.],
  64. [ 3., 6., 9., 12.]]
  65. vacc_dropout(x) = [[1., 7.],
  66. [3., 9.]]
  67. x = [[[ 1., 2., 3.],
  68. [ 4., 5., 6.],
  69. [ 7., 8., 9.]],
  70. [[ 1., 2., 3.],
  71. [ 4., 5., 6.],
  72. [ 7., 8., 9.]],
  73. [[ 1., 2., 3.],
  74. [ 4., 5., 6.],
  75. [ 7., 8., 9.]]]
  76. strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 3.],
  77. [ 7., 9.]],
  78. [[ 1., 3.],
  79. [ 7., 9.]],
  80. [[ 1., 3.],
  81. [ 7., 9.]]]
  82. )code" TVM_ADD_FILELINE)
  83. .set_num_inputs(1)
  84. .add_argument("data", "Tensor", "The input tensor.")
  85. .set_support_level(4)
  86. .add_type_rel("VaccDropout", VaccDropoutRel)
  87. .set_attr<TOpPattern>("TOpPattern", kInjective);
  88. class VaccDropoutExpr;
  89. class VaccDropoutExprNode : public TempExprNode {
  90. public:
  91. /*! \brief The original expression */
  92. Expr expr;
  93. void VisitAttrs(tvm::AttrVisitor* v) {
  94. v->Visit("expr", &expr);
  95. }
  96. TVM_DLL static VaccDropoutExpr make(Expr expr);
  97. Expr Realize() const final;
  98. static constexpr const char* _type_key = "relay.VaccDropoutExpr";
  99. TVM_DECLARE_NODE_TYPE_INFO(VaccDropoutExprNode, TempExprNode);
  100. };
  101. RELAY_DEFINE_NODE_REF(VaccDropoutExpr, VaccDropoutExprNode, TempExpr);
  102. Expr VaccDropoutExprNode::Realize() const {
  103. // insert vacc_dropout
  104. return MakeVaccDropout(this->expr);
  105. }
  106. VaccDropoutExpr VaccDropoutExprNode::make(Expr expr) {
  107. auto rnode = make_node<VaccDropoutExprNode>();
  108. rnode->expr = expr;
  109. return VaccDropoutExpr(rnode);
  110. }
  111. inline Expr RebuildConv2D(const Call& ref_call, const Array<Expr>& args) {
  112. const auto* attrs = ref_call->attrs.as<Conv2DAttrs>();
  113. const auto new_attrs = make_node<Conv2DAttrs>();
  114. *new_attrs = *attrs;
  115. Array<IndexExpr> new_strides;
  116. for (size_t i = 0; i < attrs->strides.size(); ++i) {
  117. new_strides.push_back(1);
  118. }
  119. new_attrs->strides = std::move(new_strides);
  120. return CallNode::make(Op::Get("nn.conv2d"), args, Attrs(new_attrs), {});
  121. }
  122. Expr VaccDropoutConv2dRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
  123. CHECK_EQ(new_args.size(), 2);
  124. const Conv2DAttrs* attrs = ref_call->attrs.as<Conv2DAttrs>();
  125. auto const_strides = ConvertToConstants(attrs->strides);
  126. auto const_kernal = ConvertToConstants(attrs->kernel_size);
  127. CHECK(const_kernal.size() == 2 && const_strides.size() == 2);
  128. // only convert 1x1 kernel size and 2x2 strides
  129. if ((const_kernal[0]->value == 1 && const_kernal[1]->value == 1) &&
  130. (const_strides[0]->value == 2 && const_strides[1]->value == 2)) {
  131. Expr new_arg0 = new_args[0];
  132. if (const auto* n = new_args[0].as<VaccDropoutExprNode>()) {
  133. // 之前的VaccDropoutExprNode,遇到新的conv2d才Realize,添加dropout
  134. new_arg0 = n->Realize();
  135. }
  136. Expr newConv2d = RebuildConv2D(ref_call, {new_arg0, new_args[1]});
  137. return VaccDropoutExprNode::make(newConv2d);
  138. }
  139. return Expr(nullptr);
  140. }
  141. RELAY_REGISTER_OP("nn.conv2d")
  142. .set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropoutConv2dRewrite);
  143. Expr VaccDropoutSepBeginRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
  144. CHECK(false) << "VaccConvertStrides must be before VaccSepGrouping";
  145. return Expr(nullptr);
  146. }
  147. RELAY_REGISTER_OP("nn.vacc_sep_begin")
  148. .set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropoutSepBeginRewrite);
  149. Expr MakeVaccDropoutArgsExpr(const Call& ref_call, const Array<Expr>& new_args) {
  150. if (const auto* n = new_args[0].as<VaccDropoutExprNode>()) {
  151. Expr new_expr = ForwardOp(ref_call, {n->expr, new_args[1]});
  152. return VaccDropoutExprNode::make(new_expr);
  153. }
  154. return Expr(nullptr);
  155. }
  156. // 前面的分支,到此会停止传递,调用Realize()
  157. // 此类的op有: multiply, add
  158. Expr VaccDropoutElemwiseRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
  159. CHECK_EQ(new_args.size(), 2);
  160. if (IsElemwiseShape(ref_call->args[0], ref_call->args[1])) {
  161. Expr new_arg0 = new_args[0];
  162. if (const auto* n = new_args[0].as<VaccDropoutExprNode>()) {
  163. //前面的左分支到此停止传递
  164. new_arg0 = n->Realize();
  165. }
  166. Expr new_arg1 = new_args[1];
  167. if (const auto* n = new_args[1].as<VaccDropoutExprNode>()) {
  168. //前面的右分支到此停止传递
  169. new_arg1 = n->Realize();
  170. }
  171. return ForwardOp(ref_call, {new_arg0, new_arg1});
  172. }
  173. return MakeVaccDropoutArgsExpr(ref_call, new_args);
  174. }
  175. RELAY_REGISTER_OP("multiply")
  176. .set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropoutElemwiseRewrite);
  177. RELAY_REGISTER_OP("add")
  178. .set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropoutElemwiseRewrite);
  179. // 前面的分支,到此会继续往前传递
  180. Expr VaccDropout2ArgsRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
  181. CHECK_EQ(new_args.size(), 2);
  182. return MakeVaccDropoutArgsExpr(ref_call, new_args);
  183. }
  184. RELAY_REGISTER_OP("nn.bias_add")
  185. .set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout2ArgsRewrite);
  186. RELAY_REGISTER_OP("affine")
  187. .set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout2ArgsRewrite);
  188. RELAY_REGISTER_OP("nn.prelu")
  189. .set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout2ArgsRewrite);
  190. RELAY_REGISTER_OP("vacc_activation")
  191. .set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout2ArgsRewrite);
  192. // 前面的分支,到此会继续往前传递
  193. Expr VaccDropout1ArgRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
  194. CHECK_EQ(new_args.size(), 1);
  195. if (const auto* n = new_args[0].as<VaccDropoutExprNode>()) {
  196. Expr new_expr = ForwardOp(ref_call, {n->expr});
  197. return VaccDropoutExprNode::make(new_expr);
  198. }
  199. return Expr(nullptr);
  200. }
  201. RELAY_REGISTER_OP("nn.relu")
  202. .set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout1ArgRewrite);
  203. RELAY_REGISTER_OP("nn.leaky_relu")
  204. .set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout1ArgRewrite);
  205. Expr ConvertStrides(const Expr& expr) {
  206. return ForwardRewrite(expr, "FVaccConvertStrides", nullptr);
  207. }
  208. namespace transform {
  209. Pass VaccConvertStrides() {
  210. runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
  211. [=](Function f, Module m, PassContext pc) {
  212. return Downcast<Function>(ConvertStrides(f));
  213. };
  214. return CreateFunctionPass(pass_func, 1, "VaccConvertStrides", {});
  215. }
  216. TVM_REGISTER_API("relay._transform.VaccConvertStrides")
  217. .set_body_typed(VaccConvertStrides);
  218. } // namespace transform
  219. } // namespace relay
  220. } // namespace tvm

测试

  1. from tvm import relay
  2. from tvm.relay import transform, analysis
  3. from random import randint
  4. from tvm.relay.testing import run_opt_pass, run_infer_type
  5. def run_convert_strides(expr):
  6. return run_opt_pass(expr, transform.VaccConvertStrides())
  7. def test_conv2d_1x1_s2():
  8. n, c, h, w = 4, 3, 224, 224
  9. x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
  10. w1 = relay.var("w1")
  11. w2 = relay.var("w2")
  12. def before():
  13. y = relay.nn.conv2d(x, w1, strides=(2,2), kernel_size=(1, 1), channels=8)
  14. y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)
  15. y = relay.nn.conv2d(y, w2, strides=(2,2), kernel_size=(1,1), channels=4)
  16. y = relay.nn.relu(y)
  17. return relay.Function([x, w1, w2], y)
  18. def expect():
  19. y = relay.nn.conv2d(x, w1, strides=(1,1), kernel_size=(1, 1), channels=8)
  20. y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)
  21. y = relay.vacc_dropout(y)
  22. y = relay.nn.conv2d(y, w2, strides=(1,1), kernel_size=(1,1), channels=4)
  23. y = relay.nn.relu(y)
  24. y = relay.vacc_dropout(y)
  25. return relay.Function([x, w1, w2], y)
  26. converted = run_convert_strides(before())
  27. expected = run_infer_type(expect())
  28. assert analysis.alpha_equal(converted, expected)
  29. def test_conv2d_1x1_s2_elemwise():
  30. n, c, h, w = 4, 3, 224, 224
  31. x1 = relay.var("x1", relay.ty.TensorType((n, c, h, w), "float32"))
  32. x2 = relay.var("x2", relay.ty.TensorType((n, c, 112, 112), "float32"))
  33. w1 = relay.var("w1")
  34. w2 = relay.var("w2")
  35. w3 = relay.var("w3")
  36. def before():
  37. y1 = relay.nn.conv2d(x1, w1, strides=(2,2), kernel_size=(1, 1), channels=8)
  38. y1 = relay.vacc_activation('sigmoid', y1, 1.0, 2.0)
  39. y2 = relay.nn.conv2d(x2, w2, strides=(1,1), kernel_size=(1, 1), channels=8)
  40. y2 = relay.add(y2, relay.const(0.5))
  41. y = relay.add(y1, y2)
  42. y = relay.nn.conv2d(y, w3, strides=(2,2), kernel_size=(1,1), channels=4)
  43. y = relay.nn.relu(y)
  44. return relay.Function([x1, x2, w1, w2], y)
  45. def expect():
  46. y1 = relay.nn.conv2d(x1, w1, strides=(1,1), kernel_size=(1, 1), channels=8)
  47. y1 = relay.vacc_activation('sigmoid', y1, 1.0, 2.0)
  48. y1 = relay.vacc_dropout(y1)
  49. y2 = relay.nn.conv2d(x2, w2, strides=(1,1), kernel_size=(1, 1), channels=8)
  50. y2 = relay.add(y2, relay.const(0.5))
  51. y = relay.add(y1, y2)
  52. y = relay.nn.conv2d(y, w3, strides=(1,1), kernel_size=(1,1), channels=4)
  53. y = relay.nn.relu(y)
  54. y = relay.vacc_dropout(y)
  55. return relay.Function([x1, x2, w1, w2], y)
  56. converted = run_convert_strides(before())
  57. expected = run_infer_type(expect())
  58. assert analysis.alpha_equal(converted, expected)
  59. def test_conv2d_non_1x1_s2():
  60. n, c, h, w = 4, 3, 224, 224
  61. x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
  62. w1 = relay.var("w1")
  63. w2 = relay.var("w2")
  64. def before():
  65. y = relay.nn.conv2d(x, w1, strides=(2,2), kernel_size=(3, 3), channels=8)
  66. y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)
  67. y = relay.nn.conv2d(y, w2, strides=(2,2), kernel_size=(5,5), channels=4)
  68. y = relay.nn.relu(y)
  69. return relay.Function([x, w1, w2], y)
  70. def expect():
  71. y = relay.nn.conv2d(x, w1, strides=(2,2), kernel_size=(3, 3), channels=8)
  72. y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)
  73. y = relay.nn.conv2d(y, w2, strides=(2,2), kernel_size=(5,5), channels=4)
  74. y = relay.nn.relu(y)
  75. return relay.Function([x, w1, w2], y)
  76. converted = run_convert_strides(before())
  77. expected = run_infer_type(expect())
  78. assert analysis.alpha_equal(converted, expected)
  79. def test_conv2d_1x1_non_s2():
  80. n, c, h, w = 4, 3, 224, 224
  81. x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
  82. w1 = relay.var("w1")
  83. w2 = relay.var("w2")
  84. def before():
  85. y = relay.nn.conv2d(x, w1, strides=(1,1), kernel_size=(1, 1), channels=8)
  86. y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)
  87. y = relay.nn.conv2d(y, w2, strides=(1,1), kernel_size=(1,1), channels=4)
  88. y = relay.nn.relu(y)
  89. return relay.Function([x, w1, w2], y)
  90. def expect():
  91. y = relay.nn.conv2d(x, w1, strides=(1,1), kernel_size=(1, 1), channels=8)
  92. y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)
  93. y = relay.nn.conv2d(y, w2, strides=(1,1), kernel_size=(1,1), channels=4)
  94. y = relay.nn.relu(y)
  95. return relay.Function([x, w1, w2], y)
  96. converted = run_convert_strides(before())
  97. expected = run_infer_type(expect())
  98. assert analysis.alpha_equal(converted, expected)
  99. def verify(incoming_func):
  100. inferred_func = run_infer_type(incoming_func)
  101. print("inferred_func ---\n", inferred_func)
  102. converted_func = run_convert_strides(inferred_func)
  103. print("converted_func ---\n", converted_func)
  104. assert inferred_func.body.checked_type == converted_func.body.checked_type
  105. def random_strided_conv2d_test():
  106. n, c, h, w =randint(1, 256), randint(1, 256), randint(1, 256), randint(1, 256)
  107. k = randint(1, 5)
  108. kernel_size = (k, k)
  109. p = randint(1, 3)
  110. padding = (p, p)
  111. channels=randint(1, 16)
  112. x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
  113. w = relay.var("w")
  114. y = relay.nn.conv2d(x, w, strides=(2,2), kernel_size=kernel_size,
  115. padding=padding, channels=channels)
  116. y = relay.nn.relu(y)
  117. func = relay.Function([x, w], y)
  118. verify(func)
  119. def test_strided_slice():
  120. for _ in range(10):
  121. random_strided_conv2d_test()
  122. def test_identity():
  123. n, c, h, w =randint(1, 256), randint(1, 256), randint(1, 256), randint(1, 256)
  124. x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
  125. w = relay.var("w")
  126. y = relay.nn.conv2d(x, w, strides=(1, 1), kernel_size=(1, 1), padding=(0, 0), channels=8)
  127. func = relay.Function([x, w], y)
  128. inferred_func = run_infer_type(func)
  129. converted_func = run_convert_strides(inferred_func)
  130. assert analysis.alpha_equal(inferred_func, converted_func) and analysis.structural_hash(inferred_func) == analysis.structural_hash(converted_func)
  131. if __name__ == "__main__":
  132. test_conv2d_1x1_s2()
  133. test_conv2d_1x1_s2_elemwise()
  134. test_conv2d_non_1x1_s2()
  135. test_conv2d_1x1_non_s2()
  136. test_identity()