MCPcopy Index your code
hub / github.com/apache/tvm / test_trace_expr_sum_generated

Function test_trace_expr_sum_generated

tests/python/runtime/test_runtime_trace.py:63–87  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

61
62
63def test_trace_expr_sum_generated():
64 @tvm.register_global_func("tvm.tirx.trace_callback3")
65 def trace_buffer(x):
66 return
67
68 def check_expr_sum(dtype):
69 n = 4
70 a = te.placeholder((n, n, n), name="a", dtype=dtype)
71 b = te.placeholder((n, n, n), name="b", dtype=dtype)
72 c = te.compute(
73 a.shape,
74 lambda i, j, k: (
75 tvm.tirx.trace([a[i][j][k]], "tvm.tirx.trace_callback3")
76 + tvm.tirx.trace([b[i][j][k]], "tvm.tirx.trace_callback3")
77 ),
78 )
79 f = tvm.compile(te.create_prim_func([a, b, c]))
80 xnd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=a.dtype)))
81 ynd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=b.dtype)))
82 znd = tvm.runtime.tensor(np.zeros((n, n, n), dtype=c.dtype))
83 f(xnd, ynd, znd)
84 assert np.array_equal(znd.numpy(), xnd.numpy() + ynd.numpy())
85
86 for t in ["float64", "float32", "int64", "int32"]:
87 check_expr_sum(t)
88
89
90def test_trace_expr_sum_args():

Callers

nothing calls this directly

Calls 1

check_expr_sumFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…