MCPcopy
hub / github.com/apache/tvm / verify_batch_matmul

Function verify_batch_matmul

tests/python/contrib/test_cblas.py:198–265  ·  view source on GitHub ↗

Tests matmul op where matrices are in batch

(
    batch_a,
    batch_b,
    matrix_m,
    matrix_l,
    matrix_n,
    lib,
    transa=False,
    transb=False,
    dtype="float32",
)

Source from the content-addressed store, hash-verified

196
197
198def verify_batch_matmul(
199 batch_a,
200 batch_b,
201 matrix_m,
202 matrix_l,
203 matrix_n,
204 lib,
205 transa=False,
206 transb=False,
207 dtype="float32",
208):
209 """Tests matmul op where matrices are in batch"""
210 batch = max(batch_a, batch_b)
211 ashape = (batch_a, matrix_l, matrix_n) if transa else (batch_a, matrix_n, matrix_l)
212 bshape = (batch_b, matrix_m, matrix_l) if transb else (batch_b, matrix_l, matrix_m)
213 input1_data = te.placeholder(ashape, name="input1_data", dtype=dtype)
214 input2_data = te.placeholder(bshape, name="input2_data", dtype=dtype)
215 matmul_result = lib.batch_matmul(input1_data, input2_data, transa, transb)
216 final_result = te.compute(
217 matmul_result.shape, lambda k, i, j: matmul_result[k, i, j], name="final_result"
218 )
219
220 def get_numpy(a, b, transa, transb):
221 if transa:
222 a = a.transpose(0, 2, 1)
223 if not transb:
224 b = b.transpose(0, 2, 1)
225 return tvm.topi.testing.batch_matmul(a, b)
226
227 def compiling(f, name="test_batch_matmul", ext=".so"):
228 path = name + ext
229 f.export_library(path)
230 mod = tvm.runtime.load_module(path)
231 f = mod[name]
232 return f
233
234 def verify(target="llvm"):
235 if not tvm.testing.device_enabled(target):
236 print(f"skip because {target} is not enabled...")
237 return
238 if not tvm.get_global_func(lib.__name__ + ".matmul", True):
239 print("skip because extern function is not available")
240 return
241 dev = tvm.cpu(0)
242 name = "test_batch_matmul"
243 f = tvm.compile(
244 te.create_prim_func([input1_data, input2_data, final_result]), target=target
245 )
246 if target == "c":
247 f = compiling(f, name)
248 matrix_input1 = tvm.runtime.tensor(
249 np.random.uniform(size=ashape).astype(input1_data.dtype), dev
250 )
251 matrix_input2 = tvm.runtime.tensor(
252 np.random.uniform(size=bshape).astype(input2_data.dtype), dev
253 )
254 matrix_result = tvm.runtime.tensor(
255 np.zeros((batch, matrix_n, matrix_m), dtype=final_result.dtype), dev

Callers 1

test_batch_matmulFunction · 0.70

Calls 4

placeholderMethod · 0.80
batch_matmulMethod · 0.80
verifyFunction · 0.70
maxFunction · 0.50

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…