Returns the group of the current device for `parallel_mode`. Args: parallel_mode: The chosen parallel mode. Returns: torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`.
(self, parallel_mode: ParallelMode)
| 333 | return self._world_sizes.get(parallel_mode, 1) |
| 334 | |
| 335 | def get_group(self, parallel_mode: ParallelMode): |
| 336 | """Returns the group of the current device for `parallel_mode`. |
| 337 | |
| 338 | Args: |
| 339 | parallel_mode: The chosen parallel mode. |
| 340 | |
| 341 | Returns: |
| 342 | torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`. |
| 343 | """ |
| 344 | self._check_parallel_mode(parallel_mode) |
| 345 | return self._groups[parallel_mode] |
| 346 | |
| 347 | def get_ranks_in_group(self, parallel_mode: ParallelMode): |
| 348 | """Returns the rank of the current device for `parallel_mode` in the group. |
no test coverage detected