MCPcopy Index your code
hub / github.com/apache/tvm / test_vm_multi_func

Function test_vm_multi_func

tests/python/disco/test_session.py:270–334  ·  view source on GitHub ↗
(session_kind)

Source from the content-addressed store, hash-verified

268
269@pytest.mark.parametrize("session_kind", _all_session_kinds)
270def test_vm_multi_func(session_kind):
271 num_workers = 4
272 sess = session_kind(num_workers=num_workers)
273
274 # pylint: disable=invalid-name
275 @I.ir_module(s_tir=True)
276 class TestMod:
277 @T.prim_func(s_tir=True)
278 def t1(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")):
279 for i, j in T.grid(16, 8):
280 with T.sblock("t1"):
281 vi, vj = T.axis.remap("SS", [i, j])
282 B[vi, vj] = A[vj, vi]
283
284 @T.prim_func(s_tir=True)
285 def t2(A: T.Buffer((16, 8), "float32"), B: T.Buffer((8, 16), "float32")):
286 for i, j in T.grid(8, 16):
287 with T.sblock("t2"):
288 vi, vj = T.axis.remap("SS", [i, j])
289 B[vi, vj] = A[vj, vi]
290
291 @R.function
292 def transpose_1(A: R.Tensor((8, 16), dtype="float32")) -> R.Tensor(
293 (16, 8), dtype="float32"
294 ):
295 R.func_attr({"global_symbol": "transpose_1"})
296 cls = TestMod
297 with R.dataflow():
298 B = R.call_tir(cls.t1, (A,), out_sinfo=R.Tensor((16, 8), dtype="float32"))
299 R.output(B)
300 return B
301
302 @R.function
303 def transpose_2(A: R.Tensor((16, 8), dtype="float32")) -> R.Tensor(
304 (8, 16), dtype="float32"
305 ):
306 R.func_attr({"global_symbol": "transpose_2"})
307 cls = TestMod
308 with R.dataflow():
309 B = R.call_tir(cls.t2, (A,), out_sinfo=R.Tensor((8, 16), dtype="float32"))
310 R.output(B)
311 return B
312
313 # pylint: enable=invalid-name
314 with tempfile.TemporaryDirectory() as tmpdir:
315 path = tmpdir + "/test.so"
316 device = tvm.cpu()
317 x_np = np.arange(8 * 16).astype("float32").reshape([8, 16])
318 y_np = x_np.transpose()
319
320 tvm.compile(TestMod, target="llvm").export_library(path)
321 mod = sess.load_vm_module(path, device=device)
322
323 x_disc = _numpy_to_worker_0(sess, x_np, device=device)
324 y_disc = mod["transpose_1"](x_disc)
325 z_disc = mod["transpose_2"](y_disc)
326 y_nd = _numpy_from_worker_0(sess, y_disc, shape=y_np.shape, dtype=y_np.dtype)
327 z_nd = _numpy_from_worker_0(sess, z_disc, shape=x_np.shape, dtype=x_np.dtype)

Callers

nothing calls this directly

Calls 11

_numpy_to_worker_0Function · 0.85
_numpy_from_worker_0Function · 0.85
load_vm_moduleMethod · 0.80
_sync_workerMethod · 0.80
cpuMethod · 0.45
reshapeMethod · 0.45
astypeMethod · 0.45
arangeMethod · 0.45
transposeMethod · 0.45
export_libraryMethod · 0.45
compileMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…