()
| 61 | |
| 62 | |
| 63 | def test_trace_expr_sum_generated(): |
| 64 | @tvm.register_global_func("tvm.tirx.trace_callback3") |
| 65 | def trace_buffer(x): |
| 66 | return |
| 67 | |
| 68 | def check_expr_sum(dtype): |
| 69 | n = 4 |
| 70 | a = te.placeholder((n, n, n), name="a", dtype=dtype) |
| 71 | b = te.placeholder((n, n, n), name="b", dtype=dtype) |
| 72 | c = te.compute( |
| 73 | a.shape, |
| 74 | lambda i, j, k: ( |
| 75 | tvm.tirx.trace([a[i][j][k]], "tvm.tirx.trace_callback3") |
| 76 | + tvm.tirx.trace([b[i][j][k]], "tvm.tirx.trace_callback3") |
| 77 | ), |
| 78 | ) |
| 79 | f = tvm.compile(te.create_prim_func([a, b, c])) |
| 80 | xnd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=a.dtype))) |
| 81 | ynd = tvm.runtime.tensor(np.array(np.ones((n, n, n), dtype=b.dtype))) |
| 82 | znd = tvm.runtime.tensor(np.zeros((n, n, n), dtype=c.dtype)) |
| 83 | f(xnd, ynd, znd) |
| 84 | assert np.array_equal(znd.numpy(), xnd.numpy() + ynd.numpy()) |
| 85 | |
| 86 | for t in ["float64", "float32", "int64", "int32"]: |
| 87 | check_expr_sum(t) |
| 88 | |
| 89 | |
| 90 | def test_trace_expr_sum_args(): |
nothing calls this directly
no test coverage detected
searching dependent graphs…