MCPcopy
hub / github.com/deepspeedai/DeepSpeed / reshape

Function reshape

deepspeed/checkpoint/reshape_meg_2d.py:182–208  ·  view source on GitHub ↗

reshape([tp_size_src, pp_size_src, dp_size_src], [tp_size_tgt, pp_size_tgt, dp_size_tgt])

(src, tgt)

Source from the content-addressed store, hash-verified

180
181
182def reshape(src, tgt):
183 """
184 reshape([tp_size_src, pp_size_src, dp_size_src],
185 [tp_size_tgt, pp_size_tgt, dp_size_tgt])
186 """
187
188 print(f"\n\n*** Reshaping: {src} => {tgt}")
189
190 tp_size_src, pp_size_src, dp_size_src = src
191 tp_size_tgt, pp_size_tgt, dp_size_tgt = tgt
192
193 tp_ranks1, pp_ranks1, dp_ranks1 = get_mpu_ranks(tp_size=tp_size_src, pp_size=pp_size_src, dp_size=dp_size_src)
194 tp_ranks2, pp_ranks2, dp_ranks2 = get_mpu_ranks(tp_size=tp_size_tgt, pp_size=pp_size_src, dp_size=dp_size_src)
195 tp_ranks3, pp_ranks3, dp_ranks3 = get_mpu_ranks(tp_size=tp_size_tgt, pp_size=pp_size_tgt, dp_size=dp_size_src)
196
197 # handle tp contraction first
198 print("\n*** TP contraction:")
199
200 for i, r in enumerate(tp_ranks1):
201 print(f'{tp_ranks1[i]} => {tp_ranks2[i]}')
202
203 # handle pp contraction next
204
205 print("\n*** PP contraction:")
206
207 for i, r in enumerate(pp_ranks1):
208 print(f'{pp_ranks2[i]} => {pp_ranks3[i]}')
209
210
211# easy

Callers

nothing calls this directly

Calls 1

get_mpu_ranksFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…