MCPcopy
hub / github.com/deepspeedai/DeepSpeed / DistributedExec

Class DistributedExec

tests/unit/common.py:139–362  ·  view source on GitHub ↗

Base class for distributed execution of functions/methods. Contains common methods needed for DistributedTest and DistributedFixture.

Source from the content-addressed store, hash-verified

137
138
139class DistributedExec(ABC):
140 """
141 Base class for distributed execution of functions/methods. Contains common
142 methods needed for DistributedTest and DistributedFixture.
143 """
144 world_size = 2
145 backend = get_accelerator().communication_backend_name()
146 init_distributed = True
147 set_dist_env = True
148 requires_cuda_env = True
149 reuse_dist_env = False
150 non_daemonic_procs = False
151 _pool_cache = {}
152 exec_timeout = DEEPSPEED_TEST_TIMEOUT
153
154 @abstractmethod
155 def run(self):
156 ...
157
158 def __call__(self, request):
159 self._fixture_kwargs = self._get_fixture_kwargs(request, self.run)
160 world_size = self.world_size
161 if self.requires_cuda_env and not get_accelerator().is_available():
162 pytest.skip("only supported in accelerator environments.")
163
164 self._launch_with_file_store(request, world_size)
165
166 def _get_fixture_kwargs(self, request, func):
167 if not request:
168 return {}
169 # Grab fixture / parametrize kwargs from pytest request object
170 fixture_kwargs = {}
171 params = inspect.getfullargspec(func).args
172 params.remove("self")
173 for p in params:
174 try:
175 fixture_kwargs[p] = request.getfixturevalue(p)
176 except FixtureLookupError:
177 pass # test methods can have kwargs that are not fixtures
178 return fixture_kwargs
179
180 def _launch_daemonic_procs(self, num_procs, init_method):
181 # Create process pool or use cached one
182 master_port = None
183
184 if get_accelerator().device_name() == 'hpu':
185 if self.reuse_dist_env:
186 print("Ignoring reuse_dist_env for hpu")
187 self.reuse_dist_env = False
188
189 if self.reuse_dist_env:
190 if num_procs not in self._pool_cache:
191 self._pool_cache[num_procs] = mp.Pool(processes=num_procs)
192 master_port = get_master_port()
193 pool = self._pool_cache[num_procs]
194 else:
195 pool = mp.Pool(processes=num_procs)
196 master_port = get_master_port()

Callers

nothing calls this directly

Calls 2

get_acceleratorFunction · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…