Get the data parallel group the caller rank belongs to.
()
| 612 | |
| 613 | |
| 614 | def _get_data_parallel_group(): |
| 615 | """Get the data parallel group the caller rank belongs to.""" |
| 616 | assert dist.is_initialized(), 'dist is not initialized' |
| 617 | global mpu |
| 618 | if mesh_device is not None: |
| 619 | return mesh_device.get_group(mesh_dim="data_parallel") |
| 620 | if mpu is not None: |
| 621 | if hasattr(mpu, 'initialize_sequence_parallel'): |
| 622 | return None |
| 623 | else: |
| 624 | return mpu.get_data_parallel_group() |
| 625 | |
| 626 | # Return the clone of dist world group |
| 627 | return _clone_world_group() |
| 628 | |
| 629 | |
| 630 | def _get_data_parallel_group_ranks(): |
no test coverage detected
searching dependent graphs…