MCPcopy
hub / github.com/nerfstudio-project/gsplat / cli

Function cli

gsplat/distributed.py:304–360  ·  view source on GitHub ↗

Wrapper to run a function in a distributed environment. The function `fn` should have the following signature: ```python def fn(local_rank: int, world_rank: int, world_size: int, args: Any) -> None: pass ``` Usage: ```python # Launch with "CUDA_VISIBLE_DEVICES

(fn: Callable, args: Any, verbose: bool = False)

Source from the content-addressed store, hash-verified

302
303
304def cli(fn: Callable, args: Any, verbose: bool = False) -> bool:
305 """Wrapper to run a function in a distributed environment.
306
307 The function `fn` should have the following signature:
308
309 ```python
310 def fn(local_rank: int, world_rank: int, world_size: int, args: Any) -> None:
311 pass
312 ```
313
314 Usage:
315
316 ```python
317 # Launch with "CUDA_VISIBLE_DEVICES=0,1,2,3 python my_script.py"
318 if __name__ == "__main__":
319 cli(fn, None, verbose=True)
320 ```
321 """
322 assert torch.cuda.is_available(), "CUDA device is required!"
323 if "OMPI_COMM_WORLD_SIZE" in os.environ: # multi-node
324 local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
325 world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) # dist.get_world_size()
326 world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) # dist.get_rank()
327 return _distributed_worker(
328 world_rank, world_size, fn, args, local_rank, verbose
329 )
330
331 world_size = torch.cuda.device_count()
332 distributed = world_size > 1
333
334 if distributed:
335 os.environ["MASTER_ADDR"] = "localhost"
336 os.environ["MASTER_PORT"] = str(_find_free_port())
337 process_context = torch.multiprocessing.spawn(
338 _distributed_worker,
339 args=(world_size, fn, args, None, verbose),
340 nprocs=world_size,
341 join=False,
342 )
343 try:
344 process_context.join()
345 except KeyboardInterrupt:
346 # this is important.
347 # if we do not explicitly terminate all launched subprocesses,
348 # they would continue living even after this main process ends,
349 # eventually making the OD machine unusable!
350 for i, process in enumerate(process_context.processes):
351 if process.is_alive():
352 if verbose:
353 print("terminating process " + str(i) + "...")
354 process.terminate()
355 process.join()
356 if verbose:
357 print("process " + str(i) + " finished")
358 return True
359 else:
360 return _distributed_worker(0, 1, fn=fn, args=args)

Callers 10

batch.pyFile · 0.90
main.pyFile · 0.90
test_all_gather_int32Function · 0.90
test_all_to_all_int32Function · 0.90
simple_viewer.pyFile · 0.90
simple_trainer.pyFile · 0.90

Calls 2

_distributed_workerFunction · 0.85
_find_free_portFunction · 0.85

Tested by 4

test_all_gather_int32Function · 0.72
test_all_to_all_int32Function · 0.72