MCPcopy Index your code
hub / github.com/apache/tvm / test_trace_expr_sum_args

Function test_trace_expr_sum_args

tests/python/runtime/test_runtime_trace.py:90–123  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

88
89
90def test_trace_expr_sum_args():
91 @tvm.register_global_func("tvm.tirx.trace_silent")
92 def silent(*args):
93 return
94
95 def check_expr_sum(dtype):
96 n = 4
97 a = te.placeholder((n, n, n), name="a", dtype=dtype)
98 b = te.placeholder((n, n, n), name="b", dtype=dtype)
99 e = te.placeholder((n, n, n), name="e", dtype=dtype)
100 d = te.placeholder((n, n, n), name="d", dtype=dtype)
101
102 c = te.compute(
103 a.shape,
104 lambda i, j, k: (
105 tvm.tirx.trace([i, j, k, a[i][j][k]], "tvm.tirx.trace_silent")
106 + tvm.tirx.trace([i, j, k, b[i][j][k]], "tvm.tirx.trace_silent")
107 + tvm.tirx.trace([i, j, k, d[i][j][k]], "tvm.tirx.trace_silent")
108 + tvm.tirx.trace([i, j, k, e[i][j][k]], "tvm.tirx.trace_silent")
109 ),
110 )
111 f = tvm.compile(te.create_prim_func([a, b, d, e, c]))
112 a_nd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=a.dtype)))
113 b_nd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=b.dtype)))
114 d_nd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=d.dtype)))
115 e_nd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=e.dtype)))
116 c_nd = tvm.runtime.tensor(np.zeros((n, n, n), dtype=c.dtype))
117 f(a_nd, b_nd, d_nd, e_nd, c_nd)
118 assert np.array_equal(
119 c_nd.numpy(), a_nd.numpy() + b_nd.numpy() + d_nd.numpy() + e_nd.numpy()
120 )
121
122 for t in ["float64", "float32", "int64", "int32"]:
123 check_expr_sum(t)
124
125
126def test_trace_expr_sum_custom():

Callers

nothing calls this directly

Calls 1

check_expr_sumFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…