MCPcopy
hub / github.com/FedML-AI/FedML / FedMLRunner

Class FedMLRunner

python/fedml/runner.py:19–183  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

17
18
19class FedMLRunner:
20 def __init__(
21 self,
22 args,
23 device,
24 dataset,
25 model,
26 client_trainer: ClientTrainer = None,
27 server_aggregator: ServerAggregator = None,
28 algorithm_flow: FedMLAlgorithmFlow = None,
29 ):
30 if algorithm_flow is not None:
31 self.runner = algorithm_flow
32 return
33
34 if args.training_type == FEDML_TRAINING_PLATFORM_SIMULATION:
35 init_runner_func = self._init_simulation_runner
36
37 elif args.training_type == FEDML_TRAINING_PLATFORM_CROSS_SILO:
38 init_runner_func = self._init_cross_silo_runner
39
40 elif args.training_type == FEDML_TRAINING_PLATFORM_CROSS_CLOUD:
41 init_runner_func = self._init_cheetah_runner
42
43 elif args.training_type == FEDML_TRAINING_PLATFORM_SERVING:
44 init_runner_func = self._init_model_serving_runner
45
46 elif args.training_type == FEDML_TRAINING_PLATFORM_CROSS_DEVICE:
47 init_runner_func = self._init_cross_device_runner
48 else:
49 raise Exception("no such setting")
50
51 self.runner = init_runner_func(
52 args, device, dataset, model, client_trainer, server_aggregator
53 )
54
55 def _init_simulation_runner(
56 self, args, device, dataset, model, client_trainer=None, server_aggregator=None
57 ):
58 if hasattr(args, "backend") and args.backend == FEDML_SIMULATION_TYPE_SP:
59 from .simulation.simulator import SimulatorSingleProcess
60
61 runner = SimulatorSingleProcess(
62 args, device, dataset, model, client_trainer, server_aggregator
63 )
64 elif hasattr(args, "backend") and args.backend == FEDML_SIMULATION_TYPE_MPI:
65 from .simulation.simulator import SimulatorMPI
66
67 runner = SimulatorMPI(
68 args, device, dataset, model, client_trainer, server_aggregator
69 )
70 elif hasattr(args, "backend") and args.backend == FEDML_SIMULATION_TYPE_NCCL:
71 from .simulation.simulator import SimulatorNCCL
72
73 runner = SimulatorNCCL(
74 args, device, dataset, model, client_trainer, server_aggregator
75 )
76 else:

Calls

no outgoing calls

Tested by

no test coverage detected