Compile and run a kernel that calls NVSHMEM functions. Runs in a fresh process, so setting the env var is safe.
(compile_mode)
| 280 | |
| 281 | |
| 282 | def _kernel_compile(compile_mode): |
| 283 | """Compile and run a kernel that calls NVSHMEM functions. |
| 284 | |
| 285 | Runs in a fresh process, so setting the env var is safe. |
| 286 | """ |
| 287 | os.environ["TVM_CUDA_COMPILE_MODE"] = compile_mode |
| 288 | |
| 289 | num_workers = 2 |
| 290 | sess = di.ProcessSession(num_workers=num_workers) |
| 291 | |
| 292 | f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") |
| 293 | uid = f_init_nvshmem_uid() |
| 294 | init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem") |
| 295 | init_dfunc(uid, num_workers, 0) |
| 296 | sess.sync_worker_0() |
| 297 | |
| 298 | try: |
| 299 | |
| 300 | @I.ir_module(s_tir=True) |
| 301 | class NvshmemQueryModule: |
| 302 | @T.prim_func(s_tir=True) |
| 303 | def query_pe( |
| 304 | my_pe_out: T.Buffer((1,), "int32"), |
| 305 | n_pes_out: T.Buffer((1,), "int32"), |
| 306 | ): |
| 307 | with T.sblock("root"): |
| 308 | T.reads() |
| 309 | T.writes(my_pe_out[0:1], n_pes_out[0:1]) |
| 310 | T.call_kernel( |
| 311 | NVSHMEM_QUERY_KERNEL_SOURCE, |
| 312 | ((1,), (1,)), # grid=(1,), block=(1,) |
| 313 | my_pe_out.data, |
| 314 | n_pes_out.data, |
| 315 | kernel_name="nvshmem_query_kernel", |
| 316 | ) |
| 317 | |
| 318 | @R.function |
| 319 | def main() -> R.Tuple(R.Tensor((1,), "int32"), R.Tensor((1,), "int32")): |
| 320 | cls = NvshmemQueryModule |
| 321 | with R.dataflow(): |
| 322 | my_pe = R.call_tir( |
| 323 | cls.query_pe, |
| 324 | (), |
| 325 | out_sinfo=[ |
| 326 | R.Tensor((1,), "int32"), |
| 327 | R.Tensor((1,), "int32"), |
| 328 | ], |
| 329 | ) |
| 330 | R.output(my_pe) |
| 331 | return my_pe |
| 332 | |
| 333 | tmpdir = tempfile.mkdtemp() |
| 334 | try: |
| 335 | path = tmpdir + "/test_nvshmem_kernel.so" |
| 336 | |
| 337 | target = tvm.target.Target("cuda") |
| 338 | tvm.compile(NvshmemQueryModule, target=target).export_library(path) |
| 339 | mod = sess.load_vm_module(path) |
nothing calls this directly
no test coverage detected
searching dependent graphs…