Simulate lazy loading of parameters in a callback The output of a lazy parameter loading, which would accept a callback to load the parameters.
()
| 33 | @pytest.mark.gpu |
| 34 | @pytest.mark.skipif(not env.has_nccl(), reason="need nccl") |
| 35 | def test_callback(): |
| 36 | """Simulate lazy loading of parameters in a callback |
| 37 | |
| 38 | The output of a lazy parameter loading, which would accept a |
| 39 | callback to load the parameters. |
| 40 | """ |
| 41 | |
| 42 | @R.function |
| 43 | def transform_params( |
| 44 | rank_arg: R.Prim(value="rank"), |
| 45 | fget_item: R.Callable([R.Object, R.Prim("int64")], R.Object), |
| 46 | ): |
| 47 | rank = T.int64() |
| 48 | |
| 49 | A = fget_item(R.str("A"), R.prim_value(0)) |
| 50 | A = R.match_cast(A, R.Tensor([4, 4], "int32")) |
| 51 | A = R.strided_slice(A, axes=[0], begin=[rank * 2], end=[(rank + 1) * 2]) |
| 52 | |
| 53 | B = fget_item(R.str("B"), R.prim_value(1)) |
| 54 | B = R.match_cast(B, R.Tensor([2, 2], "float32")) |
| 55 | B = R.strided_slice(B, axes=[1], begin=[rank * 1], end=[(rank + 1) * 1]) |
| 56 | |
| 57 | return (A, B) |
| 58 | |
| 59 | pipeline = tvm.ir.transform.Sequential( |
| 60 | [ |
| 61 | tvm.relax.transform.LegalizeOps(), |
| 62 | tvm.s_tir.dlight.ApplyDefaultSchedule(tvm.s_tir.dlight.gpu.Fallback()), |
| 63 | ], |
| 64 | name="pipeline", |
| 65 | ) |
| 66 | |
| 67 | with tvm.target.Target("cuda"): |
| 68 | mod = tvm.IRModule.from_expr(transform_params) |
| 69 | mod = pipeline(mod) |
| 70 | built = tvm.compile(mod, "cuda") |
| 71 | |
| 72 | num_shards = 2 |
| 73 | |
| 74 | session = tvm.runtime.disco.ProcessSession(num_workers=num_shards) |
| 75 | session.import_python_module("tvm.exec.disco_worker") |
| 76 | session.init_ccl("nccl", *range(num_shards)) |
| 77 | |
| 78 | worker_device = session.get_global_func("runtime.disco.device")() |
| 79 | worker_id = session.get_global_func("runtime.disco.worker_rank")() |
| 80 | callback_maker = session.get_global_func("tests.disco.test_callback") |
| 81 | fget_item = callback_maker(worker_device) |
| 82 | |
| 83 | with tempfile.TemporaryDirectory() as temp_dir: |
| 84 | temp_dir = pathlib.Path(temp_dir) |
| 85 | |
| 86 | # TODO(Lunderberg): Update `disco.Session.load_vm_module` to |
| 87 | # allow a `tvm.runtime.Module` argument. This would avoid the |
| 88 | # need for a temporary file. |
| 89 | shlib_path = temp_dir.joinpath("libtemp.so") |
| 90 | built.export_library(shlib_path) |
| 91 | vm = session.load_vm_module(shlib_path.as_posix()) |
| 92 | transform_params = vm["transform_params"] |
nothing calls this directly
no test coverage detected
searching dependent graphs…