TVM- ForwardRewriter
| 创建时间: | 2020-07-08 11:51 |
|---|---|
| 更新时间: | 2020-07-08 11:57 |
| 标签: | TVM, 李国明, 熊选文 |
ForwardRewriter
ForwardRewriter是ExprMutator的子类,所以直接利用ForwardRewriter来实现Pass,跟我们自己写一个ExprMutator子类来实现Pass本质上没有区别。但ForwardRewriter提供了一套框架,使得我们可以更方便地实现Pass,并写出更简洁的代码。ForwardRewriter主要适用于需要判断不同Op,并做出不同处理的Pass。此处的Forward是向前的意思,该方向和graph计算方向一样,但是跟graph的遍历方向是相反的。graph的遍历方法是:将graph视为一棵树,最后一个节点是根节点,第一个节点是叶子,树的遍历,从根节点到叶子,采用深度优先搜索(DFS)算法,准确的说是post order DFS。<br /><br />框架已经实现了graph的遍历,可以自动根据不同的Op调用注册的ForwardRewrite函数,我们可以专注于实现不同Op的rewrite。同时也提供了一些辅助手段供开发者使用,如TempExpr、fcontext上下文和fmulti_ref_trigger。
1 注册ForwardRewrite Pass

第161行,ForwardRewrite是ForwardRewriter的入口函数,如下:
2 注册Op的ForwardRewrite函数

第99行,set_attr的第一个参数(此处为“FTestRewrite”)要与注册Pass时ForwardRewrite函数的rewrite_map_attr_name参数相同。
Conv2dTestRewrite是一个FForwardRewrite类型的函数,FForwardRewrite定义如下图:
参数ref_call是将要被重写的原始call。
参数new_args是TempExpr或其自定义子类,可以携带自定义信息。
参数ctx是ref_call的上下文信息。
最后,TestRealizeExprNode::make(ref_call)返回一个TestRealizeExpr实例。
TestRealizeExpr是TempExpr的子类,TestRealizeExprNode是TempExprNode的子类。
如下所示,TestRealizeExprNode::make()比较简单,仅有一个参数:Expr expr。实际上,如果TestRealizeExprNode含有多个变量(除了expr之外还可以有多个附加成员变量,包含其他的信息),那么TestRealizeExprNode::make()的参数可以多一些(详细见TestRealizeExprNode的定义)。
Conv2d被make为TestRealizeExpr后会携带expr和bTest两个信息,作为relu的new_args[0]时,可以在111行和112行拿到这两个信息。
TestRealizeExpr(Node)的定义
下图是TestRealizeExpr(Node)的定义(详细见):
TempExprNode**子类可用于向前传递,但是只能传递给下一层**。
第49行Realize是TempExprNode的virtual函数,TestRealizeExprNode是TempExprNode的子类,需要重写该函数。
Realize的调用过程
下面尝试说明一下Realize被调用的过程,以如下graph为例:
如果只注册了relu的ForwardRewrite函数,并make为TestRealizeExpr返回,如下图:
根据ExprMutator post-order递归遍历graph的顺序,在遍历conv2d时会查找是否注册了ForwardRewrite函数,若未注册则下图第166行frewrite为nullptr(当前是conv2d算子,没有注册,所以frewrite为nullptr)。For循环遍历其args,args数目为1,args[0]即relu。所以,下图中的arg就是relu, 165行的newarg返回的是relu,由于frewrite为nullptr,所以new_arg就会是relizer.Realize(newarg)返回的结果。
继续看realizer.Realize后续代码,会判断是否继承自TempExprNode再调用其Realize:
所以,TempExprNode子类的Realize函数是否被调用,有两个条件:
- 遇到未注册**ForwardRewrite函数的Op**
- 前一个**Op的ForwardRewrite函数make为TempExpr**子类
那么,前一个op注册的**TempExprNode子类的Realize函数就会被调用。最后一个op除外,只要它注册了ForwardRewrite函数,它注册的TempExprNode子类的Realize函数就会被调用。**
举个例子:
| 序号 | 算子/Node | 算子是否被注册ForwardRewrite函数 | Realize()函数是否被调用 |
|---|---|---|---|
| 1 | x1 | - [x] |
|
- [x]
|
| 2 | x2 |
- [ ]
|
- [ ]
|
| 3 | x3 |
- [x]
|
- [ ]
|
| 4 | x4 |
- [x]
|
- [ ]
|
| 5 | x5 |
- [x]
|
- [x]
|
| 序号 | 算子/Node | 算子是否被注册ForwardRewrite函数 | Realize()函数是否被调用 |
|---|---|---|---|
| 1 | x1 | - [x] |
|
- [x]
|
| 2 | x2 |
- [ ]
|
- [ ]
|
| 3 | x3 |
- [x]
|
- [x]
|
| 4 | x4 |
- [ ]
|
- [ ]
|
| 5 | x5 |
- [x]
|
- [x]
|
参数ctx
参数ctx可以传递上下文信息,由ForwardRewrite函数的fcontext参数提供。下图是另一种注册Pass的方法,第345行提供了一个fcontext。

