在relay中添加自定义Pass: VaccReorderConcatRelu


在relay中添加自定义Pass:以VaccReorderConcatRelu为例
1 声明pass VaccReorderConcatRelu在include/tvm/relay/transform.h中添加声明:
/! \brief reorder “concat + relu” to “relu + concat” VaccReorderConcatRelu \return the pass. /TVMDLL Pass VaccReorderConcatRelu();
2 定义pass VaccReorderConcatRelu在tvm的src目录下添加对应的源文件,例如: src/relay/pass/vacc/reorder_concat_relu.cc
\file reorder_concat_relu.cc This pass will transform “concat + relu” to “relu + concat”. For example: xxx xxx \ / concat | relu / \ xxx xxx Would become: xxx xxx | | relu relu \ / concat / \ xxx xxx */#include #include #include #include #include #include
#include “../pass_util.h”
namespace tvm {namespace relay {
class ConcatReluReorder : public ExprMutator { public: Expr Reorder(const Expr& expr) { ref_counter
= GetExprRefCount(expr); return this->Mutate(expr); }
Expr VisitExpr(const CallNode* call) final { static const Op& relu_op = Op::Get(“nn.relu”); static const Op& concat_op = Op::Get(“concatenate”); Expr res = ExprMutator::VisitExpr(call); // res is a Call,调用ExprMutator::VisitExpr(call)可向上层递归调用直到到达最上层才返回 const CallNode res_call = res.as(); // Call.as return CallNode CHECK(res_call != nullptr);.//注意:此处必须使用call, 而不能使用res_call,具体原因见IsReluAfterConcat函数 if (IsReluAfterConcat(call)) { const CallNode prev_node = res_call->args[0].as(); const TupleNode concat_nodes = prev_node->args[0].as(); if (concat_nodes != nullptr) { Array fields; for (unsigned int i = 0; i < concat_nodes->fields.size(); i++) { const CallNode args_node = concat_nodes->fields[i].as(); CHECK(args_node); // make a new CallNode for relu op Expr new_relu_node = CallNode::make(relu_op, {concat_nodes->fields[i]}, res_call->attrs, res_call->type_args); fields.push_back(new_relu_node); } // make a new CallNode for concatenate op Expr new_concat_node = CallNode::make(concat_op, {TupleNode::make(fields)}, prev_node->attrs, prev_node->type_args); return new_concat_node;// relu -> concatenate } } return res; }
private: // Is relu after concatenate op ? / \ / concat | relu /// 注意:此处的call参数必须是原始图中的Node, 不能是更新后的图中的Node bool IsReluAfterConcat(const CallNode* call) { static const Op& relu_op = Op::Get(“nn.relu”); static const Op& concat_op = Op::Get(“concatenate”);
if (call == nullptr) { return false; } if (!call->op.same_as(relu_op)) { return false; // not relu op }
// get pre node auto prev_node = call->args[0].as(); if (prev_node == nullptr) { return false; } // target op if (prev_node->op.same_as(concat_op)) {// 此prev_node必须是原始图中的Node,所以要求call也必须是原始图中的Node,而不能是更新后的图中的Node// 否则, 在ref_counter
就肯定找不到prevnode auto it = ref_counter.find(prevnode); if (it != ref_counter.end() && it->second == 1) { return true; // concat op , and ref count is 1 } } return false; }
//refcounter保存的是原始图中的各Node的引用次数,不包括更新后的图中的Node std::unorderedmap ref_counter;};
Expr VaccReorderConcatRelu(const Expr& expr) { return ConcatReluReorder().Reorder(expr);}
namespace transform {
Pass VaccReorderConcatRelu() { runtime::TypedPackedFunc pass_func = = { return Downcast(VaccReorderConcatRelu(f)); }; return CreateFunctionPass(pass_func, 1, “VaccReorderConcatRelu”, {});}
TVM_REGISTER_API(“relay._transform.VaccReorderConcatRelu”) .set_body_typed(VaccReorderConcatRelu);
} // namespace transform
} // namespace relay} // namespace tvm

3 将pass VaccReorderConcatRelu添加到Build
一般在src/relay/backend/buildmodule.cc的RelayBuildModule::Optimize(…)中添加 pass_seqs.push_back(transform::VaccReorderConcatRelu());//Kevin’s pass pass_seqs.push_back(transform::VaccAlter2RepeatsToUpsample());//Kevin’s pass
python/tvm/relay/transform.py添加:
def VaccReorderConcatRelu(): “””reorder concat+relu to relu+concat in a Relay program.
Returns ———- ret : tvm.relay.Pass The registered pass for VaccReorderConcatRelu. “”” return _transform.VaccReorderConcatRelu()
5 添加单元测试,测试pass VaccReorderConcatRelu (可选)一般在tests/python目录下添加对应的单元测试文件,推荐以test_vacc_pass
开头,例如:tests/python/relay/test_vacc_pass_reorder_concat_relu.py
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# “License”); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.import tvmfrom tvm import relayfrom tvm.relay import transformfrom tvm.relay.testing import run_opt_pass

def test_only_one_relu_after_concatenate(): shape = (1, 3, 24, 24) dtype=”float32”
def before(): x0 = relay.var(“x0”, relay.TensorType(shape, dtype)) x1 = relay.var(“x1”, relay.TensorType(shape, dtype)) y0 = relay.add(x0, x0) y1 = relay.add(x1, x1) y = relay.concatenate({y0,y1}, axis=0) z = relay.nn.relu(y) return relay.Function([x0,x1], z)
def expected(): x0 = relay.var(“x0”,relay.TensorType(shape,dtype)) x1 = relay.var(“x1”,relay.TensorType(shape,dtype)) y0 = relay.add(x0, x0) y1 = relay.add(x1, x1) y3 = relay.nn.relu(y0) y4 = relay.nn.relu(y1) z = relay.concatenate({y3,y4}, axis=0) return relay.Function([x0,x1], z)
r1 = run_opt_pass(before(), transform.VaccReorderConcatRelu()) r_expected = run_opt_pass(expected(), transform.InferType()) assert relay.analysis.alpha_equal(r1, r_expected)

def test_after_reorder_one_branch_after_concatenate(): shape = (1, 3, 24, 24) dtype=”float32”
def before(): x0 = relay.var(“x0”, relay.TensorType(shape, dtype)) x1 = relay.var(“x1”, relay.TensorType(shape, dtype)) y0 = relay.add(x0, x0) y1 = relay.add(x1, x1) y = relay.concatenate({y0,y1}, axis=0) y2 = relay.nn.relu(y) z = relay.add(y2,y2) return relay.Function([x0,x1], z)
def expected(): x0 = relay.var(“x0”,relay.TensorType(shape,dtype)) x1 = relay.var(“x1”,relay.TensorType(shape,dtype)) y0 = relay.add(x0, x0) y1 = relay.add(x1, x1) y2 = relay.nn.relu(y0) y3 = relay.nn.relu(y1) y = relay.concatenate({y2,y3}, axis=0) z = relay.add(y,y) return relay.Function([x0,x1], z)
r1 = run_opt_pass(before(), transform.VaccReorderConcatRelu()) r_expected = run_opt_pass(expected(), transform.InferType()) assert relay.analysis.alpha_equal(r1, r_expected)

def test_two_relu_after_concatenate(): shape = (1, 3, 24, 24) dtype=”float32”
def before(): x0 = relay.var(“x0”, relay.TensorType(shape, dtype)) x1 = relay.var(“x1”, relay.TensorType(shape, dtype)) y0 = relay.add(x0, x0) y1 = relay.add(x1, x1) y = relay.concatenate({y0,y1}, axis=0) y2 = relay.nn.relu(y) y3 = relay.nn.relu(y) z = relay.add(y2,y3) return relay.Function([x0,x1], z)
r1 = run_opt_pass(before(), transform.VaccReorderConcatRelu()) r_expected = run_opt_pass(before(), transform.InferType()) assert relay.analysis.alpha_equal(r1, r_expected)

def test_relu_and_add_after_concatenate(): shape = (1, 3, 24, 24) dtype=”float32”
def before(): x0 = relay.var(“x0”, relay.TensorType(shape, dtype)) x1 = relay.var(“x1”, relay.TensorType(shape, dtype)) y0 = relay.add(x0, x0) y1 = relay.add(x1, x1) y = relay.concatenate({y0,y1}, axis=0) y2 = relay.nn.relu(y) y3 = relay.add(y,y) z = relay.add(y2,y3) return relay.Function([x0,x1], z)
r1 = run_opt_pass(before(), transform.VaccReorderConcatRelu()) r_expected = run_opt_pass(before(), transform.InferType()) assert relay.analysis.alpha_equal(r1, r_expected)

def test_add_and_relu_after_concatenate(): shape = (1, 3, 24, 24) dtype=”float32”
def before(): x0 = relay.var(“x0”, relay.TensorType(shape, dtype)) x1 = relay.var(“x1”, relay.TensorType(shape, dtype)) y0 = relay.add(x0, x0) y1 = relay.add(x1, x1) y = relay.concatenate({y0,y1}, axis=0) y2 = relay.add(y,y) y3 = relay.nn.relu(y) z = relay.add(y2,y3) return relay.Function([x0,x1], z)
r1 = run_opt_pass(before(), transform.VaccReorderConcatRelu()) r_expected = run_opt_pass(before(), transform.InferType()) assert relay.analysis.alpha_equal(r1, r_expected)

def test_2_continuous_concatenates(): shape = (1, 3, 24, 24) dtype=”float32”
def before(): x0 = relay.var(“x0”, relay.TensorType(shape, dtype)) x1 = relay.var(“x1”, relay.TensorType(shape, dtype)) y0 = relay.abs(x0) y1 = relay.abs(x1) y2 = relay.concatenate({y0,y1}, axis=0) y3 = relay.nn.relu(y2) y4 = relay.abs(y3) y5 = relay.ceil(y4) y6 = relay.abs(y3) y7 = relay.ceil(y6) y8 = relay.concatenate({y5,y7}, axis=0) y9 = relay.nn.relu(y8) z = relay.abs(y9) return relay.Function([x0,x1], z)
def expected(): x0 = relay.var(“x0”,relay.TensorType(shape,dtype)) x1 = relay.var(“x1”,relay.TensorType(shape,dtype)) y0 = relay.abs(x0) y1 = relay.nn.relu(y0) y2 = relay.abs(x1) y3 = relay.nn.relu(y2) y4 = relay.concatenate({y1,y3}, axis=0) y5 = relay.abs(y4) y6 = relay.ceil(y5) y7 = relay.nn.relu(y6) y8 = relay.abs(y4) y9 = relay.ceil(y8) y10 = relay.nn.relu(y9) y11 = relay.concatenate({y7,y10}, axis=0) z = relay.abs(y11) return relay.Function([x0,x1], z)
r1 = run_opt_pass(before(), transform.VaccReorderConcatRelu()) r_expected = run_opt_pass(expected(), transform.InferType()) assert relay.analysis.alpha_equal(r1, r_expected)

if name == “main“: test_only_one_relu_after_concatenate() test_after_reorder_one_branch_after_concatenate() test_two_relu_after_concatenate() test_relu_and_add_after_concatenate() test_add_and_relu_after_concatenate() test_2_continuous_concatenates()