()
| 22 | |
| 23 | |
| 24 | def test_trace_default_action(): |
| 25 | n = 2 |
| 26 | x = te.placeholder((n, n, n), name="X", dtype="float32") |
| 27 | y = te.compute(x.shape, lambda i, j, k: tvm.tirx.trace([i, j, k, x[i][j][k]])) |
| 28 | f = tvm.compile(te.create_prim_func([x, y]), target="llvm") |
| 29 | xnd = tvm.runtime.tensor(np.ones((n, n, n), dtype=x.dtype)) |
| 30 | ynd = tvm.runtime.tensor(np.zeros((n, n, n), dtype=y.dtype)) |
| 31 | f(xnd, ynd) |
| 32 | |
| 33 | |
| 34 | def test_trace_expr_assign(): |