()
| 538 | |
| 539 | |
| 540 | def mpi_rank(): |
| 541 | if mpi_disabled(): |
| 542 | try: |
| 543 | return torch.distributed.get_rank() |
| 544 | except ValueError: |
| 545 | # Fallback: return 0 when MPI is absent (Ray / Slurm PMIx) |
| 546 | return 0 |
| 547 | return mpi_comm().Get_rank() if ENABLE_MULTI_DEVICE else 0 |
| 548 | |
| 549 | |
| 550 | def global_mpi_rank(): |