(dev, n, value, dtype)
| 29 | @pytest.mark.skipif(not env.has_rocm(), reason="need rocm") |
| 30 | def test_rocm_inf_nan(): |
| 31 | def check_inf_nan(dev, n, value, dtype): |
| 32 | @I.ir_module(s_tir=True) |
| 33 | class Module: |
| 34 | @T.prim_func(s_tir=True) |
| 35 | def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): |
| 36 | T.func_attr({"tirx.noalias": True}) |
| 37 | for i_0 in T.thread_binding(1, thread="blockIdx.x"): |
| 38 | for i_1 in T.thread_binding(128, thread="threadIdx.x"): |
| 39 | with T.sblock("C"): |
| 40 | v_i = T.axis.spatial(1, i_0 * 128 + i_1) |
| 41 | T.where(i_0 * 128 + i_1 < 1) |
| 42 | T.reads() |
| 43 | T.writes(C[v_i]) |
| 44 | C[v_i] = T.Cast(dtype, value) |
| 45 | |
| 46 | fun = tvm.compile(Module, "rocm") |
| 47 | a = tvm.runtime.empty((n,), dtype, dev) |
| 48 | c = tvm.runtime.empty((n,), dtype, dev) |
| 49 | # Only need to test compiling here |
| 50 | fun(a, c) |
| 51 | |
| 52 | dev = tvm.rocm(0) |
| 53 |
no test coverage detected
searching dependent graphs…