FoldConstant-常量折叠

常量折叠涉及在程序中评估仅包含常量值的表达式,然后将这些表达式替换为求值结果。
例如函数FoldConstantOpt的就利用了该pass:

  1. static inline Expr FoldConstantOpt(const Expr& expr) {
  2. auto mod = ModuleNode::FromExpr(expr);
  3. mod = transform::FoldConstant()(mod);
  4. auto entry_func = mod->Lookup("main");
  5. return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
  6. }

PartialEval-部分评估

在编译时执行已知的计算。
部分评估器尝试在编译时进行计算,因此它可以生成减少工作量的代码。
此外,它可能会为进一步的优化提供更多的机会,因为代码的高级结构部分 (闭包、引用、控制流) 可能会被部分评估掉,随后的优化 (例如,内核融合) 可以在删除结构代码时对其进行推理。在极端情况下,部分评估甚至可以将整个程序变成没有控制流的纯一阶计算。在这种情况下,我们可以将整个计算编译到SIMD指令/GPU/FPGA上,并获得巨大的加速。
它通过对标准relay解释器进行以下修改来工作:
0: 一些值变成部分静态值。
由于我们不能在编译时知道每个术语的值,术语可能会被部分评估为 “未知值”。因此,每个部分静态值都是可能不存在的静态片段 (部分静态),以及在语义上等同于原始术语的动态片段,因此,未知部分将在运行时使用动态片段进行计算。
1: 解释器保存一个LetList,它保留生成代码的正常形式。
更具体地说,我们要求所有动态都是一个原子。这避免了代码重复 (这是低效和不正确的),因为原子具有常量大小,并允许我们不处理捕获避免替换 (因为原子没有绑定)。
2: 将部分静态值的引用映射具体化,如下所述。
引用仅具有唯一标识符,而不是具有可变字段的引用。将会有一个可变的id映射到部分静态值,称为存储。
这允许我们回滚存储:
当路径可能执行也可能不执行时 (如在条件中),我们复制存储,并与复制一起递归,并在调用返回时恢复原始状态,以便不保留计算效果。
我们在if else、模式匹配和函数中执行此操作,因为当我们看到一个函数时,我们会将所有参数部分评估为动态的,为该函数生成有效的动态。
3: 生成的代码重用绑定 (尽管它们没有阴影),所以我们必须对它们进行重复数据删除。
4: 在生成的代码中,当它调用TypeSubst时,多个VarNode可能具有相同的Id。
虽然允许,但大多数pass对Var使用NodeHash,并且对于相同的Id具有多个VarNode会破坏它们。因此,我们现在将它们重新映射到单个Id。
此外,它还会生成大量死代码,因此在部分评估后通过死代码消除器提供它是一个好主意。

部分评估器做了几个假设,因此还有改进的空间:
0: 每次发生未知影响时,我们都会清理整个存储。
它太保守了: 如果创建了本地引用 (并且没有在外部传递),则未知的全局函数调用/全局引用写入无法对其进行修改。我们可以将PE与转义分析/别名分析配对。
1: 我们假设所有未知代码都有效。进行效果分析可以使商店更加精确。
2: 进行模式匹配时,即使在动态情况下,我们也可以简化匹配。
现在它是全有还是全无: 要么是完全匹配,要么是原始的动态代码。相反,我们可以得到一个匹配树,将其与数据配对,并将其评估为正常形式。然后,我们可以重现结果。
3: 每次调用函数时,其代码将被扩展并进行部分评估。
我们可以进行绑定时间分析以缓存结果并避免重新部分评估。
然而,这些假设不会影响算法的正确性。

DeadCodeElimination

移除不影响计算结果的那些代码,例如没有被引用的节点。

InferType-类型推断

类型推断和检查,做了一些转换,生成最高效的代码。
该Pass是relay IR中最重要的pass之一。
例如,函数ConstEvaluateOpt就用到上述3个pass

  1. static inline Expr ConstEvaluateOpt(Expr expr) {
  2. auto mod = ModuleNode::FromExpr(expr);
  3. Array<transform::Pass> passes = {transform::PartialEval(),
  4. transform::DeadCodeElimination(true),
  5. transform::InferType()};
  6. auto seq_pass = transform::Sequential(passes);
  7. mod = seq_pass(mod);
  8. auto entry_func = mod->Lookup("main");
  9. return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
  10. }