MCPcopy
hub / github.com/apache/tvm / verify_batch_matmul

Function verify_batch_matmul

tests/python/contrib/test_hipblas.py:57–78  ·  view source on GitHub ↗
(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5)

Source from the content-addressed store, hash-verified

55
56
57def verify_batch_matmul(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5):
58 A = te.placeholder(Ashape, name="A", dtype=in_dtype)
59 B = te.placeholder(Bshape, name="B", dtype=in_dtype)
60 C = hipblas.batch_matmul(A, B, dtype=out_dtype)
61
62 dev = tvm.rocm(0)
63 f = tvm.compile(te.create_prim_func([A, B, C]), target="rocm")
64
65 if "int" in in_dtype:
66 a = tvm.runtime.tensor(np.random.uniform(1, 10, size=Ashape).astype(in_dtype), dev)
67 b = tvm.runtime.tensor(np.random.uniform(1, 10, size=Bshape).astype(in_dtype), dev)
68 else:
69 a = tvm.runtime.tensor(np.random.uniform(size=Ashape).astype(A.dtype), dev)
70 b = tvm.runtime.tensor(np.random.uniform(size=Bshape).astype(B.dtype), dev)
71
72 c = tvm.runtime.tensor(np.zeros(Cshape, dtype=C.dtype), dev)
73 f(a, b, c)
74 tvm.testing.assert_allclose(
75 c.numpy(),
76 np.matmul(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)).astype(C.dtype),
77 rtol=rtol,
78 )
79
80
81@pytest.mark.gpu

Callers 1

test_batch_matmulFunction · 0.70

Calls 10

placeholderMethod · 0.80
batch_matmulMethod · 0.80
rocmMethod · 0.80
uniformMethod · 0.80
numpyMethod · 0.80
fFunction · 0.50
compileMethod · 0.45
astypeMethod · 0.45
zerosMethod · 0.45
matmulMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…