[翻译]在Relay中添加编译器Pass— tvm 0.6.0文档

在Relay中添加编译器Pass

编译器pass是扩展Relay函数集和对Relay程序执行优化的主要接口。通过编写编译器pass,可以根据您的目标来修改AST或收集有关AST的信息。确实,Relay的一些最重要的内置函数(例如,autodiff和类型推断)仅是“标准”的编译器pass。总体而言,编写pass有两个关键部分:

  • 创建一个或多个遍历程序的c++类
  • 在pass管理器API中包装遍历实现及其元数据,以便它可以与Relay Pass Infrastructure巧妙地对接 首先,我们将概述编写编译器pass的关键机制。然后,我们将演示Relay中的一个具体的例子:constant-folding pass。## AST 遍历器(Traversers) 用于遍历Relay程序的基类是ExprFunctor(include/tvm/relay/exprfunctor.h)。它提供的公共接口是VisitExpr函数,该函数接受一个表达式和零个或多个参数,并返回某种类型的实例。virtual R VisitExpr(const Expr& n, Args… args) { CHECK(n.defined()); static FType vtable = InitVTable(); return vtable(n, this, std::forward(args)…); }扩展此类时,可以通过覆盖`VisitExpr每种表达式的实现来定义AST遍历模式。例如:Expr VisitExpr_(const CallNode* call) final {...}VisitExpr和VisitExpr_之间的关系与调度有关。每个VisitExpr定义都针对特定的表达式类型,但是你并不知道所有要访问的node类型。为了解决这个问题,ExprFunctor提供了一个VisitExpr函数,该函数将从给定的表达式路由到处理它的VisitExpr具体实例。尽管c++已经提供了动态调度,但ExprFunctor仍定义了自己的vtable来供VisitExpr使用。通过定义自己的vtable,我们可以更好地控制调度。例如,如果我们想定义一个PrintVisitor遍历器,每次访问之前都打印“Here”,可以覆盖VisitExpr:void PrintVisitor::VisitExpr(const Expr& expr) { std::cout << "Here" << std::endl; ExprFunctor::VisitExpr(expr); }ExprFunctor本身是一个非常通用的类,这就是为什么您会经常扩展ExprVisitorExprMutator的原因。这些类(ExprVisitor和ExprMutator)继承并扩展了类ExprFunctor`并提供了VisitExpr_默认实现,它为每种表达式类型捕获通用遍历模式。拥有这些默认实现意味着我们只需要为需要不同行为的表达式类型提供重写实现。在以下各节中,我们将分别描述每个子类。### 表达式访问器(Expression Visitors)

    ExprVisitor适用于不修改程序而是执行程序分析并收集信息的pass。通过此类, VisitExpr和其私有对象什么都不返回。此类提供的VisitExpr_实现只访问作为表达式的所有表达式字段。IfNode的默认实现如下所示。

    void ExprVisitor::VisitExpr(const IfNode* op) { this->VisitExpr(op->cond); this->VisitExpr(op->true_branch); this->VisitExpr(op->false_branch);}请注意,我们在这里调用VisitExpr而不是`VisitExpr,因此我们可以使用ExprFunctor中的vtable进行路由。现在,如果我们想编写一个类CallChecker来检查程序中是否有任何函数调用,则只需扩展ExprVisitor并定义以下VisitExpr`方法:void VisitExpr(const CallNode *n) final{ result = true;//此处没有调用它的args来进一步向上递归}`result是一个字段(成员变量)。在这种情况下,我们不需要在CallNode字段上进一步递归,因为result`已经是true,并且现在我们知道原始表达式包含一个调用。为了使此访问者可用,我们将提供以下公共方法:bool Check(const Expr &expr) final{ result = false; VisitExpr(expr); return result_;}这就是我们所需要的。定义一个公共接口在调用顶级递归之前执行一些 bookkeeping是非常普遍的。当然,我们可以通过创建一个独立的pass来创建CallChecker实例并对其进行Check调用,从而进一步包装API,但值得一提的是,我们花费很少的精力就实现了目标。(以下蓝色字体部分是熊选文添加的)
    ### 表达式增变器(Expression Mutators)

    ExprMutator用于以某种方式来转换程序的pass。有了这个类,VisitExpr和它的私有对象就会返回Expr。此类提供的VisitExpr_默认将访问作为表达式的表达式的所有字段,并将这些字段设置为访问它们的结果。TupleGetItemNode的默认实现如下所示。

    Expr ExprMutator::VisitExpr(const TupleGetItemNode g) { auto t = this->Mutate(g->tuple); if (g->tuple == t) { return GetRef(g); } else { return TupleGetItemNode::make(t, g->index); }} /! \brief Mutate is alias for VisitExpr \return expr. */ Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); }
    Expr VisitExpr(const Expr& expr) override;
    这里有一些注意事项。首先,Mutate是 VisitExpr在ExprMutator的的别名。其次,如果调用Mutate修改了该tuple字段,则仅返回一个新node。这种更新方法称为函数更新,这样做可以避免不必要的分配。ExprMutator有但ExprVisitor没有的一个功能是一个内置的用于缓存结果的`memo
    字段。std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;<br />ExprMutator有一个记忆器(memo_),这一点很有意义,因为我们知道我们在缓存哪些结果(例如Expr),但是ExprVisitor不返回任何东西。通常,在ExprVisitor的子类中,当我们要缓存结果时,我们需要自己定义缓存。现在,如果我们想写一个类IfCollapser,用它的true分支来替换每个if语句,我们可以为IfNode覆盖VisitExpr`:Expr ExprMutator::VisitExpr(const IfNode *op){ return this->Mutate(op->true_branch);}请注意,返回的表达式不一定是 IfNode,这没有问题,因为它的返回类型是Expr。现在,我们创建公共接口:Expr CollapseIfs(const Expr &expr) final{ return this->Mutate(expr);}使用这个增变器,我们不需要做任何 bookkeeping,但是我们仍然希望遵循这个约定:使用描述性方法作为接口。## 示例:常量折叠Constant Folding

    为了更好地理解编写pass的过程,我们将以 constant folding pass (位于src/relay/pass/fold_constant.cc中)为例,因为这是一种相对简单的pass,其中包含了上面两种遍历。

    constant folding pass在程序中评估只含常量值的表达式,然后将这些表达式替换为求值结果。此pass的目的是尽我们所能进行所有的计算。为了实现这一点,该pass利用了一个vistor(ConstantChecker)和一个mutator(ConstantFolder)。### ConstantChecker访问器(Visitor) 此访问器用于检查指定的表达式是否为常量。在Relay中,如果表达式是ConstantNode,或只包含常量字段的TupleNode,我们将该表达式定义为常量。我们使用一个memo_字段将nodes和它们是否常量一一映射,并缓存这些结果。std::unorderedmap memo;以下是ConstantChecker中的VisitExpr_定义。void VisitExpr(const ConstantNode *n) final{ memo[GetRef(n)] = true;//该Node是ConstantNode,该表达式就是常量}
    void VisitExpr(const TupleNode *n) final{ bool result = true; for (const auto &field : n->fields) { if (!Check(field)) { result = false;//只要有一个字段不是常量,该表达式不是常量 break; } } memo[GetRef(n)] = result;}用于协调这些定义的 bookkeeping是一个Check函数,该函数返回给定表达式是否被视为常量。bool Check(const Expr &expr){ const auto it = memo.find(expr); if (it != memo.end()) return it->second; VisitExpr(expr); return memo[expr];}我们不会为每个node修改memo;相反,我们仅在node可能为常量时修改memo。当memo不包含expr时的默认值为false(bool类型的默认值是false)。### ConstantFolder增变器(Mutator) 该mutator执行大量的constant folding pass,并在内部调用ConstantChecker。在Relay中,有三种涉及constant folding的node类型:LetNodeTupleItemGetNode,和CallNode。在下面的段落中,我们将分别说明在pass中它们各自的角色ConstantChecker。Expr VisitExpr(const LetNode *op) final{ Expr value = this->Mutate(op->value); if (value.as()) { memo[op->var] = value; return this->Mutate(op->body); } else { Var var = Downcast(this->Mutate(op->var)); Expr body = this->Mutate(op->body); if (var.sameas(op->var) && value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return LetNode::make(var, value, body); } }}在LetNode情况下,我们首先尝试对表达式中绑定的值转换为ConstantNode。如果可以,那么我们将填充`memo并返回访问body的结果-本质上是将绑定值传播到body中的使用点。如果不能转换为ConstantNode,我们将模仿默认实现。Expr VisitExpr_(const TupleGetItemNode *op) final{ Expr res = ExprMutator::VisitExpr_(op); op = res.as<TupleGetItemNode>(); if (const auto *tuple = op->tuple.as<TupleNode>()) { return tuple->fields[op->index]; } else { return res; }}在TupleItemGetNode情况下,我们检查op->tuple是否为TupleNode。如果是,我们通过op->index用所指向的 tuple字段替换 tuple get。需要检查的原因是因为op->tuple可能被估计为一个 tuple,不管它本身是否为 tuple。Expr VisitExpr_(const CallNode *call) final{ static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful"); Expr res = ExprMutator::VisitExpr_(call); call = res.as<CallNode>(); // We don't constant fold function with zero arguments. // This is a heuristic that is useful. // For example it is harmful to fold ones(shape=(4, 5)). if (call->args.size() == 0) return res; const OpNode *op = call->op.as<OpNode>(); if (op == nullptr) return res; // skip stateful ops. if (op_stateful.get(GetRef<Op>(op), false)) return res; bool all_const_args = true; for (Expr arg : call->args) { if (!checker_.Check(arg)) { all_const_args = false; } } if (all_const_args) { return ConstEvaluate(res); } else { return res; }}在CallNode情况下,我们首先使用 ExprMutator的VisitExpr来访问该调用,该调用将const-folds该调用的所有字段。之所以使用ExprMutator::VisitExpr而不是VisitExpr,是因为我们要绕过vtable(以避免无限循环)并使用ExprMutator提供的默认实现。然后,仅当所有参数都是常量(使用ConstantChecker)时,我们才评估该调用。评估该调用会产生一个**value**,因此我们使用一个辅助方法ValueToExpr来允许我们将评估后的表达式放回AST中。现在,我们为constant folder构造了一个更方便的接口FoldConstant。FoldConstantConstantFolder类外部的一个独立函数,它接受一个表达式并在内部创建并使用一个ConstantFolder`实例(完整的定义可请参考 src/relay/pass/fold_constant.cc)。Expr FoldConstant(const Expr& expr) { DLContext ctx; ctx.device_type = kDLCPU; ctx.device_id = 0; Target target = Target::Create(“llvm”); // use a fresh build context // in case we are already in a build context. With fresh_build_ctx(BuildConfig::Create());
    return ConstantFolder(CreateInterpreter( Module(nullptr), ctx, target)).Mutate(expr);}###

    用pass管理器注册一个pass

    注意:有关此主题的更多详细信息,请参阅:ref-relay-pass-infra的文档。编写AST遍历器后,可以使用以下代码将该pass注册为TVM API端点:namespace transform {
    Pass FoldConstant(){ runtime::TypedPackedFunc pass_func = = { return Downcast(FoldConstant(f)); }; return CreateFunctionPass(pass_func, 2, “FoldConstant”, {});}
    } // namespace transform如果将上述代码产生的Pass对象提供给pass基础结构,它将确保AST遍历应用于给定Relay模块中的每个函数,这是我们希望constant folding pass的行为(它应该折叠所有常量)。函数CreateFunctionPass允许注册pass的优化级别(在本例中为2),可用于根据pass的通用函数、pass的名称以及pass的任何相关性将pass组合在一起。pass的依赖作为所有pass的列表给出,其结果对于运行当前pass是必需的。FoldConstant没有任何依赖项,但是许多relay pass确实依赖于它具有的类型信息,InferType是常见的依赖项;其它的pass依赖于该程序是否为A-normal正常形式(通过ToANormalForm pass)。请注意,PassContext对象包含pass用于错误报告和配置选项的信息。FoldConstant不需要此信息,但是其它pass可能引用其PassContext对象。现在可以通过pass基础结构调用该pass了,不过最好为该pass添加一个Python绑定,如以下代码片段所示:TVM_REGISTER_API(“relay._transform.FoldConstant”).set_body_typed(FoldConstant);一旦以上述方式定义了Pass对象,就可以使用pass基础结构的Sequential构造来调用它们,该构造将一系列pass依次应用于relay模块,从而获得转换后的模块。例如,以下代码将FoldConstantToANormalForm两个pass(一个接一个)应用于mod 中的每个函数,并获得一个新模块。seq = transform.Sequential([ relay.transform.FoldConstant(), relay.transform.ToANormalForm()])new_mod = seq(mod)注意:关于注册的更多详细信息可以参考TVM Runtime System,有关pass管理器接口的更多信息参考Relay Pass Infrastructure。relay的标准pass在include/tvm/relay/transform.h列出,并在src/relay/pass/中实现。