| 63 | |
| 64 | |
| 65 | def init_distributed(rank=-1, local_rank=-1, size=-1, use_gpu=False, backend=""): |
| 66 | global myreq |
| 67 | global my_rank |
| 68 | global my_size |
| 69 | global my_local_rank |
| 70 | global my_local_size |
| 71 | global a2a_impl |
| 72 | global alltoall_supported |
| 73 | |
| 74 | # guess MPI ranks from env (works for IMPI, OMPI and MVAPICH2) |
| 75 | num_mpi_ranks = env2int( |
| 76 | ["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"] |
| 77 | ) |
| 78 | if backend == "" and num_mpi_ranks > 1: |
| 79 | if torch_ccl and env2int(["CCL_WORKER_COUNT"]) > 0: |
| 80 | backend = "ccl" |
| 81 | elif use_gpu and dist.is_nccl_available(): |
| 82 | backend = "nccl" |
| 83 | elif dist.is_mpi_available(): |
| 84 | backend = "mpi" |
| 85 | else: |
| 86 | print( |
| 87 | "WARNING: MPI multi-process launch detected but PyTorch MPI backend not available." |
| 88 | ) |
| 89 | backend = "gloo" |
| 90 | |
| 91 | if backend != "": |
| 92 | # guess Rank and size |
| 93 | if rank == -1: |
| 94 | rank = env2int( |
| 95 | ["PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK", "RANK"], 0 |
| 96 | ) |
| 97 | if size == -1: |
| 98 | size = env2int( |
| 99 | [ |
| 100 | "PMI_SIZE", |
| 101 | "OMPI_COMM_WORLD_SIZE", |
| 102 | "MV2_COMM_WORLD_SIZE", |
| 103 | "WORLD_SIZE", |
| 104 | ], |
| 105 | 1, |
| 106 | ) |
| 107 | if not os.environ.get("RANK", None) and rank != -1: |
| 108 | os.environ["RANK"] = str(rank) |
| 109 | if not os.environ.get("WORLD_SIZE", None) and size != -1: |
| 110 | os.environ["WORLD_SIZE"] = str(size) |
| 111 | if not os.environ.get("MASTER_PORT", None): |
| 112 | os.environ["MASTER_PORT"] = "29500" |
| 113 | if not os.environ.get("MASTER_ADDR", None): |
| 114 | local_size = env2int( |
| 115 | [ |
| 116 | "MPI_LOCALNRANKS", |
| 117 | "OMPI_COMM_WORLD_LOCAL_SIZE", |
| 118 | "MV2_COMM_WORLD_LOCAL_SIZE", |
| 119 | ], |
| 120 | 1, |
| 121 | ) |
| 122 | if local_size != size and backend != "mpi": |