MCPcopy Index your code
hub / github.com/apache/tvm / test_rocm_vectorize_add

Function test_rocm_vectorize_add

tests/python/codegen/test_target_codegen_rocm.py:82–109  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

80@pytest.mark.gpu
81@pytest.mark.skipif(not env.has_rocm(), reason="need rocm")
82def test_rocm_vectorize_add():
83 def check_rocm(dtype, n, lanes):
84 vec_dtype = f"{dtype}x{lanes}"
85 num_blocks = n // 4
86
87 @I.ir_module(s_tir=True)
88 class Module:
89 @T.prim_func(s_tir=True)
90 def main(A: T.Buffer((n,), vec_dtype), B: T.Buffer((n,), vec_dtype)):
91 T.func_attr({"tirx.noalias": True})
92 for i_0 in T.thread_binding(num_blocks, thread="blockIdx.x"):
93 for i_1 in T.thread_binding(4, thread="threadIdx.x"):
94 with T.sblock("B"):
95 v_i = T.axis.spatial(n, i_0 * 4 + i_1)
96 T.reads(A[v_i])
97 T.writes(B[v_i])
98 B[v_i] = A[v_i] + T.Broadcast(T.Cast(dtype, 1), lanes)
99
100 fun = tvm.compile(Module, target="rocm")
101
102 dev = tvm.rocm(0)
103 a = tvm.runtime.empty((n,), vec_dtype, dev).copyfrom(np.random.uniform(size=(n, lanes)))
104 c = tvm.runtime.empty((n,), vec_dtype, dev)
105 fun(a, c)
106 tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1)
107
108 check_rocm("float32", 64, 2)
109 check_rocm("float16", 64, 2)
110
111
112@pytest.mark.gpu

Callers

nothing calls this directly

Calls 1

check_rocmFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…