expr.const函数,用于根据指定的值创建一个Constant类。
注意,value可以是bool, int, float, numpy.ndarray, tvm.nd.NDArray
python/tvm/relay/expr.py
def const(value, dtype=None):
"""Create a constant value.
Parameters
----------
value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray]
The constant value.
dtype: str, optional
The data type of the value.
Note
----
When dtype is None, we use the following rule:
- int maps to "int32"
- float maps to "float32"
- bool maps to "bool"
- other using the same default rule as numpy.
"""
if isinstance(value, (_base.numeric_types, (bool, list))):
value = _np.array(value, dtype=dtype)
if not dtype:
# when dtype is None: int maps to "int32", float maps to "float32"
map_dtype = {
_np.dtype('int64'): _np.int32,
_np.dtype('float64'): _np.float32
}.get(value.dtype, None)
if map_dtype:
value = value.astype(map_dtype)
if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value)
if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")
return Constant(value)
该函数有一个bug, 如果value是_np.ndarray 或_np.generic类型,提供的dtype是无效的。如果要指定dtype,需要先将value.astype(dtype)。