Return world size for the data parallel group.
()
| 682 | |
| 683 | |
| 684 | def _get_data_parallel_world_size(): |
| 685 | """Return world size for the data parallel group.""" |
| 686 | assert dist.is_initialized(), 'dist is not initialized' |
| 687 | if mesh_device is not None: |
| 688 | return dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel")) |
| 689 | global mpu |
| 690 | if mpu is not None: |
| 691 | if hasattr(mpu, 'initialize_sequence_parallel'): |
| 692 | return None |
| 693 | else: |
| 694 | return mpu.get_data_parallel_world_size() |
| 695 | return dist.get_world_size(group=_get_data_parallel_group()) |
| 696 | |
| 697 | |
| 698 | def _get_model_parallel_world_size(): |
no test coverage detected
searching dependent graphs…