(self, local_rank, num_procs, master_port, init_method, skip_msg="")
| 293 | self._launch_daemonic_procs(num_procs, init_method) |
| 294 | |
| 295 | def _dist_run(self, local_rank, num_procs, master_port, init_method, skip_msg=""): |
| 296 | if dist.is_initialized(): |
| 297 | if get_accelerator().is_available(): |
| 298 | # local_rank might not match the rank in the previous run if you are reusing the environment |
| 299 | get_accelerator().set_device(dist.get_rank()) |
| 300 | else: |
| 301 | """ Initialize deepspeed.comm and execute the user function. """ |
| 302 | if self.set_dist_env: |
| 303 | os.environ['MASTER_ADDR'] = '127.0.0.1' |
| 304 | os.environ['MASTER_PORT'] = str(master_port) |
| 305 | os.environ['LOCAL_RANK'] = str(local_rank) |
| 306 | # NOTE: unit tests don't support multi-node so local_rank == global rank |
| 307 | os.environ['RANK'] = str(local_rank) |
| 308 | # In case of multiprocess launching LOCAL_SIZE should be same as WORLD_SIZE |
| 309 | # DeepSpeed single node launcher would also set LOCAL_SIZE accordingly |
| 310 | os.environ['LOCAL_SIZE'] = str(num_procs) |
| 311 | os.environ['WORLD_SIZE'] = str(num_procs) |
| 312 | |
| 313 | # turn off NCCL logging if set |
| 314 | os.environ.pop('NCCL_DEBUG', None) |
| 315 | |
| 316 | if get_accelerator().is_available(): |
| 317 | set_accelerator_visible() |
| 318 | |
| 319 | if get_accelerator().is_available(): |
| 320 | get_accelerator().set_device(local_rank) |
| 321 | |
| 322 | if self.init_distributed: |
| 323 | deepspeed.init_distributed(dist_backend=self.backend, |
| 324 | init_method=init_method, |
| 325 | rank=local_rank, |
| 326 | world_size=num_procs) |
| 327 | dist.barrier() |
| 328 | |
| 329 | try: |
| 330 | self.run(**self._fixture_kwargs) |
| 331 | except BaseException as e: |
| 332 | if isinstance(e, Skipped): |
| 333 | if self.non_daemonic_procs: |
| 334 | skip_msg.put(e.msg) |
| 335 | else: |
| 336 | skip_msg = e.msg |
| 337 | else: |
| 338 | raise e |
| 339 | |
| 340 | return skip_msg |
| 341 | |
| 342 | def _launch_with_file_store(self, request, world_size): |
| 343 | tmpdir = request.getfixturevalue("tmpdir") |
nothing calls this directly
no test coverage detected