Base class for distributed execution of functions/methods. Contains common methods needed for DistributedTest and DistributedFixture.
| 137 | |
| 138 | |
| 139 | class DistributedExec(ABC): |
| 140 | """ |
| 141 | Base class for distributed execution of functions/methods. Contains common |
| 142 | methods needed for DistributedTest and DistributedFixture. |
| 143 | """ |
| 144 | world_size = 2 |
| 145 | backend = get_accelerator().communication_backend_name() |
| 146 | init_distributed = True |
| 147 | set_dist_env = True |
| 148 | requires_cuda_env = True |
| 149 | reuse_dist_env = False |
| 150 | non_daemonic_procs = False |
| 151 | _pool_cache = {} |
| 152 | exec_timeout = DEEPSPEED_TEST_TIMEOUT |
| 153 | |
| 154 | @abstractmethod |
| 155 | def run(self): |
| 156 | ... |
| 157 | |
| 158 | def __call__(self, request): |
| 159 | self._fixture_kwargs = self._get_fixture_kwargs(request, self.run) |
| 160 | world_size = self.world_size |
| 161 | if self.requires_cuda_env and not get_accelerator().is_available(): |
| 162 | pytest.skip("only supported in accelerator environments.") |
| 163 | |
| 164 | self._launch_with_file_store(request, world_size) |
| 165 | |
| 166 | def _get_fixture_kwargs(self, request, func): |
| 167 | if not request: |
| 168 | return {} |
| 169 | # Grab fixture / parametrize kwargs from pytest request object |
| 170 | fixture_kwargs = {} |
| 171 | params = inspect.getfullargspec(func).args |
| 172 | params.remove("self") |
| 173 | for p in params: |
| 174 | try: |
| 175 | fixture_kwargs[p] = request.getfixturevalue(p) |
| 176 | except FixtureLookupError: |
| 177 | pass # test methods can have kwargs that are not fixtures |
| 178 | return fixture_kwargs |
| 179 | |
| 180 | def _launch_daemonic_procs(self, num_procs, init_method): |
| 181 | # Create process pool or use cached one |
| 182 | master_port = None |
| 183 | |
| 184 | if get_accelerator().device_name() == 'hpu': |
| 185 | if self.reuse_dist_env: |
| 186 | print("Ignoring reuse_dist_env for hpu") |
| 187 | self.reuse_dist_env = False |
| 188 | |
| 189 | if self.reuse_dist_env: |
| 190 | if num_procs not in self._pool_cache: |
| 191 | self._pool_cache[num_procs] = mp.Pool(processes=num_procs) |
| 192 | master_port = get_master_port() |
| 193 | pool = self._pool_cache[num_procs] |
| 194 | else: |
| 195 | pool = mp.Pool(processes=num_procs) |
| 196 | master_port = get_master_port() |
nothing calls this directly
no test coverage detected
searching dependent graphs…