Wrapper to run a function in a distributed environment. The function `fn` should have the following signature: ```python def fn(local_rank: int, world_rank: int, world_size: int, args: Any) -> None: pass ``` Usage: ```python # Launch with "CUDA_VISIBLE_DEVICES
(fn: Callable, args: Any, verbose: bool = False)
| 302 | |
| 303 | |
| 304 | def cli(fn: Callable, args: Any, verbose: bool = False) -> bool: |
| 305 | """Wrapper to run a function in a distributed environment. |
| 306 | |
| 307 | The function `fn` should have the following signature: |
| 308 | |
| 309 | ```python |
| 310 | def fn(local_rank: int, world_rank: int, world_size: int, args: Any) -> None: |
| 311 | pass |
| 312 | ``` |
| 313 | |
| 314 | Usage: |
| 315 | |
| 316 | ```python |
| 317 | # Launch with "CUDA_VISIBLE_DEVICES=0,1,2,3 python my_script.py" |
| 318 | if __name__ == "__main__": |
| 319 | cli(fn, None, verbose=True) |
| 320 | ``` |
| 321 | """ |
| 322 | assert torch.cuda.is_available(), "CUDA device is required!" |
| 323 | if "OMPI_COMM_WORLD_SIZE" in os.environ: # multi-node |
| 324 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) |
| 325 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) # dist.get_world_size() |
| 326 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) # dist.get_rank() |
| 327 | return _distributed_worker( |
| 328 | world_rank, world_size, fn, args, local_rank, verbose |
| 329 | ) |
| 330 | |
| 331 | world_size = torch.cuda.device_count() |
| 332 | distributed = world_size > 1 |
| 333 | |
| 334 | if distributed: |
| 335 | os.environ["MASTER_ADDR"] = "localhost" |
| 336 | os.environ["MASTER_PORT"] = str(_find_free_port()) |
| 337 | process_context = torch.multiprocessing.spawn( |
| 338 | _distributed_worker, |
| 339 | args=(world_size, fn, args, None, verbose), |
| 340 | nprocs=world_size, |
| 341 | join=False, |
| 342 | ) |
| 343 | try: |
| 344 | process_context.join() |
| 345 | except KeyboardInterrupt: |
| 346 | # this is important. |
| 347 | # if we do not explicitly terminate all launched subprocesses, |
| 348 | # they would continue living even after this main process ends, |
| 349 | # eventually making the OD machine unusable! |
| 350 | for i, process in enumerate(process_context.processes): |
| 351 | if process.is_alive(): |
| 352 | if verbose: |
| 353 | print("terminating process " + str(i) + "...") |
| 354 | process.terminate() |
| 355 | process.join() |
| 356 | if verbose: |
| 357 | print("process " + str(i) + " finished") |
| 358 | return True |
| 359 | else: |
| 360 | return _distributed_worker(0, 1, fn=fn, args=args) |