Initialize distributed environment for distributed training. Args: config (str): Config file path. launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default. master_port (str): The master port for distributed training.
(
config: str,
launcher: str = "slurm",
master_port: int = 8888,
seed: int = 1024,
args_check=True,
)
| 427 | |
| 428 | |
| 429 | def initialize_distributed_env( |
| 430 | config: str, |
| 431 | launcher: str = "slurm", |
| 432 | master_port: int = 8888, |
| 433 | seed: int = 1024, |
| 434 | args_check=True, |
| 435 | ): |
| 436 | """ |
| 437 | Initialize distributed environment for distributed training. |
| 438 | |
| 439 | Args: |
| 440 | config (str): Config file path. |
| 441 | launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default. |
| 442 | master_port (str): The master port for distributed training. 8888 by default. |
| 443 | seed (int, optional): Specified random seed for every process. 1024 by default. |
| 444 | """ |
| 445 | |
| 446 | torch.cuda.empty_cache() |
| 447 | |
| 448 | if launcher == "torch": |
| 449 | launch_from_torch(config=config, seed=seed) |
| 450 | elif launcher == "slurm": |
| 451 | launch_from_slurm( |
| 452 | config=config, |
| 453 | host=get_master_node(), |
| 454 | port=master_port, |
| 455 | seed=seed, |
| 456 | ) |
| 457 | else: |
| 458 | assert launcher in ["slurm", "torch"], "launcher only support slurm or torch" |
| 459 | |
| 460 | if args_check: |
| 461 | args_sanity_check() |
no test coverage detected