MCPcopy Index your code
hub / github.com/apache/tvm / check_expr_sum

Function check_expr_sum

tests/python/runtime/test_runtime_trace.py:68–84  ·  view source on GitHub ↗
(dtype)

Source from the content-addressed store, hash-verified

66 return
67
68 def check_expr_sum(dtype):
69 n = 4
70 a = te.placeholder((n, n, n), name="a", dtype=dtype)
71 b = te.placeholder((n, n, n), name="b", dtype=dtype)
72 c = te.compute(
73 a.shape,
74 lambda i, j, k: (
75 tvm.tirx.trace([a[i][j][k]], "tvm.tirx.trace_callback3")
76 + tvm.tirx.trace([b[i][j][k]], "tvm.tirx.trace_callback3")
77 ),
78 )
79 f = tvm.compile(te.create_prim_func([a, b, c]))
80 xnd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=a.dtype)))
81 ynd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=b.dtype)))
82 znd = tvm.runtime.tensor(np.zeros((n, n, n), dtype=c.dtype))
83 f(xnd, ynd, znd)
84 assert np.array_equal(znd.numpy(), xnd.numpy() + ynd.numpy())
85
86 for t in ["float64", "float32", "int64", "int32"]:
87 check_expr_sum(t)

Callers 2

test_trace_expr_sum_argsFunction · 0.85

Calls 7

placeholderMethod · 0.80
traceMethod · 0.80
onesMethod · 0.80
numpyMethod · 0.80
fFunction · 0.50
compileMethod · 0.45
zerosMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…