MCPcopy Index your code
hub / github.com/apache/tvm / _kernel_compile

Function _kernel_compile

tests/python/disco/test_nvshmem.py:282–363  ·  view source on GitHub ↗

Compile and run a kernel that calls NVSHMEM functions. Runs in a fresh process, so setting the env var is safe.

(compile_mode)

Source from the content-addressed store, hash-verified

280
281
282def _kernel_compile(compile_mode):
283 """Compile and run a kernel that calls NVSHMEM functions.
284
285 Runs in a fresh process, so setting the env var is safe.
286 """
287 os.environ["TVM_CUDA_COMPILE_MODE"] = compile_mode
288
289 num_workers = 2
290 sess = di.ProcessSession(num_workers=num_workers)
291
292 f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
293 uid = f_init_nvshmem_uid()
294 init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
295 init_dfunc(uid, num_workers, 0)
296 sess.sync_worker_0()
297
298 try:
299
300 @I.ir_module(s_tir=True)
301 class NvshmemQueryModule:
302 @T.prim_func(s_tir=True)
303 def query_pe(
304 my_pe_out: T.Buffer((1,), "int32"),
305 n_pes_out: T.Buffer((1,), "int32"),
306 ):
307 with T.sblock("root"):
308 T.reads()
309 T.writes(my_pe_out[0:1], n_pes_out[0:1])
310 T.call_kernel(
311 NVSHMEM_QUERY_KERNEL_SOURCE,
312 ((1,), (1,)), # grid=(1,), block=(1,)
313 my_pe_out.data,
314 n_pes_out.data,
315 kernel_name="nvshmem_query_kernel",
316 )
317
318 @R.function
319 def main() -> R.Tuple(R.Tensor((1,), "int32"), R.Tensor((1,), "int32")):
320 cls = NvshmemQueryModule
321 with R.dataflow():
322 my_pe = R.call_tir(
323 cls.query_pe,
324 (),
325 out_sinfo=[
326 R.Tensor((1,), "int32"),
327 R.Tensor((1,), "int32"),
328 ],
329 )
330 R.output(my_pe)
331 return my_pe
332
333 tmpdir = tempfile.mkdtemp()
334 try:
335 path = tmpdir + "/test_nvshmem_kernel.so"
336
337 target = tvm.target.Target("cuda")
338 tvm.compile(NvshmemQueryModule, target=target).export_library(path)
339 mod = sess.load_vm_module(path)

Callers

nothing calls this directly

Calls 9

get_global_funcMethod · 0.80
sync_worker_0Method · 0.80
load_vm_moduleMethod · 0.80
debug_get_from_remoteMethod · 0.80
numpyMethod · 0.80
_sync_allMethod · 0.80
export_libraryMethod · 0.45
compileMethod · 0.45
shutdownMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…