MCPcopy
hub / github.com/apache/tvm / test_rocm_inf_nan

Function test_rocm_inf_nan

tests/python/codegen/test_target_codegen_rocm.py:30–59  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

28@pytest.mark.gpu
29@pytest.mark.skipif(not env.has_rocm(), reason="need rocm")
30def test_rocm_inf_nan():
31 def check_inf_nan(dev, n, value, dtype):
32 @I.ir_module(s_tir=True)
33 class Module:
34 @T.prim_func(s_tir=True)
35 def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)):
36 T.func_attr({"tirx.noalias": True})
37 for i_0 in T.thread_binding(1, thread="blockIdx.x"):
38 for i_1 in T.thread_binding(128, thread="threadIdx.x"):
39 with T.sblock("C"):
40 v_i = T.axis.spatial(1, i_0 * 128 + i_1)
41 T.where(i_0 * 128 + i_1 < 1)
42 T.reads()
43 T.writes(C[v_i])
44 C[v_i] = T.Cast(dtype, value)
45
46 fun = tvm.compile(Module, "rocm")
47 a = tvm.runtime.empty((n,), dtype, dev)
48 c = tvm.runtime.empty((n,), dtype, dev)
49 # Only need to test compiling here
50 fun(a, c)
51
52 dev = tvm.rocm(0)
53
54 check_inf_nan(dev, 1, -float("inf"), "float32")
55 check_inf_nan(dev, 1, -float("inf"), "float64")
56 check_inf_nan(dev, 1, float("inf"), "float32")
57 check_inf_nan(dev, 1, float("inf"), "float64")
58 check_inf_nan(dev, 1, float("nan"), "float32")
59 check_inf_nan(dev, 1, float("nan"), "float64")
60
61
62@pytest.mark.gpu

Callers

nothing calls this directly

Calls 2

rocmMethod · 0.80
check_inf_nanFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…