(target="rocm")
| 34 | C = hipblas.matmul(A, B, dtype=out_dtype) |
| 35 | |
| 36 | def verify(target="rocm"): |
| 37 | if not tvm.get_global_func("tvm.contrib.hipblas.matmul", True): |
| 38 | print("skip because extern function is not available") |
| 39 | return |
| 40 | dev = tvm.rocm(0) |
| 41 | f = tvm.compile(te.create_prim_func([A, B, C]), target=target) |
| 42 | a = tvm.runtime.tensor(np.random.uniform(0, 128, size=(n, l)).astype(A.dtype), dev) |
| 43 | b = tvm.runtime.tensor(np.random.uniform(0, 128, size=(l, m)).astype(B.dtype), dev) |
| 44 | c = tvm.runtime.tensor(np.zeros((n, m), dtype=C.dtype), dev) |
| 45 | f(a, b, c) |
| 46 | tvm.testing.assert_allclose( |
| 47 | c.numpy(), np.dot(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)), rtol=rtol |
| 48 | ) |
| 49 | |
| 50 | verify() |
| 51 |
no test coverage detected
searching dependent graphs…