Tests matmul op where matrices are in batch
(
batch_a,
batch_b,
matrix_m,
matrix_l,
matrix_n,
lib,
transa=False,
transb=False,
dtype="float32",
)
| 196 | |
| 197 | |
| 198 | def verify_batch_matmul( |
| 199 | batch_a, |
| 200 | batch_b, |
| 201 | matrix_m, |
| 202 | matrix_l, |
| 203 | matrix_n, |
| 204 | lib, |
| 205 | transa=False, |
| 206 | transb=False, |
| 207 | dtype="float32", |
| 208 | ): |
| 209 | """Tests matmul op where matrices are in batch""" |
| 210 | batch = max(batch_a, batch_b) |
| 211 | ashape = (batch_a, matrix_l, matrix_n) if transa else (batch_a, matrix_n, matrix_l) |
| 212 | bshape = (batch_b, matrix_m, matrix_l) if transb else (batch_b, matrix_l, matrix_m) |
| 213 | input1_data = te.placeholder(ashape, name="input1_data", dtype=dtype) |
| 214 | input2_data = te.placeholder(bshape, name="input2_data", dtype=dtype) |
| 215 | matmul_result = lib.batch_matmul(input1_data, input2_data, transa, transb) |
| 216 | final_result = te.compute( |
| 217 | matmul_result.shape, lambda k, i, j: matmul_result[k, i, j], name="final_result" |
| 218 | ) |
| 219 | |
| 220 | def get_numpy(a, b, transa, transb): |
| 221 | if transa: |
| 222 | a = a.transpose(0, 2, 1) |
| 223 | if not transb: |
| 224 | b = b.transpose(0, 2, 1) |
| 225 | return tvm.topi.testing.batch_matmul(a, b) |
| 226 | |
| 227 | def compiling(f, name="test_batch_matmul", ext=".so"): |
| 228 | path = name + ext |
| 229 | f.export_library(path) |
| 230 | mod = tvm.runtime.load_module(path) |
| 231 | f = mod[name] |
| 232 | return f |
| 233 | |
| 234 | def verify(target="llvm"): |
| 235 | if not tvm.testing.device_enabled(target): |
| 236 | print(f"skip because {target} is not enabled...") |
| 237 | return |
| 238 | if not tvm.get_global_func(lib.__name__ + ".matmul", True): |
| 239 | print("skip because extern function is not available") |
| 240 | return |
| 241 | dev = tvm.cpu(0) |
| 242 | name = "test_batch_matmul" |
| 243 | f = tvm.compile( |
| 244 | te.create_prim_func([input1_data, input2_data, final_result]), target=target |
| 245 | ) |
| 246 | if target == "c": |
| 247 | f = compiling(f, name) |
| 248 | matrix_input1 = tvm.runtime.tensor( |
| 249 | np.random.uniform(size=ashape).astype(input1_data.dtype), dev |
| 250 | ) |
| 251 | matrix_input2 = tvm.runtime.tensor( |
| 252 | np.random.uniform(size=bshape).astype(input2_data.dtype), dev |
| 253 | ) |
| 254 | matrix_result = tvm.runtime.tensor( |
| 255 | np.zeros((batch, matrix_n, matrix_m), dtype=final_result.dtype), dev |
no test coverage detected
searching dependent graphs…