(dtype, n)
| 63 | @pytest.mark.skipif(not env.has_rocm(), reason="need rocm") |
| 64 | def test_rocm_copy(): |
| 65 | def check_rocm(dtype, n): |
| 66 | dev = tvm.rocm(0) |
| 67 | a_np = np.random.uniform(size=(n,)).astype(dtype) |
| 68 | a = tvm.runtime.empty((n,), dtype, dev).copyfrom(a_np) |
| 69 | b_np = a.numpy() |
| 70 | tvm.testing.assert_allclose(a_np, b_np) |
| 71 | tvm.testing.assert_allclose(a_np, a.numpy()) |
| 72 | |
| 73 | for _ in range(100): |
| 74 | dtype = np.random.choice(["float32", "float16", "int8", "int32"]) |