如下图中relu的ForwardRewrite函数中就使用了这个上下文信息:
3 介绍fmulti_ref_trigger
单独截图如下(详见“注册ForwardRewrite Pass”部分):
如果需要针对an Expr consumed by multiple callers的情况单独处理,在注册ForwardRewrite Pass时,可以提供fmulti_ref_trigger函数,下图代码当ref_count > 1时(第111行),就会调用fmulti_ref_trigger函数。
ForwardRewriter的定义:
/*!** \file forward_rewrite.cc* \brief Apply rewriting rules in a forward fashion.*/#include <tvm/relay/expr_functor.h>#include <tvm/relay/op_attr_types.h>#include <tvm/relay/transform.h>#include "pass_util.h"namespace tvm {namespace relay {// Realizer class that realizes the expression// Note that we can take benefit of its internal memo// so that calling realize repeatively won't hurt perf.class TempRealizer : private ExprMutator {public:Expr Realize(Expr expr) {return VisitExpr(expr);}private:Expr VisitExpr(const Expr& expr) final {auto it = memo_.find(expr);if (it != memo_.end()) {return it->second;} else {Expr res;if (const auto* temp = expr.as<TempExprNode>()) {res = temp->Realize();} else {res = ExprFunctor::VisitExpr(expr);}memo_[res] = res;return res;}}};class ForwardRewriter : private ExprMutator {public:ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,std::function<NodeRef(const Call&)> fcontext,std::function<Expr(const Expr&)> fmulti_ref_trigger): rewrite_map_(rewrite_map),fcontext_(fcontext),fmulti_ref_trigger_(fmulti_ref_trigger) {}ForwardRewriter(const FForwardRewrite* rewrite_func,std::function<NodeRef(const Call&)> fcontext,std::function<Expr(const Expr&)> fmulti_ref_trigger): rewrite_func_(rewrite_func),fcontext_(fcontext),fmulti_ref_trigger_(fmulti_ref_trigger) {}// Transform expression.Expr Rewrite(Expr expr) {if (fmulti_ref_trigger_ != nullptr) {ref_counter_ = GetExprRefCount(expr);}return this->VisitExpr(expr);}private:// The rewrite rule.const OpMap<FForwardRewrite>* rewrite_map_{nullptr};const FForwardRewrite* rewrite_func_{nullptr};// The context.conststd::function<NodeRef(const Call&)> fcontext_{nullptr};// The multiple reference triggerstd::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr};// Internal ref counterstd::unordered_map<const Node*, size_t> ref_counter_;// internal realizerTempRealizer realizer_;Expr VisitExpr(const Expr& expr) final {// by default always realize.return realizer_.Realize(ExprMutator::VisitExpr(expr));}// Visit and allow non-realized version.Expr GetTempExpr(const Expr& expr) {if (fmulti_ref_trigger_ != nullptr) {Expr ret = ExprMutator::VisitExpr(expr);auto it = ref_counter_.find(expr.get());CHECK(it != ref_counter_.end());if (it->second > 1) {ret = fmulti_ref_trigger_(ret);}return ret;} else {return ExprMutator::VisitExpr(expr);}}// Automatic fold TupleGetItem.Expr VisitExpr_(const TupleGetItemNode* op) final {Expr tuple = this->GetTempExpr(op->tuple);if (const auto* ptuple = tuple.as<TupleNode>()) {return ptuple->fields[op->index];} else {if (tuple.same_as(op->tuple)) {return GetRef<Expr>(op);} else {return TupleGetItemNode::make(tuple, op->index);}}}Expr VisitExpr_(const TupleNode* op) final {tvm::Array<Expr> fields;bool all_fields_unchanged = true;for (auto field : op->fields) {auto new_field = this->GetTempExpr(field);fields.push_back(new_field);all_fields_unchanged &= new_field.same_as(field);}if (all_fields_unchanged) {return GetRef<Expr>(op);} else {return TupleNode::make(fields);}}Expr VisitExpr_(const CallNode* call_node) final {const Call& ref_call = GetRef<Call>(call_node);PackedFunc frewrite;if (rewrite_func_) {frewrite = *rewrite_func_;} else {CHECK(rewrite_map_);frewrite = rewrite_map_->get(call_node->op, nullptr);}auto new_op = this->Mutate(call_node->op);bool unchanged = call_node->op.same_as(new_op);Array<Expr> call_args;for (auto arg : call_node->args) {Expr new_arg = this->GetTempExpr(arg);if (frewrite == nullptr) {//只有当前op没有注册FForwardRewrite函数时,frewrite为nullptr//此时,realizer_.Realize(new_arg)才会被调用new_arg = realizer_.Realize(new_arg);}unchanged &= new_arg.same_as(arg);call_args.push_back(new_arg);}// try to rewrite.if (frewrite != nullptr) {Expr res = frewrite(ref_call, call_args,fcontext_ != nullptr ? fcontext_(ref_call) : NodeRef(nullptr));if (res.defined()) return res;// abort, use old rulefor (size_t i = 0; i < call_args.size(); ++i) {Expr arg = call_args[i];Expr new_arg = realizer_.Realize(arg);if (!arg.same_as(new_arg)) {call_args.Set(i, new_arg);unchanged = false;}}}if (unchanged) return ref_call;return CallNode::make(new_op, call_args, call_node->attrs, call_node->type_args);}};Expr ForwardRewrite(const Expr& expr,const std::string& rewrite_map_name,std::function<NodeRef(const Call&)> fcontext,std::function<Expr(const Expr&)> fmulti_ref_trigger) {auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);}Expr ForwardRewrite(const Expr& expr,const FForwardRewrite& rewrite_func,std::function<NodeRef(const Call&)> fcontext,std::function<Expr(const Expr&)> fmulti_ref_trigger) {return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);}} // namespace relay} // namespace tvm
简单实例
#include <tvm/relay/analysis.h>#include <tvm/relay/attrs/annotation.h>#include <tvm/relay/transform.h>#include "../../qnn/util.h"#include "../pattern_util.h"#include "util.h"namespace tvm {namespace relay {namespace testrewrite {class TestRealizeExpr;class TestRealizeExprNode : public TempExprNode {public:Expr expr;bool bTest;void VisitAttrs(tvm::AttrVisitor* v) {v->Visit("expr", &expr);v->Visit("bTest", &bTest);}Expr Realize() const final;TVM_DLL static TestRealizeExpr make(Expr expr);static constexpr const char* _type_key = "relay.transform.TestRealizeExpr";TVM_DECLARE_NODE_TYPE_INFO(TestRealizeExprNode, TempExprNode);};RELAY_DEFINE_NODE_REF(TestRealizeExpr, TestRealizeExprNode, TempExpr);Expr TestRealizeExprNode::Realize() const {Expr expr = this->expr;const CallNode* call_node = expr.as<CallNode>();const OpNode* call_op = call_node->op.as<OpNode>();LOG(INFO) << "Realize:" << call_op->name;return expr;}TestRealizeExpr TestRealizeExprNode::make(Expr expr) {NodePtr<TestRealizeExprNode> n = make_node<TestRealizeExprNode>();n->expr = std::move(expr);n->bTest = true;return TestRealizeExpr(n);}inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {return CallNode::make(ref_call->op, args, ref_call->attrs, ref_call->type_args);}/* \brief forward the original operator */Expr IdentityTestRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {const CallNode* call_node = ref_call.as<CallNode>();const OpNode* call_op = call_node->op.as<OpNode>();LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";return Expr(nullptr);}Expr Conv2dTestRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {const CallNode* call_node = ref_call.as<CallNode>();const OpNode* call_op = call_node->op.as<OpNode>();LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";return TestRealizeExprNode::make(ref_call);}RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardRewrite>("FTestRewrite", Conv2dTestRewrite);Expr ReluTestRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {const CallNode* call_node = ref_call.as<CallNode>();const OpNode* call_op = call_node->op.as<OpNode>();LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";if (const auto* n = new_args[0].as<TestRealizeExprNode>()) {n->expr;n->bTest;}return TestRealizeExprNode::make(ref_call);}RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FTestRewrite", ReluTestRewrite);Expr MultiplyTestRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {const CallNode* call_node = ref_call.as<CallNode>();const OpNode* call_op = call_node->op.as<OpNode>();LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";return TestRealizeExprNode::make(ref_call);}RELAY_REGISTER_OP("multiply").set_attr<FForwardRewrite>("FTestRewrite", MultiplyTestRewrite);Expr AddRealize(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {CHECK_EQ(new_args.size(), 2);const CallNode* call_node = ref_call.as<CallNode>();const OpNode* call_op = call_node->op.as<OpNode>();LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";CHECK_EQ(ref_call->type_args.size(), 2);if (IsElewiseShape(ref_call->type_args[0], ref_call->type_args[1])) {LOG(INFO) << "same shape";}const auto* n = new_args[0].as<TestRealizeExprNode>();if (!n) {// Expr ret = ForwardOp(ref_call, {n->data});return TestRealizeExprNode::make(ref_call);}CHECK(!new_args[0]->is_type<TempExprNode>());return Expr(nullptr);}RELAY_REGISTER_OP("add").set_attr<FForwardRewrite>("FTestRewrite", AddRealize);} // namespace testrewritenamespace transform {Pass TestRewritePass() {runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =[=](Function f, Module m, PassContext pc) {return Downcast<Function>(ForwardRewrite(f, "FTestRewrite", nullptr, nullptr));};return CreateFunctionPass(pass_func, 1, "TestRewrite", {});}TVM_REGISTER_API("relay._transform.TestRewrite").set_body_typed(TestRewritePass);} // namespace transform} // namespace relay} // namespace tvm
复杂点的实例
/*!* \file convert_strides.cc* \brief Convert existing 2x2-strides conv2d to* 1x1 strides conv2d + vacc_dropout.*//** The pass does three things:* 1. Add an operator vacc_dropout to implement dropout with strides=2* 2. 1x1 kernel size and 2x2 strides conv2d is converted to 1x1 kernel size and 1x1 strides conv2d* 3. Move vacc_dropout(which is a PEP op) to the end of SEP unless elementwise add or multiply is encountered.*/#include <tvm/relay/transform.h>#include <tvm/relay/attrs/nn.h>#include "../util.h"namespace tvm {namespace relay {inline Array<Integer> ConvertToConstants(const Array<IndexExpr>& arr) {Array<Integer> convert_result;for (size_t i = 0; i < arr.size(); ++i) {const IntImm* const_elem = arr[i].as<IntImm>();CHECK(const_elem);convert_result.push_back(const_elem->value);}return std::move(convert_result);}bool VaccDropoutRel(const Array<Type>& types,int num_inputs,const Attrs& attrs,const TypeReporter& reporter) {CHECK_EQ(types.size(), 2);const auto* data = types[0].as<TensorTypeNode>();if (data == nullptr) return false;auto dshape = data->shape;auto num_axis = dshape.size();std::vector<int64_t> stride_vec;for (size_t i = 0; i < num_axis - 2; ++i) {stride_vec.push_back(1);}stride_vec.push_back(2);stride_vec.push_back(2);std::vector<IndexExpr> oshape(dshape.size());for (size_t i = 0; i < num_axis; ++i) {int64_t stride_v = stride_vec[i];const int64_t* p_dim_size = as_const_int(dshape[i]);CHECK(p_dim_size)<< "vacc_dropout requires dimension to be concrete int";int64_t dim_size = p_dim_size[0];oshape[i] = make_const(dshape[i].dtype(), (dim_size + stride_v - 1) / stride_v);}reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));return true;}Expr MakeVaccDropout(Expr data) {static const Op& op = Op::Get("nn.vacc_dropout");return CallNode::make(op, {data}, Attrs{}, {});}TVM_REGISTER_API("relay.op._make.vacc_dropout").set_body_typed(MakeVaccDropout);RELAY_REGISTER_OP("nn.vacc_dropout").describe(R"code(Applies the dropout with H/W stride = 2 to the input array.Examples::x = [[ 1., 4., 7., 10.],[ 2., 5., 8., 11.],[ 3., 6., 9., 12.]]vacc_dropout(x) = [[1., 7.],[3., 9.]]x = [[[ 1., 2., 3.],[ 4., 5., 6.],[ 7., 8., 9.]],[[ 1., 2., 3.],[ 4., 5., 6.],[ 7., 8., 9.]],[[ 1., 2., 3.],[ 4., 5., 6.],[ 7., 8., 9.]]]strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 3.],[ 7., 9.]],[[ 1., 3.],[ 7., 9.]],[[ 1., 3.],[ 7., 9.]]])code" TVM_ADD_FILELINE).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.").set_support_level(4).add_type_rel("VaccDropout", VaccDropoutRel).set_attr<TOpPattern>("TOpPattern", kInjective);class VaccDropoutExpr;class VaccDropoutExprNode : public TempExprNode {public:/*! \brief The original expression */Expr expr;void VisitAttrs(tvm::AttrVisitor* v) {v->Visit("expr", &expr);}TVM_DLL static VaccDropoutExpr make(Expr expr);Expr Realize() const final;static constexpr const char* _type_key = "relay.VaccDropoutExpr";TVM_DECLARE_NODE_TYPE_INFO(VaccDropoutExprNode, TempExprNode);};RELAY_DEFINE_NODE_REF(VaccDropoutExpr, VaccDropoutExprNode, TempExpr);Expr VaccDropoutExprNode::Realize() const {// insert vacc_dropoutreturn MakeVaccDropout(this->expr);}VaccDropoutExpr VaccDropoutExprNode::make(Expr expr) {auto rnode = make_node<VaccDropoutExprNode>();rnode->expr = expr;return VaccDropoutExpr(rnode);}inline Expr RebuildConv2D(const Call& ref_call, const Array<Expr>& args) {const auto* attrs = ref_call->attrs.as<Conv2DAttrs>();const auto new_attrs = make_node<Conv2DAttrs>();*new_attrs = *attrs;Array<IndexExpr> new_strides;for (size_t i = 0; i < attrs->strides.size(); ++i) {new_strides.push_back(1);}new_attrs->strides = std::move(new_strides);return CallNode::make(Op::Get("nn.conv2d"), args, Attrs(new_attrs), {});}Expr VaccDropoutConv2dRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {CHECK_EQ(new_args.size(), 2);const Conv2DAttrs* attrs = ref_call->attrs.as<Conv2DAttrs>();auto const_strides = ConvertToConstants(attrs->strides);auto const_kernal = ConvertToConstants(attrs->kernel_size);CHECK(const_kernal.size() == 2 && const_strides.size() == 2);// only convert 1x1 kernel size and 2x2 stridesif ((const_kernal[0]->value == 1 && const_kernal[1]->value == 1) &&(const_strides[0]->value == 2 && const_strides[1]->value == 2)) {Expr new_arg0 = new_args[0];if (const auto* n = new_args[0].as<VaccDropoutExprNode>()) {// 之前的VaccDropoutExprNode,遇到新的conv2d才Realize,添加dropoutnew_arg0 = n->Realize();}Expr newConv2d = RebuildConv2D(ref_call, {new_arg0, new_args[1]});return VaccDropoutExprNode::make(newConv2d);}return Expr(nullptr);}RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropoutConv2dRewrite);Expr VaccDropoutSepBeginRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {CHECK(false) << "VaccConvertStrides must be before VaccSepGrouping";return Expr(nullptr);}RELAY_REGISTER_OP("nn.vacc_sep_begin").set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropoutSepBeginRewrite);Expr MakeVaccDropoutArgsExpr(const Call& ref_call, const Array<Expr>& new_args) {if (const auto* n = new_args[0].as<VaccDropoutExprNode>()) {Expr new_expr = ForwardOp(ref_call, {n->expr, new_args[1]});return VaccDropoutExprNode::make(new_expr);}return Expr(nullptr);}// 前面的分支,到此会停止传递,调用Realize()// 此类的op有: multiply, addExpr VaccDropoutElemwiseRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {CHECK_EQ(new_args.size(), 2);if (IsElemwiseShape(ref_call->args[0], ref_call->args[1])) {Expr new_arg0 = new_args[0];if (const auto* n = new_args[0].as<VaccDropoutExprNode>()) {//前面的左分支到此停止传递new_arg0 = n->Realize();}Expr new_arg1 = new_args[1];if (const auto* n = new_args[1].as<VaccDropoutExprNode>()) {//前面的右分支到此停止传递new_arg1 = n->Realize();}return ForwardOp(ref_call, {new_arg0, new_arg1});}return MakeVaccDropoutArgsExpr(ref_call, new_args);}RELAY_REGISTER_OP("multiply").set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropoutElemwiseRewrite);RELAY_REGISTER_OP("add").set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropoutElemwiseRewrite);// 前面的分支,到此会继续往前传递Expr VaccDropout2ArgsRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {CHECK_EQ(new_args.size(), 2);return MakeVaccDropoutArgsExpr(ref_call, new_args);}RELAY_REGISTER_OP("nn.bias_add").set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout2ArgsRewrite);RELAY_REGISTER_OP("affine").set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout2ArgsRewrite);RELAY_REGISTER_OP("nn.prelu").set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout2ArgsRewrite);RELAY_REGISTER_OP("vacc_activation").set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout2ArgsRewrite);// 前面的分支,到此会继续往前传递Expr VaccDropout1ArgRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {CHECK_EQ(new_args.size(), 1);if (const auto* n = new_args[0].as<VaccDropoutExprNode>()) {Expr new_expr = ForwardOp(ref_call, {n->expr});return VaccDropoutExprNode::make(new_expr);}return Expr(nullptr);}RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout1ArgRewrite);RELAY_REGISTER_OP("nn.leaky_relu").set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout1ArgRewrite);Expr ConvertStrides(const Expr& expr) {return ForwardRewrite(expr, "FVaccConvertStrides", nullptr);}namespace transform {Pass VaccConvertStrides() {runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =[=](Function f, Module m, PassContext pc) {return Downcast<Function>(ConvertStrides(f));};return CreateFunctionPass(pass_func, 1, "VaccConvertStrides", {});}TVM_REGISTER_API("relay._transform.VaccConvertStrides").set_body_typed(VaccConvertStrides);} // namespace transform} // namespace relay} // namespace tvm
测试
from tvm import relayfrom tvm.relay import transform, analysisfrom random import randintfrom tvm.relay.testing import run_opt_pass, run_infer_typedef run_convert_strides(expr):return run_opt_pass(expr, transform.VaccConvertStrides())def test_conv2d_1x1_s2():n, c, h, w = 4, 3, 224, 224x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))w1 = relay.var("w1")w2 = relay.var("w2")def before():y = relay.nn.conv2d(x, w1, strides=(2,2), kernel_size=(1, 1), channels=8)y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)y = relay.nn.conv2d(y, w2, strides=(2,2), kernel_size=(1,1), channels=4)y = relay.nn.relu(y)return relay.Function([x, w1, w2], y)def expect():y = relay.nn.conv2d(x, w1, strides=(1,1), kernel_size=(1, 1), channels=8)y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)y = relay.vacc_dropout(y)y = relay.nn.conv2d(y, w2, strides=(1,1), kernel_size=(1,1), channels=4)y = relay.nn.relu(y)y = relay.vacc_dropout(y)return relay.Function([x, w1, w2], y)converted = run_convert_strides(before())expected = run_infer_type(expect())assert analysis.alpha_equal(converted, expected)def test_conv2d_1x1_s2_elemwise():n, c, h, w = 4, 3, 224, 224x1 = relay.var("x1", relay.ty.TensorType((n, c, h, w), "float32"))x2 = relay.var("x2", relay.ty.TensorType((n, c, 112, 112), "float32"))w1 = relay.var("w1")w2 = relay.var("w2")w3 = relay.var("w3")def before():y1 = relay.nn.conv2d(x1, w1, strides=(2,2), kernel_size=(1, 1), channels=8)y1 = relay.vacc_activation('sigmoid', y1, 1.0, 2.0)y2 = relay.nn.conv2d(x2, w2, strides=(1,1), kernel_size=(1, 1), channels=8)y2 = relay.add(y2, relay.const(0.5))y = relay.add(y1, y2)y = relay.nn.conv2d(y, w3, strides=(2,2), kernel_size=(1,1), channels=4)y = relay.nn.relu(y)return relay.Function([x1, x2, w1, w2], y)def expect():y1 = relay.nn.conv2d(x1, w1, strides=(1,1), kernel_size=(1, 1), channels=8)y1 = relay.vacc_activation('sigmoid', y1, 1.0, 2.0)y1 = relay.vacc_dropout(y1)y2 = relay.nn.conv2d(x2, w2, strides=(1,1), kernel_size=(1, 1), channels=8)y2 = relay.add(y2, relay.const(0.5))y = relay.add(y1, y2)y = relay.nn.conv2d(y, w3, strides=(1,1), kernel_size=(1,1), channels=4)y = relay.nn.relu(y)y = relay.vacc_dropout(y)return relay.Function([x1, x2, w1, w2], y)converted = run_convert_strides(before())expected = run_infer_type(expect())assert analysis.alpha_equal(converted, expected)def test_conv2d_non_1x1_s2():n, c, h, w = 4, 3, 224, 224x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))w1 = relay.var("w1")w2 = relay.var("w2")def before():y = relay.nn.conv2d(x, w1, strides=(2,2), kernel_size=(3, 3), channels=8)y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)y = relay.nn.conv2d(y, w2, strides=(2,2), kernel_size=(5,5), channels=4)y = relay.nn.relu(y)return relay.Function([x, w1, w2], y)def expect():y = relay.nn.conv2d(x, w1, strides=(2,2), kernel_size=(3, 3), channels=8)y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)y = relay.nn.conv2d(y, w2, strides=(2,2), kernel_size=(5,5), channels=4)y = relay.nn.relu(y)return relay.Function([x, w1, w2], y)converted = run_convert_strides(before())expected = run_infer_type(expect())assert analysis.alpha_equal(converted, expected)def test_conv2d_1x1_non_s2():n, c, h, w = 4, 3, 224, 224x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))w1 = relay.var("w1")w2 = relay.var("w2")def before():y = relay.nn.conv2d(x, w1, strides=(1,1), kernel_size=(1, 1), channels=8)y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)y = relay.nn.conv2d(y, w2, strides=(1,1), kernel_size=(1,1), channels=4)y = relay.nn.relu(y)return relay.Function([x, w1, w2], y)def expect():y = relay.nn.conv2d(x, w1, strides=(1,1), kernel_size=(1, 1), channels=8)y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)y = relay.nn.conv2d(y, w2, strides=(1,1), kernel_size=(1,1), channels=4)y = relay.nn.relu(y)return relay.Function([x, w1, w2], y)converted = run_convert_strides(before())expected = run_infer_type(expect())assert analysis.alpha_equal(converted, expected)def verify(incoming_func):inferred_func = run_infer_type(incoming_func)print("inferred_func ---\n", inferred_func)converted_func = run_convert_strides(inferred_func)print("converted_func ---\n", converted_func)assert inferred_func.body.checked_type == converted_func.body.checked_typedef random_strided_conv2d_test():n, c, h, w =randint(1, 256), randint(1, 256), randint(1, 256), randint(1, 256)k = randint(1, 5)kernel_size = (k, k)p = randint(1, 3)padding = (p, p)channels=randint(1, 16)x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))w = relay.var("w")y = relay.nn.conv2d(x, w, strides=(2,2), kernel_size=kernel_size,padding=padding, channels=channels)y = relay.nn.relu(y)func = relay.Function([x, w], y)verify(func)def test_strided_slice():for _ in range(10):random_strided_conv2d_test()def test_identity():n, c, h, w =randint(1, 256), randint(1, 256), randint(1, 256), randint(1, 256)x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))w = relay.var("w")y = relay.nn.conv2d(x, w, strides=(1, 1), kernel_size=(1, 1), padding=(0, 0), channels=8)func = relay.Function([x, w], y)inferred_func = run_infer_type(func)converted_func = run_convert_strides(inferred_func)assert analysis.alpha_equal(inferred_func, converted_func) and analysis.structural_hash(inferred_func) == analysis.structural_hash(converted_func)if __name__ == "__main__":test_conv2d_1x1_s2()test_conv2d_1x1_s2_elemwise()test_conv2d_non_1x1_s2()test_conv2d_1x1_non_s2()test_identity()
