TVM- ForwardRewriter
创建时间: | 2020-07-08 11:51 |
---|---|
更新时间: | 2020-07-08 11:57 |
标签: | TVM, 李国明, 熊选文 |
ForwardRewriter
ForwardRewriter是ExprMutator的子类,所以直接利用ForwardRewriter来实现Pass,跟我们自己写一个ExprMutator子类来实现Pass本质上没有区别。但ForwardRewriter提供了一套框架,使得我们可以更方便地实现Pass,并写出更简洁的代码。ForwardRewriter主要适用于需要判断不同Op,并做出不同处理的Pass。此处的Forward是向前的意思,该方向和graph计算方向一样,但是跟graph的遍历方向是相反的。graph的遍历方法是:将graph视为一棵树,最后一个节点是根节点,第一个节点是叶子,树的遍历,从根节点到叶子,采用深度优先搜索(DFS)算法,准确的说是post order DFS。<br /><br />框架已经实现了graph的遍历,可以自动根据不同的Op调用注册的ForwardRewrite函数,我们可以专注于实现不同Op的rewrite。同时也提供了一些辅助手段供开发者使用,如TempExpr、fcontext上下文和fmulti_ref_trigger。
1 注册ForwardRewrite Pass
第161行,ForwardRewrite是ForwardRewriter的入口函数,如下:
2 注册Op的ForwardRewrite函数
第99行,set_attr的第一个参数(此处为“FTestRewrite”)要与注册Pass时ForwardRewrite函数的rewrite_map_attr_name参数相同。
Conv2dTestRewrite是一个FForwardRewrite类型的函数,FForwardRewrite定义如下图:
参数ref_call是将要被重写的原始call。
参数new_args是TempExpr或其自定义子类,可以携带自定义信息。
参数ctx是ref_call的上下文信息。
最后,TestRealizeExprNode::make(ref_call)返回一个TestRealizeExpr实例。
TestRealizeExpr是TempExpr的子类,TestRealizeExprNode是TempExprNode的子类。
如下所示,TestRealizeExprNode::make()比较简单,仅有一个参数:Expr expr。实际上,如果TestRealizeExprNode含有多个变量(除了expr之外还可以有多个附加成员变量,包含其他的信息),那么TestRealizeExprNode::make()的参数可以多一些(详细见TestRealizeExprNode的定义)。
Conv2d被make为TestRealizeExpr后会携带expr和bTest两个信息,作为relu的new_args[0]时,可以在111行和112行拿到这两个信息。
TestRealizeExpr(Node)的定义
下图是TestRealizeExpr(Node)的定义(详细见):
TempExprNode**子类可用于向前传递,但是只能传递给下一层**。
第49行Realize是TempExprNode的virtual函数,TestRealizeExprNode是TempExprNode的子类,需要重写该函数。
Realize的调用过程
下面尝试说明一下Realize被调用的过程,以如下graph为例:
如果只注册了relu的ForwardRewrite函数,并make为TestRealizeExpr返回,如下图:
根据ExprMutator post-order递归遍历graph的顺序,在遍历conv2d时会查找是否注册了ForwardRewrite函数,若未注册则下图第166行frewrite为nullptr(当前是conv2d算子,没有注册,所以frewrite为nullptr)。For循环遍历其args,args数目为1,args[0]即relu。所以,下图中的arg就是relu, 165行的newarg返回的是relu,由于frewrite为nullptr,所以new_arg就会是relizer.Realize(newarg)返回的结果。
继续看realizer.Realize后续代码,会判断是否继承自TempExprNode再调用其Realize:
所以,TempExprNode子类的Realize函数是否被调用,有两个条件:
- 遇到未注册**ForwardRewrite函数的Op**
- 前一个**Op的ForwardRewrite函数make为TempExpr**子类
那么,前一个op注册的**TempExprNode子类的Realize函数就会被调用。最后一个op除外,只要它注册了ForwardRewrite函数,它注册的TempExprNode子类的Realize函数就会被调用。**
举个例子:
序号 | 算子/Node | 算子是否被注册ForwardRewrite函数 | Realize()函数是否被调用 |
---|---|---|---|
1 | x1 | - [x] |
|
- [x]
|
| 2 | x2 |
- [ ]
|
- [ ]
|
| 3 | x3 |
- [x]
|
- [ ]
|
| 4 | x4 |
- [x]
|
- [ ]
|
| 5 | x5 |
- [x]
|
- [x]
|
序号 | 算子/Node | 算子是否被注册ForwardRewrite函数 | Realize()函数是否被调用 |
---|---|---|---|
1 | x1 | - [x] |
|
- [x]
|
| 2 | x2 |
- [ ]
|
- [ ]
|
| 3 | x3 |
- [x]
|
- [x]
|
| 4 | x4 |
- [ ]
|
- [ ]
|
| 5 | x5 |
- [x]
|
- [x]
|
参数ctx
参数ctx可以传递上下文信息,由ForwardRewrite函数的fcontext参数提供。下图是另一种注册Pass的方法,第345行提供了一个fcontext。
如下图中relu的ForwardRewrite函数中就使用了这个上下文信息:
3 介绍fmulti_ref_trigger
单独截图如下(详见“注册ForwardRewrite Pass”部分):
如果需要针对an Expr consumed by multiple callers的情况单独处理,在注册ForwardRewrite Pass时,可以提供fmulti_ref_trigger函数,下图代码当ref_count > 1时(第111行),就会调用fmulti_ref_trigger函数。
ForwardRewriter的定义:
/*!
*
* \file forward_rewrite.cc
* \brief Apply rewriting rules in a forward fashion.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include "pass_util.h"
namespace tvm {
namespace relay {
// Realizer class that realizes the expression
// Note that we can take benefit of its internal memo
// so that calling realize repeatively won't hurt perf.
class TempRealizer : private ExprMutator {
public:
Expr Realize(Expr expr) {
return VisitExpr(expr);
}
private:
Expr VisitExpr(const Expr& expr) final {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
Expr res;
if (const auto* temp = expr.as<TempExprNode>()) {
res = temp->Realize();
} else {
res = ExprFunctor::VisitExpr(expr);
}
memo_[res] = res;
return res;
}
}
};
class ForwardRewriter : private ExprMutator {
public:
ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
std::function<NodeRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger)
: rewrite_map_(rewrite_map),
fcontext_(fcontext),
fmulti_ref_trigger_(fmulti_ref_trigger) {}
ForwardRewriter(const FForwardRewrite* rewrite_func,
std::function<NodeRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger)
: rewrite_func_(rewrite_func),
fcontext_(fcontext),
fmulti_ref_trigger_(fmulti_ref_trigger) {}
// Transform expression.
Expr Rewrite(Expr expr) {
if (fmulti_ref_trigger_ != nullptr) {
ref_counter_ = GetExprRefCount(expr);
}
return this->VisitExpr(expr);
}
private:
// The rewrite rule.
const OpMap<FForwardRewrite>* rewrite_map_{nullptr};
const FForwardRewrite* rewrite_func_{nullptr};
// The context.const
std::function<NodeRef(const Call&)> fcontext_{nullptr};
// The multiple reference trigger
std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr};
// Internal ref counter
std::unordered_map<const Node*, size_t> ref_counter_;
// internal realizer
TempRealizer realizer_;
Expr VisitExpr(const Expr& expr) final {
// by default always realize.
return realizer_.Realize(ExprMutator::VisitExpr(expr));
}
// Visit and allow non-realized version.
Expr GetTempExpr(const Expr& expr) {
if (fmulti_ref_trigger_ != nullptr) {
Expr ret = ExprMutator::VisitExpr(expr);
auto it = ref_counter_.find(expr.get());
CHECK(it != ref_counter_.end());
if (it->second > 1) {
ret = fmulti_ref_trigger_(ret);
}
return ret;
} else {
return ExprMutator::VisitExpr(expr);
}
}
// Automatic fold TupleGetItem.
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr tuple = this->GetTempExpr(op->tuple);
if (const auto* ptuple = tuple.as<TupleNode>()) {
return ptuple->fields[op->index];
} else {
if (tuple.same_as(op->tuple)) {
return GetRef<Expr>(op);
} else {
return TupleGetItemNode::make(tuple, op->index);
}
}
}
Expr VisitExpr_(const TupleNode* op) final {
tvm::Array<Expr> fields;
bool all_fields_unchanged = true;
for (auto field : op->fields) {
auto new_field = this->GetTempExpr(field);
fields.push_back(new_field);
all_fields_unchanged &= new_field.same_as(field);
}
if (all_fields_unchanged) {
return GetRef<Expr>(op);
} else {
return TupleNode::make(fields);
}
}
Expr VisitExpr_(const CallNode* call_node) final {
const Call& ref_call = GetRef<Call>(call_node);
PackedFunc frewrite;
if (rewrite_func_) {
frewrite = *rewrite_func_;
} else {
CHECK(rewrite_map_);
frewrite = rewrite_map_->get(call_node->op, nullptr);
}
auto new_op = this->Mutate(call_node->op);
bool unchanged = call_node->op.same_as(new_op);
Array<Expr> call_args;
for (auto arg : call_node->args) {
Expr new_arg = this->GetTempExpr(arg);
if (frewrite == nullptr) {
//只有当前op没有注册FForwardRewrite函数时,frewrite为nullptr
//此时,realizer_.Realize(new_arg)才会被调用
new_arg = realizer_.Realize(new_arg);
}
unchanged &= new_arg.same_as(arg);
call_args.push_back(new_arg);
}
// try to rewrite.
if (frewrite != nullptr) {
Expr res = frewrite(
ref_call, call_args,
fcontext_ != nullptr ? fcontext_(ref_call) : NodeRef(nullptr));
if (res.defined()) return res;
// abort, use old rule
for (size_t i = 0; i < call_args.size(); ++i) {
Expr arg = call_args[i];
Expr new_arg = realizer_.Realize(arg);
if (!arg.same_as(new_arg)) {
call_args.Set(i, new_arg);
unchanged = false;
}
}
}
if (unchanged) return ref_call;
return CallNode::make(
new_op, call_args, call_node->attrs, call_node->type_args);
}
};
Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_name,
std::function<NodeRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger) {
auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);
}
Expr ForwardRewrite(const Expr& expr,
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger) {
return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
}
} // namespace relay
} // namespace tvm
简单实例
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/transform.h>
#include "../../qnn/util.h"
#include "../pattern_util.h"
#include "util.h"
namespace tvm {
namespace relay {
namespace testrewrite {
class TestRealizeExpr;
class TestRealizeExprNode : public TempExprNode {
public:
Expr expr;
bool bTest;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("expr", &expr);
v->Visit("bTest", &bTest);
}
Expr Realize() const final;
TVM_DLL static TestRealizeExpr make(Expr expr);
static constexpr const char* _type_key = "relay.transform.TestRealizeExpr";
TVM_DECLARE_NODE_TYPE_INFO(TestRealizeExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(TestRealizeExpr, TestRealizeExprNode, TempExpr);
Expr TestRealizeExprNode::Realize() const {
Expr expr = this->expr;
const CallNode* call_node = expr.as<CallNode>();
const OpNode* call_op = call_node->op.as<OpNode>();
LOG(INFO) << "Realize:" << call_op->name;
return expr;
}
TestRealizeExpr TestRealizeExprNode::make(Expr expr) {
NodePtr<TestRealizeExprNode> n = make_node<TestRealizeExprNode>();
n->expr = std::move(expr);
n->bTest = true;
return TestRealizeExpr(n);
}
inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
return CallNode::make(ref_call->op, args, ref_call->attrs, ref_call->type_args);
}
/* \brief forward the original operator */
Expr IdentityTestRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
const CallNode* call_node = ref_call.as<CallNode>();
const OpNode* call_op = call_node->op.as<OpNode>();
LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";
return Expr(nullptr);
}
Expr Conv2dTestRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
const CallNode* call_node = ref_call.as<CallNode>();
const OpNode* call_op = call_node->op.as<OpNode>();
LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";
return TestRealizeExprNode::make(ref_call);
}
RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardRewrite>("FTestRewrite", Conv2dTestRewrite);
Expr ReluTestRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
const CallNode* call_node = ref_call.as<CallNode>();
const OpNode* call_op = call_node->op.as<OpNode>();
LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";
if (const auto* n = new_args[0].as<TestRealizeExprNode>()) {
n->expr;
n->bTest;
}
return TestRealizeExprNode::make(ref_call);
}
RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FTestRewrite", ReluTestRewrite);
Expr MultiplyTestRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
const CallNode* call_node = ref_call.as<CallNode>();
const OpNode* call_op = call_node->op.as<OpNode>();
LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";
return TestRealizeExprNode::make(ref_call);
}
RELAY_REGISTER_OP("multiply").set_attr<FForwardRewrite>("FTestRewrite", MultiplyTestRewrite);
Expr AddRealize(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 2);
const CallNode* call_node = ref_call.as<CallNode>();
const OpNode* call_op = call_node->op.as<OpNode>();
LOG(INFO) << call_op->name << " has " << new_args.size() << " arg(s)";
CHECK_EQ(ref_call->type_args.size(), 2);
if (IsElewiseShape(ref_call->type_args[0], ref_call->type_args[1])) {
LOG(INFO) << "same shape";
}
const auto* n = new_args[0].as<TestRealizeExprNode>();
if (!n) {
// Expr ret = ForwardOp(ref_call, {n->data});
return TestRealizeExprNode::make(ref_call);
}
CHECK(!new_args[0]->is_type<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("add").set_attr<FForwardRewrite>("FTestRewrite", AddRealize);
} // namespace testrewrite
namespace transform {
Pass TestRewritePass() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(ForwardRewrite(f, "FTestRewrite", nullptr, nullptr));
};
return CreateFunctionPass(pass_func, 1, "TestRewrite", {});
}
TVM_REGISTER_API("relay._transform.TestRewrite").set_body_typed(TestRewritePass);
} // namespace transform
} // namespace relay
} // namespace tvm
复杂点的实例
/*!
* \file convert_strides.cc
* \brief Convert existing 2x2-strides conv2d to
* 1x1 strides conv2d + vacc_dropout.
*/
/*
* The pass does three things:
* 1. Add an operator vacc_dropout to implement dropout with strides=2
* 2. 1x1 kernel size and 2x2 strides conv2d is converted to 1x1 kernel size and 1x1 strides conv2d
* 3. Move vacc_dropout(which is a PEP op) to the end of SEP unless elementwise add or multiply is encountered.
*/
#include <tvm/relay/transform.h>
#include <tvm/relay/attrs/nn.h>
#include "../util.h"
namespace tvm {
namespace relay {
inline Array<Integer> ConvertToConstants(const Array<IndexExpr>& arr) {
Array<Integer> convert_result;
for (size_t i = 0; i < arr.size(); ++i) {
const IntImm* const_elem = arr[i].as<IntImm>();
CHECK(const_elem);
convert_result.push_back(const_elem->value);
}
return std::move(convert_result);
}
bool VaccDropoutRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
auto dshape = data->shape;
auto num_axis = dshape.size();
std::vector<int64_t> stride_vec;
for (size_t i = 0; i < num_axis - 2; ++i) {
stride_vec.push_back(1);
}
stride_vec.push_back(2);
stride_vec.push_back(2);
std::vector<IndexExpr> oshape(dshape.size());
for (size_t i = 0; i < num_axis; ++i) {
int64_t stride_v = stride_vec[i];
const int64_t* p_dim_size = as_const_int(dshape[i]);
CHECK(p_dim_size)
<< "vacc_dropout requires dimension to be concrete int";
int64_t dim_size = p_dim_size[0];
oshape[i] = make_const(dshape[i].dtype(), (dim_size + stride_v - 1) / stride_v);
}
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Expr MakeVaccDropout(Expr data) {
static const Op& op = Op::Get("nn.vacc_dropout");
return CallNode::make(op, {data}, Attrs{}, {});
}
TVM_REGISTER_API("relay.op._make.vacc_dropout")
.set_body_typed(MakeVaccDropout);
RELAY_REGISTER_OP("nn.vacc_dropout")
.describe(R"code(Applies the dropout with H/W stride = 2 to the input array.
Examples::
x = [[ 1., 4., 7., 10.],
[ 2., 5., 8., 11.],
[ 3., 6., 9., 12.]]
vacc_dropout(x) = [[1., 7.],
[3., 9.]]
x = [[[ 1., 2., 3.],
[ 4., 5., 6.],
[ 7., 8., 9.]],
[[ 1., 2., 3.],
[ 4., 5., 6.],
[ 7., 8., 9.]],
[[ 1., 2., 3.],
[ 4., 5., 6.],
[ 7., 8., 9.]]]
strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 3.],
[ 7., 9.]],
[[ 1., 3.],
[ 7., 9.]],
[[ 1., 3.],
[ 7., 9.]]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(4)
.add_type_rel("VaccDropout", VaccDropoutRel)
.set_attr<TOpPattern>("TOpPattern", kInjective);
class VaccDropoutExpr;
class VaccDropoutExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr expr;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("expr", &expr);
}
TVM_DLL static VaccDropoutExpr make(Expr expr);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.VaccDropoutExpr";
TVM_DECLARE_NODE_TYPE_INFO(VaccDropoutExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(VaccDropoutExpr, VaccDropoutExprNode, TempExpr);
Expr VaccDropoutExprNode::Realize() const {
// insert vacc_dropout
return MakeVaccDropout(this->expr);
}
VaccDropoutExpr VaccDropoutExprNode::make(Expr expr) {
auto rnode = make_node<VaccDropoutExprNode>();
rnode->expr = expr;
return VaccDropoutExpr(rnode);
}
inline Expr RebuildConv2D(const Call& ref_call, const Array<Expr>& args) {
const auto* attrs = ref_call->attrs.as<Conv2DAttrs>();
const auto new_attrs = make_node<Conv2DAttrs>();
*new_attrs = *attrs;
Array<IndexExpr> new_strides;
for (size_t i = 0; i < attrs->strides.size(); ++i) {
new_strides.push_back(1);
}
new_attrs->strides = std::move(new_strides);
return CallNode::make(Op::Get("nn.conv2d"), args, Attrs(new_attrs), {});
}
Expr VaccDropoutConv2dRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 2);
const Conv2DAttrs* attrs = ref_call->attrs.as<Conv2DAttrs>();
auto const_strides = ConvertToConstants(attrs->strides);
auto const_kernal = ConvertToConstants(attrs->kernel_size);
CHECK(const_kernal.size() == 2 && const_strides.size() == 2);
// only convert 1x1 kernel size and 2x2 strides
if ((const_kernal[0]->value == 1 && const_kernal[1]->value == 1) &&
(const_strides[0]->value == 2 && const_strides[1]->value == 2)) {
Expr new_arg0 = new_args[0];
if (const auto* n = new_args[0].as<VaccDropoutExprNode>()) {
// 之前的VaccDropoutExprNode,遇到新的conv2d才Realize,添加dropout
new_arg0 = n->Realize();
}
Expr newConv2d = RebuildConv2D(ref_call, {new_arg0, new_args[1]});
return VaccDropoutExprNode::make(newConv2d);
}
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropoutConv2dRewrite);
Expr VaccDropoutSepBeginRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
CHECK(false) << "VaccConvertStrides must be before VaccSepGrouping";
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.vacc_sep_begin")
.set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropoutSepBeginRewrite);
Expr MakeVaccDropoutArgsExpr(const Call& ref_call, const Array<Expr>& new_args) {
if (const auto* n = new_args[0].as<VaccDropoutExprNode>()) {
Expr new_expr = ForwardOp(ref_call, {n->expr, new_args[1]});
return VaccDropoutExprNode::make(new_expr);
}
return Expr(nullptr);
}
// 前面的分支,到此会停止传递,调用Realize()
// 此类的op有: multiply, add
Expr VaccDropoutElemwiseRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 2);
if (IsElemwiseShape(ref_call->args[0], ref_call->args[1])) {
Expr new_arg0 = new_args[0];
if (const auto* n = new_args[0].as<VaccDropoutExprNode>()) {
//前面的左分支到此停止传递
new_arg0 = n->Realize();
}
Expr new_arg1 = new_args[1];
if (const auto* n = new_args[1].as<VaccDropoutExprNode>()) {
//前面的右分支到此停止传递
new_arg1 = n->Realize();
}
return ForwardOp(ref_call, {new_arg0, new_arg1});
}
return MakeVaccDropoutArgsExpr(ref_call, new_args);
}
RELAY_REGISTER_OP("multiply")
.set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropoutElemwiseRewrite);
RELAY_REGISTER_OP("add")
.set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropoutElemwiseRewrite);
// 前面的分支,到此会继续往前传递
Expr VaccDropout2ArgsRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 2);
return MakeVaccDropoutArgsExpr(ref_call, new_args);
}
RELAY_REGISTER_OP("nn.bias_add")
.set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout2ArgsRewrite);
RELAY_REGISTER_OP("affine")
.set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout2ArgsRewrite);
RELAY_REGISTER_OP("nn.prelu")
.set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout2ArgsRewrite);
RELAY_REGISTER_OP("vacc_activation")
.set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout2ArgsRewrite);
// 前面的分支,到此会继续往前传递
Expr VaccDropout1ArgRewrite(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<VaccDropoutExprNode>()) {
Expr new_expr = ForwardOp(ref_call, {n->expr});
return VaccDropoutExprNode::make(new_expr);
}
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout1ArgRewrite);
RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FForwardRewrite>("FVaccConvertStrides", VaccDropout1ArgRewrite);
Expr ConvertStrides(const Expr& expr) {
return ForwardRewrite(expr, "FVaccConvertStrides", nullptr);
}
namespace transform {
Pass VaccConvertStrides() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(ConvertStrides(f));
};
return CreateFunctionPass(pass_func, 1, "VaccConvertStrides", {});
}
TVM_REGISTER_API("relay._transform.VaccConvertStrides")
.set_body_typed(VaccConvertStrides);
} // namespace transform
} // namespace relay
} // namespace tvm
测试
from tvm import relay
from tvm.relay import transform, analysis
from random import randint
from tvm.relay.testing import run_opt_pass, run_infer_type
def run_convert_strides(expr):
return run_opt_pass(expr, transform.VaccConvertStrides())
def test_conv2d_1x1_s2():
n, c, h, w = 4, 3, 224, 224
x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
w1 = relay.var("w1")
w2 = relay.var("w2")
def before():
y = relay.nn.conv2d(x, w1, strides=(2,2), kernel_size=(1, 1), channels=8)
y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)
y = relay.nn.conv2d(y, w2, strides=(2,2), kernel_size=(1,1), channels=4)
y = relay.nn.relu(y)
return relay.Function([x, w1, w2], y)
def expect():
y = relay.nn.conv2d(x, w1, strides=(1,1), kernel_size=(1, 1), channels=8)
y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)
y = relay.vacc_dropout(y)
y = relay.nn.conv2d(y, w2, strides=(1,1), kernel_size=(1,1), channels=4)
y = relay.nn.relu(y)
y = relay.vacc_dropout(y)
return relay.Function([x, w1, w2], y)
converted = run_convert_strides(before())
expected = run_infer_type(expect())
assert analysis.alpha_equal(converted, expected)
def test_conv2d_1x1_s2_elemwise():
n, c, h, w = 4, 3, 224, 224
x1 = relay.var("x1", relay.ty.TensorType((n, c, h, w), "float32"))
x2 = relay.var("x2", relay.ty.TensorType((n, c, 112, 112), "float32"))
w1 = relay.var("w1")
w2 = relay.var("w2")
w3 = relay.var("w3")
def before():
y1 = relay.nn.conv2d(x1, w1, strides=(2,2), kernel_size=(1, 1), channels=8)
y1 = relay.vacc_activation('sigmoid', y1, 1.0, 2.0)
y2 = relay.nn.conv2d(x2, w2, strides=(1,1), kernel_size=(1, 1), channels=8)
y2 = relay.add(y2, relay.const(0.5))
y = relay.add(y1, y2)
y = relay.nn.conv2d(y, w3, strides=(2,2), kernel_size=(1,1), channels=4)
y = relay.nn.relu(y)
return relay.Function([x1, x2, w1, w2], y)
def expect():
y1 = relay.nn.conv2d(x1, w1, strides=(1,1), kernel_size=(1, 1), channels=8)
y1 = relay.vacc_activation('sigmoid', y1, 1.0, 2.0)
y1 = relay.vacc_dropout(y1)
y2 = relay.nn.conv2d(x2, w2, strides=(1,1), kernel_size=(1, 1), channels=8)
y2 = relay.add(y2, relay.const(0.5))
y = relay.add(y1, y2)
y = relay.nn.conv2d(y, w3, strides=(1,1), kernel_size=(1,1), channels=4)
y = relay.nn.relu(y)
y = relay.vacc_dropout(y)
return relay.Function([x1, x2, w1, w2], y)
converted = run_convert_strides(before())
expected = run_infer_type(expect())
assert analysis.alpha_equal(converted, expected)
def test_conv2d_non_1x1_s2():
n, c, h, w = 4, 3, 224, 224
x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
w1 = relay.var("w1")
w2 = relay.var("w2")
def before():
y = relay.nn.conv2d(x, w1, strides=(2,2), kernel_size=(3, 3), channels=8)
y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)
y = relay.nn.conv2d(y, w2, strides=(2,2), kernel_size=(5,5), channels=4)
y = relay.nn.relu(y)
return relay.Function([x, w1, w2], y)
def expect():
y = relay.nn.conv2d(x, w1, strides=(2,2), kernel_size=(3, 3), channels=8)
y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)
y = relay.nn.conv2d(y, w2, strides=(2,2), kernel_size=(5,5), channels=4)
y = relay.nn.relu(y)
return relay.Function([x, w1, w2], y)
converted = run_convert_strides(before())
expected = run_infer_type(expect())
assert analysis.alpha_equal(converted, expected)
def test_conv2d_1x1_non_s2():
n, c, h, w = 4, 3, 224, 224
x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
w1 = relay.var("w1")
w2 = relay.var("w2")
def before():
y = relay.nn.conv2d(x, w1, strides=(1,1), kernel_size=(1, 1), channels=8)
y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)
y = relay.nn.conv2d(y, w2, strides=(1,1), kernel_size=(1,1), channels=4)
y = relay.nn.relu(y)
return relay.Function([x, w1, w2], y)
def expect():
y = relay.nn.conv2d(x, w1, strides=(1,1), kernel_size=(1, 1), channels=8)
y = relay.vacc_activation('sigmoid', y, 1.0, 2.0)
y = relay.nn.conv2d(y, w2, strides=(1,1), kernel_size=(1,1), channels=4)
y = relay.nn.relu(y)
return relay.Function([x, w1, w2], y)
converted = run_convert_strides(before())
expected = run_infer_type(expect())
assert analysis.alpha_equal(converted, expected)
def verify(incoming_func):
inferred_func = run_infer_type(incoming_func)
print("inferred_func ---\n", inferred_func)
converted_func = run_convert_strides(inferred_func)
print("converted_func ---\n", converted_func)
assert inferred_func.body.checked_type == converted_func.body.checked_type
def random_strided_conv2d_test():
n, c, h, w =randint(1, 256), randint(1, 256), randint(1, 256), randint(1, 256)
k = randint(1, 5)
kernel_size = (k, k)
p = randint(1, 3)
padding = (p, p)
channels=randint(1, 16)
x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
w = relay.var("w")
y = relay.nn.conv2d(x, w, strides=(2,2), kernel_size=kernel_size,
padding=padding, channels=channels)
y = relay.nn.relu(y)
func = relay.Function([x, w], y)
verify(func)
def test_strided_slice():
for _ in range(10):
random_strided_conv2d_test()
def test_identity():
n, c, h, w =randint(1, 256), randint(1, 256), randint(1, 256), randint(1, 256)
x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
w = relay.var("w")
y = relay.nn.conv2d(x, w, strides=(1, 1), kernel_size=(1, 1), padding=(0, 0), channels=8)
func = relay.Function([x, w], y)
inferred_func = run_infer_type(func)
converted_func = run_convert_strides(inferred_func)
assert analysis.alpha_equal(inferred_func, converted_func) and analysis.structural_hash(inferred_func) == analysis.structural_hash(converted_func)
if __name__ == "__main__":
test_conv2d_1x1_s2()
test_conv2d_1x1_s2_elemwise()
test_conv2d_non_1x1_s2()
test_conv2d_1x1_non_s2()
test_identity()