1 引言
在tvm中,自定义算子时,算子的参数,除了输入之外,可能还包括其他的属性,这时就需要设置属性的类型。
tvm已经定义了很多常用的属性。
c++端参考include/tvm/relay/attrs/目录下的文件。
python端参考:python/tvm/relay/op/op_attrs.py
例如,case算子的参数,除了输入的tensor之外,还需要指定转换的数据类型(输出的tensor的数据类型),所以,需要一个转换类型属性(CastAttrs),该属性类至少包含一个数据类型参数。
CastAttrs属性定义如下:
c++端:include/tvm/relay/attrs/transform.h
/*! \brief data type cast */struct CastAttrs : public tvm::AttrsNode<CastAttrs> {DataType dtype;TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") {TVM_ATTR_FIELD(dtype).describe("Target data type");}}; // struct CastAttrs.
python端:python/tvm/relay/op/op_attrs.py
@register_relay_attr_nodeclass CastAttrs(Attrs):"""Attributes for transform.cast"""
然后,算子cast的定义就设置了该类型,
TVM_REGISTER_NODE_TYPE(CastAttrs);RELAY_REGISTER_OP("cast").describe(R"code(Cast the data into a new data type.)code" TVM_ADD_FILELINE).set_num_inputs(1).set_attrs_type<CastAttrs>().add_argument("data", "Tensor", "The input tensor.").set_support_level(3).add_type_rel("Cast", CastRel).set_attr<FTVMCompute>("FTVMCompute", CastCompute).set_attr<TOpPattern>("TOpPattern", kElemWise).set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
如果已定义的属性类型中,没有符合要求的,就需要自定义一个属性类型。一般的,自定义属性类如下:
2 自定义属性
自定义一个属性,以”SaturateCastAttrs”为例。
2.1 c++端定义
一般的,属性类都是在include/tvm/relay/attrs/transform.h在中定义。
定义一个结构体,该结构体的名称以Attrs结尾,例如”SaturateCastAttrs”,该结构体含有三个属性:
DataType dtype;//Target data type
double a_min; //The minimum clip value.
double a_max; //The maximum clip value.
结构体如下:
/*! \brief data type cast */struct SaturateCastAttrs : public tvm::AttrsNode<SaturateCastAttrs> {DataType dtype;//Target data typedouble a_min; //The minimum clip value.double a_max; //The maximum clip value.TVM_DECLARE_ATTRS(SaturateCastAttrs, "relay.attrs.SaturateCastAttrs") {TVM_ATTR_FIELD(dtype).describe("Target data type");TVM_ATTR_FIELD(a_min).describe("The minimum clip value.");TVM_ATTR_FIELD(a_max).describe("The maximum clip value.");}}; // struct SaturateCastAttrs.
2.2 在c++端使用
(1)调用宏TVM_REGISTER_NODE_TYPE来注册属性SaturateCastAttrs
TVM_REGISTER_NODE_TYPE``(SaturateCastAttrs);
(2)设置算子的属性
RELAY_REGISTER_OP("vacc_saturate_cast")....set_attrs_type<SaturateCastAttrs>()...
2.3 python端注册(非必须)
如果要在python端直接使用SaturateCastAttrs属性,那就需要在python/tvm/relay/op/op_attrs.py中注册。
注册如下:
@register_relay_attr_nodeclass SaturateCastRel(Attrs):"""Attributes for op `vacc_saturate_cast`"""
