在relay中添加自定义Pass: VaccReorderConcatRelu
在relay中添加自定义Pass: VaccReorderConcatRelu
1 声明pass VaccReorderConcatRelu
在include/tvm/relay/transform.h中添加声明:
/*! \brief reorder “concat + relu” to “relu + concat”
* VaccReorderConcatRelu
* \return the pass.
*/
TVM_DLL Pass VaccReorderConcatRelu();
2 定义pass VaccReorderConcatRelu
在tvm的src目录下添加对应的源文件,例如: src/relay/pass/vacc/reorder_concat_relu.cc
* \file reorder_concat_relu.cc
*
* This pass will transform “concat + relu” to “relu + concat”.
* For example:
* xxx xxx
* \ /
* concat
* |
* relu
* / \
* xxx xxx
*
* Would become:
*
* xxx xxx
* | |
* relu relu
* \ /
* concat
* / \
* xxx xxx
*
*/
#include
#include
#include
#include
#include
#include
include “../pass_util.h”
namespace tvm {
namespace relay {
class ConcatReluReorder : public ExprMutator {
public:
Expr Reorder(const Expr& expr) {
refcounter = GetExprRefCount(expr);
return this->Mutate(expr);
}
Expr VisitExpr(const CallNode* _call) final {
static const Op& reluop = Op::Get(“nn.relu”);
static const Op& concat_op = Op::Get(“concatenate”);
Expr res = ExprMutator::VisitExpr(call); // res is a Call,调用__ExprMutator::VisitExpr(call)可向上层递归调用直到到达最上层才返回
const CallNode* res_call = res.as
CHECK(res_call != nullptr);
.
//注意:此处必须使用call, 而不能使用res_call,具体原因见IsReluAfterConcat函数
if (IsReluAfterConcat(call)) {
const CallNode prev_node = res_call->args[0].as
const TupleNode
if (concat_nodes != nullptr) {
Array
for (unsigned int i = 0; i < concat_nodes->fields.size(); i++) {
const CallNode* args_node = concat_nodes->fields[i].as
CHECK(args_node);
Expr new_relu_node =
CallNode::make(relu_op, {concat_nodes->fields[i]}, res_call->attrs,
res_call->type_args);
fields.push_back(new_relu_node);
}
// make a new CallNode for concatenate op
Expr new_concat_node =
CallNode::make(concat_op, {TupleNode::make(fields)},
prev_node->attrs, prev_node->type_args);
return new_concat_node;// relu -> concatenate_
}
}
return res;
}
private:
// Is relu after concatenate op ?
/*
* \ /
* concat
* |
* relu
*/
// 注意:此处的call参数必须是原始图中的Node, 不能是更新后的图中的Node
bool IsReluAfterConcat(const CallNode* call) {
static const Op& relu_op = Op::Get(“nn.relu”);
static const Op& concat_op = Op::Get(“concatenate”);
if (call == nullptr) {<br /> return false;<br /> }<br /> if (!call->op.same_as(relu_op)) {<br /> return false;_ // not relu op_<br /> }
// get pre node
auto prevnode = call->args[0].as
if (prev_node == nullptr) {
return false;
}
if (prev_node->op.same_as(concat_op)) {
// 此prev_node必须是原始图中的Node,所以要求call也必须是原始图中的Node,而不能是更新后的图中的Node
// 否则, 在ref_counter就肯定找不到prevnode
auto it = ref_counter.find(prevnode);
if (it != ref_counter.end() && it->second == 1) {
return true; // concat op , and ref count is 1
}
}
return false;
}
// refcounter保存的是原始图中的各Node的引用次数,不包括更新后的图中的Node
std::unorderedmap
};
Expr VaccReorderConcatRelu(const Expr& expr) {
return ConcatReluReorder().Reorder(expr);
}
namespace transform {
Pass VaccReorderConcatRelu() {
runtime::TypedPackedFunc
[=](Function _f, Module m, PassContext pc) {
return Downcast
};
return CreateFunctionPass(pass_func, 1, “VaccReorderConcatRelu”, {});
}
TVM_REGISTER_API(“relay._transform.VaccReorderConcatRelu”)
.set_body_typed(VaccReorderConcatRelu);
} // namespace transform
} // namespace relay
} // namespace tvm
3 添加单元测试,测试pass VaccReorderConcatRelu (可选)
一般在tests/python目录下添加对应的单元测试文件,推荐以testvacc_pass开头,例如:
tests/python/relay/test_vacc_pass_reorder_concat_relu.py
4 调用pass VaccReorderConcatRelu
一般在src/relay/backend/build_module.cc的RelayBuildModule::
2 定义pass VaccReorderConcatRelu
在tvm的src目录下添加对应的源文件,例如: src/relay/pass/vacc/reorder_concat_relu.cc
2