MCPcopy Index your code
hub / github.com/apache/tvm / test_callback

Function test_callback

tests/python/disco/test_callback.py:35–130  ·  view source on GitHub ↗

Simulate lazy loading of parameters in a callback The output of a lazy parameter loading, which would accept a callback to load the parameters.

()

Source from the content-addressed store, hash-verified

33@pytest.mark.gpu
34@pytest.mark.skipif(not env.has_nccl(), reason="need nccl")
35def test_callback():
36 """Simulate lazy loading of parameters in a callback
37
38 The output of a lazy parameter loading, which would accept a
39 callback to load the parameters.
40 """
41
42 @R.function
43 def transform_params(
44 rank_arg: R.Prim(value="rank"),
45 fget_item: R.Callable([R.Object, R.Prim("int64")], R.Object),
46 ):
47 rank = T.int64()
48
49 A = fget_item(R.str("A"), R.prim_value(0))
50 A = R.match_cast(A, R.Tensor([4, 4], "int32"))
51 A = R.strided_slice(A, axes=[0], begin=[rank * 2], end=[(rank + 1) * 2])
52
53 B = fget_item(R.str("B"), R.prim_value(1))
54 B = R.match_cast(B, R.Tensor([2, 2], "float32"))
55 B = R.strided_slice(B, axes=[1], begin=[rank * 1], end=[(rank + 1) * 1])
56
57 return (A, B)
58
59 pipeline = tvm.ir.transform.Sequential(
60 [
61 tvm.relax.transform.LegalizeOps(),
62 tvm.s_tir.dlight.ApplyDefaultSchedule(tvm.s_tir.dlight.gpu.Fallback()),
63 ],
64 name="pipeline",
65 )
66
67 with tvm.target.Target("cuda"):
68 mod = tvm.IRModule.from_expr(transform_params)
69 mod = pipeline(mod)
70 built = tvm.compile(mod, "cuda")
71
72 num_shards = 2
73
74 session = tvm.runtime.disco.ProcessSession(num_workers=num_shards)
75 session.import_python_module("tvm.exec.disco_worker")
76 session.init_ccl("nccl", *range(num_shards))
77
78 worker_device = session.get_global_func("runtime.disco.device")()
79 worker_id = session.get_global_func("runtime.disco.worker_rank")()
80 callback_maker = session.get_global_func("tests.disco.test_callback")
81 fget_item = callback_maker(worker_device)
82
83 with tempfile.TemporaryDirectory() as temp_dir:
84 temp_dir = pathlib.Path(temp_dir)
85
86 # TODO(Lunderberg): Update `disco.Session.load_vm_module` to
87 # allow a `tvm.runtime.Module` argument. This would avoid the
88 # need for a temporary file.
89 shlib_path = temp_dir.joinpath("libtemp.so")
90 built.export_library(shlib_path)
91 vm = session.load_vm_module(shlib_path.as_posix())
92 transform_params = vm["transform_params"]

Callers

nothing calls this directly

Calls 12

transform_paramsFunction · 0.85
from_exprMethod · 0.80
import_python_moduleMethod · 0.80
init_cclMethod · 0.80
get_global_funcMethod · 0.80
load_vm_moduleMethod · 0.80
debug_get_from_remoteMethod · 0.80
numpyMethod · 0.80
compileMethod · 0.45
export_libraryMethod · 0.45
cudaMethod · 0.45
cpuMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…