1 引言

在tvm中,自定义算子时,算子的参数,除了输入之外,可能还包括其他的属性,这时就需要设置属性的类型。

tvm已经定义了很多常用的属性。
c++端参考include/tvm/relay/attrs/目录下的文件。
python端参考:python/tvm/relay/op/op_attrs.py

例如,case算子的参数,除了输入的tensor之外,还需要指定转换的数据类型(输出的tensor的数据类型),所以,需要一个转换类型属性(CastAttrs),该属性类至少包含一个数据类型参数。

CastAttrs属性定义如下:
c++端:include/tvm/relay/attrs/transform.h

  1. /*! \brief data type cast */
  2. struct CastAttrs : public tvm::AttrsNode<CastAttrs> {
  3. DataType dtype;
  4. TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") {
  5. TVM_ATTR_FIELD(dtype)
  6. .describe("Target data type");
  7. }
  8. }; // struct CastAttrs.

python端:python/tvm/relay/op/op_attrs.py

  1. @register_relay_attr_node
  2. class CastAttrs(Attrs):
  3. """Attributes for transform.cast"""

然后,算子cast的定义就设置了该类型,

  1. TVM_REGISTER_NODE_TYPE(CastAttrs);
  2. RELAY_REGISTER_OP("cast")
  3. .describe(R"code(Cast the data into a new data type.
  4. )code" TVM_ADD_FILELINE)
  5. .set_num_inputs(1)
  6. .set_attrs_type<CastAttrs>()
  7. .add_argument("data", "Tensor", "The input tensor.")
  8. .set_support_level(3)
  9. .add_type_rel("Cast", CastRel)
  10. .set_attr<FTVMCompute>("FTVMCompute", CastCompute)
  11. .set_attr<TOpPattern>("TOpPattern", kElemWise)
  12. .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

如果已定义的属性类型中,没有符合要求的,就需要自定义一个属性类型。一般的,自定义属性类如下:

2 自定义属性

自定义一个属性,以”SaturateCastAttrs”为例。

2.1 c++端定义

一般的,属性类都是在include/tvm/relay/attrs/transform.h在中定义。
定义一个结构体,该结构体的名称以Attrs结尾,例如”SaturateCastAttrs”,该结构体含有三个属性:
DataType dtype;//Target data type
double a_min; //The minimum clip value.
double a_max; //The maximum clip value.
结构体如下:

  1. /*! \brief data type cast */
  2. struct SaturateCastAttrs : public tvm::AttrsNode<SaturateCastAttrs> {
  3. DataType dtype;//Target data type
  4. double a_min; //The minimum clip value.
  5. double a_max; //The maximum clip value.
  6. TVM_DECLARE_ATTRS(SaturateCastAttrs, "relay.attrs.SaturateCastAttrs") {
  7. TVM_ATTR_FIELD(dtype).describe("Target data type");
  8. TVM_ATTR_FIELD(a_min).describe("The minimum clip value.");
  9. TVM_ATTR_FIELD(a_max).describe("The maximum clip value.");
  10. }
  11. }; // struct SaturateCastAttrs.

2.2 在c++端使用

(1)调用宏TVM_REGISTER_NODE_TYPE来注册属性SaturateCastAttrs

TVM_REGISTER_NODE_TYPE``(SaturateCastAttrs);

(2)设置算子的属性

  1. RELAY_REGISTER_OP("vacc_saturate_cast")
  2. ...
  3. .set_attrs_type<SaturateCastAttrs>()
  4. ...

2.3 python端注册(非必须)

如果要在python端直接使用SaturateCastAttrs属性,那就需要在python/tvm/relay/op/op_attrs.py中注册。
注册如下:

  1. @register_relay_attr_node
  2. class SaturateCastRel(Attrs):
  3. """Attributes for op `vacc_saturate_cast`"""