1 引言
在tvm中,自定义算子时,算子的参数,除了输入之外,可能还包括其他的属性,这时就需要设置属性的类型。
tvm已经定义了很多常用的属性。
c++端参考include/tvm/relay/attrs/目录下的文件。
python端参考:python/tvm/relay/op/op_attrs.py
例如,case算子的参数,除了输入的tensor之外,还需要指定转换的数据类型(输出的tensor的数据类型),所以,需要一个转换类型属性(CastAttrs),该属性类至少包含一个数据类型参数。
CastAttrs属性定义如下:
c++端:include/tvm/relay/attrs/transform.h
/*! \brief data type cast */
struct CastAttrs : public tvm::AttrsNode<CastAttrs> {
DataType dtype;
TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") {
TVM_ATTR_FIELD(dtype)
.describe("Target data type");
}
}; // struct CastAttrs.
python端:python/tvm/relay/op/op_attrs.py
@register_relay_attr_node
class CastAttrs(Attrs):
"""Attributes for transform.cast"""
然后,算子cast的定义就设置了该类型,
TVM_REGISTER_NODE_TYPE(CastAttrs);
RELAY_REGISTER_OP("cast")
.describe(R"code(Cast the data into a new data type.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type<CastAttrs>()
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Cast", CastRel)
.set_attr<FTVMCompute>("FTVMCompute", CastCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
如果已定义的属性类型中,没有符合要求的,就需要自定义一个属性类型。一般的,自定义属性类如下:
2 自定义属性
自定义一个属性,以”SaturateCastAttrs”为例。
2.1 c++端定义
一般的,属性类都是在include/tvm/relay/attrs/transform.h在中定义。
定义一个结构体,该结构体的名称以Attrs结尾,例如”SaturateCastAttrs”,该结构体含有三个属性:
DataType dtype;//Target data type
double a_min; //The minimum clip value.
double a_max; //The maximum clip value.
结构体如下:
/*! \brief data type cast */
struct SaturateCastAttrs : public tvm::AttrsNode<SaturateCastAttrs> {
DataType dtype;//Target data type
double a_min; //The minimum clip value.
double a_max; //The maximum clip value.
TVM_DECLARE_ATTRS(SaturateCastAttrs, "relay.attrs.SaturateCastAttrs") {
TVM_ATTR_FIELD(dtype).describe("Target data type");
TVM_ATTR_FIELD(a_min).describe("The minimum clip value.");
TVM_ATTR_FIELD(a_max).describe("The maximum clip value.");
}
}; // struct SaturateCastAttrs.
2.2 在c++端使用
(1)调用宏TVM_REGISTER_NODE_TYPE来注册属性SaturateCastAttrs
TVM_REGISTER_NODE_TYPE``(SaturateCastAttrs);
(2)设置算子的属性
RELAY_REGISTER_OP("vacc_saturate_cast")
...
.set_attrs_type<SaturateCastAttrs>()
...
2.3 python端注册(非必须)
如果要在python端直接使用SaturateCastAttrs属性,那就需要在python/tvm/relay/op/op_attrs.py中注册。
注册如下:
@register_relay_attr_node
class SaturateCastRel(Attrs):
"""Attributes for op `vacc_saturate_cast`"""