Setup a distributed process group.
()
| 19 | |
| 20 | |
| 21 | def setup_dist(): |
| 22 | """ |
| 23 | Setup a distributed process group. |
| 24 | """ |
| 25 | if dist.is_initialized(): |
| 26 | return |
| 27 | |
| 28 | comm = MPI.COMM_WORLD |
| 29 | backend = "gloo" if not th.cuda.is_available() else "nccl" |
| 30 | |
| 31 | if backend == "gloo": |
| 32 | hostname = "localhost" |
| 33 | else: |
| 34 | hostname = socket.gethostbyname(socket.getfqdn()) |
| 35 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) |
| 36 | os.environ["RANK"] = str(comm.rank) |
| 37 | os.environ["WORLD_SIZE"] = str(comm.size) |
| 38 | |
| 39 | port = comm.bcast(_find_free_port(), root=0) |
| 40 | os.environ["MASTER_PORT"] = str(port) |
| 41 | dist.init_process_group(backend=backend, init_method="env://") |
| 42 | |
| 43 | |
| 44 | def dev(): |
nothing calls this directly
no test coverage detected