[翻译]编写自定义pass— tvm 0.7.dev0文档


# 编写自定义pass 作者:Jian WengTVM是一个框架,它抽象了机器学习加速器的异质性。有时,用户可能需要自定义一些分析和IR转换,以使TVM适应自己的专用硬件。本教程可帮助用户在TVM中编写自定义pass。## 前提条件 在阅读本教程之前,我们假设读者已经熟悉以下主题:

  • 在TVM中编写算法并进行调度。否则,请参见示例教程,例如如何在CPU上优化GEMM
  • HalideIR的基本结构。否则,请参阅HalideIR/src/ir/IR.h以了解定义了IR节点的哪些属性。
  • 访问者设计模式。否则,请检查Python AST模块以了解如何实现AST访问者。
  • HalideIR / Schedule如何降级为LoweredFunc类或LLVM模块。否则,请看python/tvm/build_module.py以获得一些基础知识。 fromfutureimportabsolute_import,print_functionimporttvmimportnumpyasnp我们首先编写一个非常简单的向量加法,并使用默认schedule来构建它。然后,我们使用自定义的降级pass来直接操纵IR,而不是使用schedule原语。n=tvm.const(128,”int32”)a=tvm.placeholder((n,),name=”a”)b=tvm.placeholder((n,),name=”b”)c=tvm.compute((n,),lambdai:a[i]+b[i],name=’c’)
    sch=tvm.create_schedule(c.op)ir=tvm.lower(sch,[a,b,c],simple_mode=True)print(ir)输出:produce c {for (i, 0, 128) {c[i] = (a[i] + b[i])}}## 写pass 本质上,“IR转换pass” 是将语句映射到新语句的函数。因此,我们定义此向量化函数并逐步实现它。TVM已经为用户提供了两个类来分析和转换IR。### IR访问器 我们可以用tvm.ir_pass.PostOrderVisit(stmt, func)来从Halide IR收集信息。func是一个函数回调。在退出当前IR node之前调用此函数。然后,我们想办法存储IR访问的结果,因为func的返回值将被忽略。注意您必须使用一些数组来存储IR访问的结果。即使该值仅仅是一个变量。这主要是由于Python-C运行时中的限制。每次递归都会刷新变量值,但会保留数组值。loops=[]deffind_width8(op):””” Find all the ‘For’ nodes whose extent can be divided by 8. “””ifisinstance(op,tvm.stmt.For):ifisinstance(op.extent,tvm.expr.IntImm):ifop.extent.value%8==0:loops.append(op)### IR转换 转换接口与访问者接口略有不同。访问者中只有一个后序回调,但是转换访问者既支持前序回调也支持后序回调。如果要保留原始IR节点,只需返回None。如果要将当前节点更改为某个节点,请使用TVM IR maker接口进行构建并返回该值。注意如果调用了前序函数并返回了非None值,则将跳过后序函数。defvectorize8(op):””” Split can vectorize the loops found in find_width8. “””ifopinloops:extent=op.extent.valuename=op.loop_var.namelo,li=tvm.var(name+’.outer’),tvm.var(name+’.inner’)body=tvm.ir_pass.Substitute(op.body,{op.loop_var:lo8+li})body=tvm.make.For(li,0,8,tvm.stmt.For.Vectorized,0,body)body=tvm.make.For(lo,0,extent//8,tvm.stmt.For.Serial,0,body)returnbodyreturnNone
    defvectorize(stmt):globalloops
    tvm.ir_pass.PostOrderVisit(stmt,find_width8)
    ifnotloops:returnstmt
    # The last list arugment indicates what kinds of nodes will be transformed.# Thus, in this case only For nodes will call vectorize8stmt=tvm.ir_pass.IRTransform(stmt,None,vectorize8,[‘For’])
    returnstmt## pass降级 到目前为止,我们已经编写完成IR转换pass。接下来,我们需要将此pass 关联到TVM的下层pass上。首先,可以直接调用此函数作为健全性检查。print(vectorize(ir))输出:produce c {for (i.outer, 0, 16) {vectorized (i.inner, 0, 8) {c[((i.outer
    8) + i.inner)] = (a[((i.outer8) + i.inner)] + b[((i.outer8) + i.inner)])}}}在TVM中,有一个名为BuildConfig的属性。您可以使用此属性来自定义您自己的pass降级选项。在这种情况下,我们通过将元组列表作为参数提供给add_lower_pass, 将上面所写的pass注入TVM标准降级pass中。“元组”表示降级的不同阶段。在TVM中,降级分为四个阶段,每个阶段完成后将调用用户自定义的阶段。注意以下是每个阶段完成的基本转换:
  • 阶段0生成原始IR和循环级别。
  • 阶段1将数组存储扁平化。
  • 阶段2转换循环,例如展开,向量化和线程绑定。
  • 阶段3进行一些清理工作。 因此,放置此转换pass的好地方就在阶段1之后。withtvm.build_config(add_lower_pass=[(1,vectorize)])ascfg:print(tvm.lower(sch,[a,b,c],simple_mode=True))输出:
    produce c {for (i.outer, 0, 16) {c[ramp((i.outer8), 1, 8)] = (a[ramp((i.outer8), 1, 8)] + b[ramp((i.outer*8), 1, 8)])}}## 快速查看 本教程提供了编写自定义IR转换pass的快速视图:tvm.ir_pass.PostOrderVisit用于收集每个IR node上的信息。 tvm.ir_pass.IRTransform用于转换IR node。封装以上两者以编写IR转换函数,使用tvm.build_config将此函数用于TVM下层的pass。