reshape([tp_size_src, pp_size_src, dp_size_src], [tp_size_tgt, pp_size_tgt, dp_size_tgt])
(src, tgt)
| 180 | |
| 181 | |
| 182 | def reshape(src, tgt): |
| 183 | """ |
| 184 | reshape([tp_size_src, pp_size_src, dp_size_src], |
| 185 | [tp_size_tgt, pp_size_tgt, dp_size_tgt]) |
| 186 | """ |
| 187 | |
| 188 | print(f"\n\n*** Reshaping: {src} => {tgt}") |
| 189 | |
| 190 | tp_size_src, pp_size_src, dp_size_src = src |
| 191 | tp_size_tgt, pp_size_tgt, dp_size_tgt = tgt |
| 192 | |
| 193 | tp_ranks1, pp_ranks1, dp_ranks1 = get_mpu_ranks(tp_size=tp_size_src, pp_size=pp_size_src, dp_size=dp_size_src) |
| 194 | tp_ranks2, pp_ranks2, dp_ranks2 = get_mpu_ranks(tp_size=tp_size_tgt, pp_size=pp_size_src, dp_size=dp_size_src) |
| 195 | tp_ranks3, pp_ranks3, dp_ranks3 = get_mpu_ranks(tp_size=tp_size_tgt, pp_size=pp_size_tgt, dp_size=dp_size_src) |
| 196 | |
| 197 | # handle tp contraction first |
| 198 | print("\n*** TP contraction:") |
| 199 | |
| 200 | for i, r in enumerate(tp_ranks1): |
| 201 | print(f'{tp_ranks1[i]} => {tp_ranks2[i]}') |
| 202 | |
| 203 | # handle pp contraction next |
| 204 | |
| 205 | print("\n*** PP contraction:") |
| 206 | |
| 207 | for i, r in enumerate(pp_ranks1): |
| 208 | print(f'{pp_ranks2[i]} => {pp_ranks3[i]}') |
| 209 | |
| 210 | |
| 211 | # easy |
nothing calls this directly
no test coverage detected
searching dependent graphs…