| 17 | |
| 18 | |
| 19 | class FedMLRunner: |
| 20 | def __init__( |
| 21 | self, |
| 22 | args, |
| 23 | device, |
| 24 | dataset, |
| 25 | model, |
| 26 | client_trainer: ClientTrainer = None, |
| 27 | server_aggregator: ServerAggregator = None, |
| 28 | algorithm_flow: FedMLAlgorithmFlow = None, |
| 29 | ): |
| 30 | if algorithm_flow is not None: |
| 31 | self.runner = algorithm_flow |
| 32 | return |
| 33 | |
| 34 | if args.training_type == FEDML_TRAINING_PLATFORM_SIMULATION: |
| 35 | init_runner_func = self._init_simulation_runner |
| 36 | |
| 37 | elif args.training_type == FEDML_TRAINING_PLATFORM_CROSS_SILO: |
| 38 | init_runner_func = self._init_cross_silo_runner |
| 39 | |
| 40 | elif args.training_type == FEDML_TRAINING_PLATFORM_CROSS_CLOUD: |
| 41 | init_runner_func = self._init_cheetah_runner |
| 42 | |
| 43 | elif args.training_type == FEDML_TRAINING_PLATFORM_SERVING: |
| 44 | init_runner_func = self._init_model_serving_runner |
| 45 | |
| 46 | elif args.training_type == FEDML_TRAINING_PLATFORM_CROSS_DEVICE: |
| 47 | init_runner_func = self._init_cross_device_runner |
| 48 | else: |
| 49 | raise Exception("no such setting") |
| 50 | |
| 51 | self.runner = init_runner_func( |
| 52 | args, device, dataset, model, client_trainer, server_aggregator |
| 53 | ) |
| 54 | |
| 55 | def _init_simulation_runner( |
| 56 | self, args, device, dataset, model, client_trainer=None, server_aggregator=None |
| 57 | ): |
| 58 | if hasattr(args, "backend") and args.backend == FEDML_SIMULATION_TYPE_SP: |
| 59 | from .simulation.simulator import SimulatorSingleProcess |
| 60 | |
| 61 | runner = SimulatorSingleProcess( |
| 62 | args, device, dataset, model, client_trainer, server_aggregator |
| 63 | ) |
| 64 | elif hasattr(args, "backend") and args.backend == FEDML_SIMULATION_TYPE_MPI: |
| 65 | from .simulation.simulator import SimulatorMPI |
| 66 | |
| 67 | runner = SimulatorMPI( |
| 68 | args, device, dataset, model, client_trainer, server_aggregator |
| 69 | ) |
| 70 | elif hasattr(args, "backend") and args.backend == FEDML_SIMULATION_TYPE_NCCL: |
| 71 | from .simulation.simulator import SimulatorNCCL |
| 72 | |
| 73 | runner = SimulatorNCCL( |
| 74 | args, device, dataset, model, client_trainer, server_aggregator |
| 75 | ) |
| 76 | else: |
no outgoing calls
no test coverage detected