| 139 | @pytest.mark.gpu |
| 140 | @pytest.mark.skipif(not env.has_rocm(), reason="need rocm") |
| 141 | def test_rocm_vectorized_exp(): |
| 142 | @T.prim_func(s_tir=True) |
| 143 | def func( |
| 144 | A_handle: T.handle, |
| 145 | B_handle: T.handle, |
| 146 | ): |
| 147 | A = T.match_buffer(A_handle, (4,), dtype="float32") |
| 148 | B = T.match_buffer(B_handle, (4,), dtype="float32") |
| 149 | |
| 150 | for bx in T.thread_binding(1, thread="blockIdx.x"): |
| 151 | for tx in T.thread_binding(1, thread="threadIdx.x"): |
| 152 | with T.sblock("test"): |
| 153 | for i in T.vectorized(0, 4): |
| 154 | B[i] = T.exp2(A[i]) |
| 155 | |
| 156 | mod = tvm.compile(func, target="rocm") |
| 157 | dev = tvm.rocm(0) |
| 158 | a = tvm.runtime.tensor(np.ones((4,)).astype("float32"), dev) |
| 159 | b = tvm.runtime.tensor(np.zeros((4,)).astype("float32"), dev) |
| 160 | mod(a, b) |
| 161 | tvm.testing.assert_allclose(b.numpy(), np.exp2(a.numpy())) |
| 162 | |
| 163 | |
| 164 | @pytest.mark.gpu |