()
| 66 | @pytest.mark.skipif(not env.has_cuda(), reason="need cuda") |
| 67 | @pytest.mark.skip(reason="nvshmem doesn't work with pytest") |
| 68 | def test_codegen_nvshmem(): |
| 69 | def _test_func(): |
| 70 | ############ setup ############ |
| 71 | sess = di.ProcessSession(num_workers=NUM_WORKERS) |
| 72 | f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") |
| 73 | uid = f_init_nvshmem_uid() |
| 74 | init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem") |
| 75 | init_dfunc(uid, NUM_WORKERS, 0) |
| 76 | sess.sync_worker_0() |
| 77 | |
| 78 | def test_thread_info(sess): |
| 79 | @T.prim_func |
| 80 | def main(res: T.Buffer((2,), "int32")): |
| 81 | T.device_entry() |
| 82 | cta_id = T.cta_id([1]) |
| 83 | tid = T.thread_id([nwarps * 32]) |
| 84 | res[0] = T.nvshmem.my_pe() |
| 85 | res[1] = T.nvshmem.n_pes() |
| 86 | |
| 87 | res_array = sess.empty((2,), "int32") |
| 88 | run_prim_func(sess, main, res_array) |
| 89 | |
| 90 | def test_transfer(sess, scope, shape, nwarps, nelems, op_name): |
| 91 | """Tests data transfer operations (get/put) at thread, warp, and block scopes.""" |
| 92 | dtype = "float32" |
| 93 | is_get = "get" in op_name |
| 94 | op_func = getattr(T.nvshmem, op_name) |
| 95 | if scope != "thread": |
| 96 | op_func = getattr(op_func, scope) |
| 97 | |
| 98 | # fmt: off |
| 99 | @T.prim_func |
| 100 | def main(A: T.Buffer(shape, dtype), B: T.Buffer(shape, dtype)): |
| 101 | T.device_entry() |
| 102 | cta_id = T.cta_id([1]) |
| 103 | warp_id = T.warp_id([nwarps]) |
| 104 | lane_id = T.lane_id([32]) |
| 105 | tid = T.thread_id([nwarps * 32]) |
| 106 | |
| 107 | my_pe = T.nvshmem.my_pe() |
| 108 | n_pes = T.nvshmem.n_pes() |
| 109 | offset = T.if_then_else( |
| 110 | scope == "block", 0, T.if_then_else(scope == "thread", tid, warp_id * 32) |
| 111 | ) |
| 112 | op_func(dst=B.ptr_to([offset]), src=A.ptr_to([offset]), nelems=nelems, pe=(my_pe + 1) % n_pes) # noqa: E501 |
| 113 | T.nvshmem.quiet() |
| 114 | # fmt: on |
| 115 | |
| 116 | def init_fn(i, s, d): |
| 117 | return np.arange(s[0], dtype=d) + i * 100 |
| 118 | |
| 119 | A_array = create_nvshmem_array(sess, shape, dtype, init_fn) |
| 120 | B_array = create_nvshmem_array(sess, shape, dtype) |
| 121 | sess.sync_worker_0() |
| 122 | run_prim_func(sess, main, A_array, B_array) |
| 123 | |
| 124 | for i in range(NUM_WORKERS): |
| 125 | if is_get: |
no test coverage detected
searching dependent graphs…