| 31 | return args |
| 32 | |
| 33 | def manage_mpi_args(args): |
| 34 | if hasattr(args, "backend") and args.backend == "MPI": |
| 35 | from mpi4py import MPI |
| 36 | |
| 37 | comm = MPI.COMM_WORLD |
| 38 | process_id = comm.Get_rank() |
| 39 | world_size = comm.Get_size() |
| 40 | args.comm = comm |
| 41 | args.rank = process_id |
| 42 | if process_id == 0: |
| 43 | args.role = "server" |
| 44 | else: |
| 45 | args.role = "client" |
| 46 | assert args.worker_num + 1 == world_size, f"Invalid number of mpi processes. Expected {args.worker_num + 1}" |
| 47 | else: |
| 48 | args.comm = None |
| 49 | |
| 50 | def init_cross_silo(args): |
| 51 | manage_mpi_args(args) |