(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5)
| 55 | |
| 56 | |
| 57 | def verify_batch_matmul(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5): |
| 58 | A = te.placeholder(Ashape, name="A", dtype=in_dtype) |
| 59 | B = te.placeholder(Bshape, name="B", dtype=in_dtype) |
| 60 | C = hipblas.batch_matmul(A, B, dtype=out_dtype) |
| 61 | |
| 62 | dev = tvm.rocm(0) |
| 63 | f = tvm.compile(te.create_prim_func([A, B, C]), target="rocm") |
| 64 | |
| 65 | if "int" in in_dtype: |
| 66 | a = tvm.runtime.tensor(np.random.uniform(1, 10, size=Ashape).astype(in_dtype), dev) |
| 67 | b = tvm.runtime.tensor(np.random.uniform(1, 10, size=Bshape).astype(in_dtype), dev) |
| 68 | else: |
| 69 | a = tvm.runtime.tensor(np.random.uniform(size=Ashape).astype(A.dtype), dev) |
| 70 | b = tvm.runtime.tensor(np.random.uniform(size=Bshape).astype(B.dtype), dev) |
| 71 | |
| 72 | c = tvm.runtime.tensor(np.zeros(Cshape, dtype=C.dtype), dev) |
| 73 | f(a, b, c) |
| 74 | tvm.testing.assert_allclose( |
| 75 | c.numpy(), |
| 76 | np.matmul(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)).astype(C.dtype), |
| 77 | rtol=rtol, |
| 78 | ) |
| 79 | |
| 80 | |
| 81 | @pytest.mark.gpu |
no test coverage detected
searching dependent graphs…