(session_kind)
| 268 | |
| 269 | @pytest.mark.parametrize("session_kind", _all_session_kinds) |
| 270 | def test_vm_multi_func(session_kind): |
| 271 | num_workers = 4 |
| 272 | sess = session_kind(num_workers=num_workers) |
| 273 | |
| 274 | # pylint: disable=invalid-name |
| 275 | @I.ir_module(s_tir=True) |
| 276 | class TestMod: |
| 277 | @T.prim_func(s_tir=True) |
| 278 | def t1(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")): |
| 279 | for i, j in T.grid(16, 8): |
| 280 | with T.sblock("t1"): |
| 281 | vi, vj = T.axis.remap("SS", [i, j]) |
| 282 | B[vi, vj] = A[vj, vi] |
| 283 | |
| 284 | @T.prim_func(s_tir=True) |
| 285 | def t2(A: T.Buffer((16, 8), "float32"), B: T.Buffer((8, 16), "float32")): |
| 286 | for i, j in T.grid(8, 16): |
| 287 | with T.sblock("t2"): |
| 288 | vi, vj = T.axis.remap("SS", [i, j]) |
| 289 | B[vi, vj] = A[vj, vi] |
| 290 | |
| 291 | @R.function |
| 292 | def transpose_1(A: R.Tensor((8, 16), dtype="float32")) -> R.Tensor( |
| 293 | (16, 8), dtype="float32" |
| 294 | ): |
| 295 | R.func_attr({"global_symbol": "transpose_1"}) |
| 296 | cls = TestMod |
| 297 | with R.dataflow(): |
| 298 | B = R.call_tir(cls.t1, (A,), out_sinfo=R.Tensor((16, 8), dtype="float32")) |
| 299 | R.output(B) |
| 300 | return B |
| 301 | |
| 302 | @R.function |
| 303 | def transpose_2(A: R.Tensor((16, 8), dtype="float32")) -> R.Tensor( |
| 304 | (8, 16), dtype="float32" |
| 305 | ): |
| 306 | R.func_attr({"global_symbol": "transpose_2"}) |
| 307 | cls = TestMod |
| 308 | with R.dataflow(): |
| 309 | B = R.call_tir(cls.t2, (A,), out_sinfo=R.Tensor((8, 16), dtype="float32")) |
| 310 | R.output(B) |
| 311 | return B |
| 312 | |
| 313 | # pylint: enable=invalid-name |
| 314 | with tempfile.TemporaryDirectory() as tmpdir: |
| 315 | path = tmpdir + "/test.so" |
| 316 | device = tvm.cpu() |
| 317 | x_np = np.arange(8 * 16).astype("float32").reshape([8, 16]) |
| 318 | y_np = x_np.transpose() |
| 319 | |
| 320 | tvm.compile(TestMod, target="llvm").export_library(path) |
| 321 | mod = sess.load_vm_module(path, device=device) |
| 322 | |
| 323 | x_disc = _numpy_to_worker_0(sess, x_np, device=device) |
| 324 | y_disc = mod["transpose_1"](x_disc) |
| 325 | z_disc = mod["transpose_2"](y_disc) |
| 326 | y_nd = _numpy_from_worker_0(sess, y_disc, shape=y_np.shape, dtype=y_np.dtype) |
| 327 | z_nd = _numpy_from_worker_0(sess, z_disc, shape=x_np.shape, dtype=x_np.dtype) |
nothing calls this directly
no test coverage detected
searching dependent graphs…