Flatten the logical mesh into an effective 1d logical mesh,
(self)
| 479 | return global_pg_ranks |
| 480 | |
| 481 | def flatten(self): |
| 482 | """ |
| 483 | Flatten the logical mesh into an effective 1d logical mesh, |
| 484 | """ |
| 485 | if self._is_init_from_process_group: |
| 486 | raise RuntimeError( |
| 487 | "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." |
| 488 | ) |
| 489 | |
| 490 | flatten_mesh_shape_size = len(self._mesh_shape) |
| 491 | flatten_mesh_shape = [self.num_devices] |
| 492 | return DeviceMesh( |
| 493 | self._physical_mesh_id, |
| 494 | tuple(flatten_mesh_shape), |
| 495 | mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), |
| 496 | mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), |
| 497 | init_process_group=self._init_process_group, |
| 498 | ) |
| 499 | |
| 500 | def all_gather_cost(self, num_bytes, mesh_dim): |
| 501 | num_devices = self.logical_mesh_id.shape[mesh_dim] |