比较2个Relay expr的结构图/数据流图是否一致
analysis.alpha_equal
analysis.graph_equal
定义文件:python/tvm/relay/analysis.py
# file : python/tvm/relay/analysis.py
def alpha_equal(lhs, rhs):
"""
比较两个Relay expr的结构等效(alpha等效)。
"""
return bool(_make._alpha_equal(lhs, rhs))
def assert_alpha_equal(lhs, rhs):
_make._assert_alpha_equal(lhs, rhs)
def graph_equal(lhs, rhs):
"""
比较两个Relay expr的数据流等效。
这个和alpha_equal的区别在于它不要求变量lhs和rhs匹配,它们被视为源并相互映射。
"""
return bool(_make._graph_equal(lhs, rhs))
def assert_graph_equal(lhs, rhs):
_make._assert_graph_equal(lhs, rhs)
比较2个tensor值是否一致
tvm.testing.assert_allclose
定义文件:python/tvm/testing.py
def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
""" Version of np.testing.assert_allclose with `atol` and `rtol` fields set
in reasonable defaults.
Arguments `actual` and `desired` are not interchangable, since the function
compares the `abs(actual-desired)` with `atol+rtol*abs(desired)`. Since we
often allow `desired` to be close to zero, we generally want non-zero `atol`.
"""
np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)
它和np.testing.assert_allclose等效,其中带有“ atol”和“ rtol”字段的设置为合理的默认值。
numpy中还有其它的比较函数,如下:
# numpy.testing.assert_allclose(actual, desired, rtol=1e-07, atol=0, equal_nan=True, err_msg='', verbose=True)
# Raises an AssertionError if two objects are not equal up to desired tolerance.
# The test is equivalent to allclose(actual, desired, rtol, atol) (note that allclose has different default values). It compares the difference between actual and desired to atol + rtol * abs(desired).
x = [1e-5, 1e-3, 1e-1]
y = np.arccos(np.cos(x))
np.testing.assert_allclose(x, y, rtol=1e-5, atol=0)
# numpy.testing.assert_array_almost_equal_nulp(x, y, nulp=1)
# Compare two arrays relatively to their spacing.
# This is a relatively robust method to compare two arrays whose amplitude is variable.
x = np.array([1., 1e-10, 1e-20])
eps = np.finfo(x.dtype).eps
np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x)
# numpy.testing.assert_array_max_ulp(a, b, maxulp=1, dtype=None)
# Check that all items of arrays differ in at most N Units in the Last Place.
a = np.linspace(0., 1., 100)
res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a)))
# numpy.testing.assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True)[source]
# Raises an AssertionError if two objects are not equal up to desired precision.
assert_allclose的示例:
def test_any_concat():
x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
y = relay.var('y', shape=(1, 2), dtype="float32")
z = relay.op.concatenate([x, y], axis=0)
mod = relay.module.Module()
mod["main"] = relay.Function([x, y], z)
x_np = np.random.uniform(size=(3, 2)).astype('float32')
y_np = np.random.uniform(size=(1, 2)).astype('float32')
ref = np.concatenate([x_np, y_np], axis=0)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np)
tvm.testing.assert_allclose(result.asnumpy(), ref)
import numpy as np
import tvm
from tvm import relay
from tvm.contrib.nvcc import have_fp16
def test_basic_build():
tgt = "llvm"
ctx = tvm.cpu()
# func
a = relay.var("a", dtype="float32", shape=(16, 8))
b = relay.var("b", dtype="float32", shape=(8, 8))
c = relay.var("c", dtype="float32", shape=(16, 8))
x = relay.nn.dense(a, b)
y = relay.nn.relu(x)
z = y + c
func = relay.Function([a, b, c], z)
A = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx)
B = tvm.nd.array(np.random.uniform(-1, 1, (8, 8)).astype("float32"), ctx=ctx)
C = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx)
params = {
"b" : B,
"c" : C
}
# build
targets = {
tvm.expr.IntImm("int32", ctx.device_type): tgt
}
g_json, mmod, params = relay.build(relay.Module.from_expr(func), targets, "llvm", params=params)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.set_input("a", A)
rt.load_params(relay.save_param_dict(params))
rt.run()
out = rt.get_output(0)
np.testing.assert_allclose(out.asnumpy(), np.maximum(np.dot(A.asnumpy(),B.asnumpy().T),0) + C.asnumpy(),atol=1e-5, rtol=1e-5)