MCPcopy Index your code
hub / github.com/apache/tvm / test_trace_expr_assign

Function test_trace_expr_assign

tests/python/runtime/test_runtime_trace.py:34–60  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

32
33
34def test_trace_expr_assign():
35 @tvm.register_global_func("tvm.tirx.trace_callback2")
36 def trace_buffer(x):
37 return
38
39 def check_assign(dtype):
40 n = 4
41 x = te.placeholder((n, n, n), name="X", dtype=dtype)
42 y = te.compute(
43 x.shape, lambda i, j, k: tvm.tirx.trace([x[i][j][k]], "tvm.tirx.trace_callback2")
44 )
45 z = te.compute(
46 x.shape, lambda i, j, k: tvm.tirx.trace([y[i][j][k]], "tvm.tirx.trace_callback2")
47 )
48 f = tvm.compile(te.create_prim_func([x, y, z]), "llvm")
49
50 xnd = tvm.runtime.tensor(np.ones((n, n, n), dtype=x.dtype))
51 ynd = tvm.runtime.tensor(np.zeros((n, n, n), dtype=y.dtype))
52 znd = tvm.runtime.tensor(np.zeros((n, n, n), dtype=z.dtype))
53 f(xnd, ynd, znd)
54
55 assert np.array_equal(xnd.numpy(), np.ones((n, n, n)))
56 assert np.array_equal(ynd.numpy(), np.ones((n, n, n)))
57 assert np.array_equal(znd.numpy(), np.ones((n, n, n)))
58
59 for t in ["float64", "float32", "int64", "int32"]:
60 check_assign(t)
61
62
63def test_trace_expr_sum_generated():

Callers

nothing calls this directly

Calls 1

check_assignFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…