MCPcopy Index your code
hub / github.com/apache/tvm / test_codegen_nvshmem

Function test_codegen_nvshmem

tests/python/tirx/codegen/test_codegen_nvshmem.py:68–301  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

66@pytest.mark.skipif(not env.has_cuda(), reason="need cuda")
67@pytest.mark.skip(reason="nvshmem doesn't work with pytest")
68def test_codegen_nvshmem():
69 def _test_func():
70 ############ setup ############
71 sess = di.ProcessSession(num_workers=NUM_WORKERS)
72 f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
73 uid = f_init_nvshmem_uid()
74 init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
75 init_dfunc(uid, NUM_WORKERS, 0)
76 sess.sync_worker_0()
77
78 def test_thread_info(sess):
79 @T.prim_func
80 def main(res: T.Buffer((2,), "int32")):
81 T.device_entry()
82 cta_id = T.cta_id([1])
83 tid = T.thread_id([nwarps * 32])
84 res[0] = T.nvshmem.my_pe()
85 res[1] = T.nvshmem.n_pes()
86
87 res_array = sess.empty((2,), "int32")
88 run_prim_func(sess, main, res_array)
89
90 def test_transfer(sess, scope, shape, nwarps, nelems, op_name):
91 """Tests data transfer operations (get/put) at thread, warp, and block scopes."""
92 dtype = "float32"
93 is_get = "get" in op_name
94 op_func = getattr(T.nvshmem, op_name)
95 if scope != "thread":
96 op_func = getattr(op_func, scope)
97
98 # fmt: off
99 @T.prim_func
100 def main(A: T.Buffer(shape, dtype), B: T.Buffer(shape, dtype)):
101 T.device_entry()
102 cta_id = T.cta_id([1])
103 warp_id = T.warp_id([nwarps])
104 lane_id = T.lane_id([32])
105 tid = T.thread_id([nwarps * 32])
106
107 my_pe = T.nvshmem.my_pe()
108 n_pes = T.nvshmem.n_pes()
109 offset = T.if_then_else(
110 scope == "block", 0, T.if_then_else(scope == "thread", tid, warp_id * 32)
111 )
112 op_func(dst=B.ptr_to([offset]), src=A.ptr_to([offset]), nelems=nelems, pe=(my_pe + 1) % n_pes) # noqa: E501
113 T.nvshmem.quiet()
114 # fmt: on
115
116 def init_fn(i, s, d):
117 return np.arange(s[0], dtype=d) + i * 100
118
119 A_array = create_nvshmem_array(sess, shape, dtype, init_fn)
120 B_array = create_nvshmem_array(sess, shape, dtype)
121 sess.sync_worker_0()
122 run_prim_func(sess, main, A_array, B_array)
123
124 for i in range(NUM_WORKERS):
125 if is_get:

Callers 1

Calls 3

sendMethod · 0.95
recvMethod · 0.95
PopenWorkerClass · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…