(args)
| 48 | args.comm = None |
| 49 | |
| 50 | def init_cross_silo(args): |
| 51 | manage_mpi_args(args) |
| 52 | |
| 53 | # Set intra-silo arguments |
| 54 | if args.rank == 0: |
| 55 | args.n_node_in_silo = 1 |
| 56 | args.n_proc_in_silo = 1 |
| 57 | args.rank_in_node = 0 |
| 58 | args.proc_rank_in_silo = 0 |
| 59 | else: |
| 60 | # Modify arguments to match info set in env by torchrun |
| 61 | # Silo Topology |
| 62 | |
| 63 | args.n_proc_in_silo = int(os.environ.get("WORLD_SIZE", 1)) |
| 64 | |
| 65 | # Rank in node |
| 66 | args.rank_in_node = int(os.environ.get("LOCAL_RANK", 0)) |
| 67 | args.process_id = args.rank_in_node |
| 68 | |
| 69 | # Rank in silo (process group) |
| 70 | args.proc_rank_in_silo = int(os.environ.get("RANK", 0)) |
| 71 | |
| 72 | # Process group master endpoint |
| 73 | args.pg_master_address = os.environ.get("MASTER_ADDR", "127.0.0.1") |
| 74 | args.pg_master_port = os.environ.get("MASTER_PORT", 29300) |
| 75 | |
| 76 | if not hasattr(args, "n_node_in_silo"): |
| 77 | args.n_node_in_silo = 1 |
| 78 | if not (hasattr(args, "n_proc_per_node") and args.n_proc_per_node): |
| 79 | args.n_proc_per_node = 1 |
| 80 | |
| 81 | return args |
| 82 | |
| 83 | |
| 84 | def init_simulation_sp(args): |
no test coverage detected