MCPcopy
hub / github.com/hpcaitech/ColossalAI / flatten

Method flatten

colossalai/device/device_mesh.py:481–498  ·  view source on GitHub ↗

Flatten the logical mesh into an effective 1d logical mesh,

(self)

Source from the content-addressed store, hash-verified

479 return global_pg_ranks
480
481 def flatten(self):
482 """
483 Flatten the logical mesh into an effective 1d logical mesh,
484 """
485 if self._is_init_from_process_group:
486 raise RuntimeError(
487 "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
488 )
489
490 flatten_mesh_shape_size = len(self._mesh_shape)
491 flatten_mesh_shape = [self.num_devices]
492 return DeviceMesh(
493 self._physical_mesh_id,
494 tuple(flatten_mesh_shape),
495 mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
496 mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
497 init_process_group=self._init_process_group,
498 )
499
500 def all_gather_cost(self, num_bytes, mesh_dim):
501 num_devices = self.logical_mesh_id.shape[mesh_dim]

Callers 15

find_paramsMethod · 0.45
init_param_dataMethod · 0.45
chunk_and_padFunction · 0.45
get_shardFunction · 0.45
forwardMethod · 0.45
forwardMethod · 0.45
forwardMethod · 0.45
forwardMethod · 0.45
appendMethod · 0.45
__init__Method · 0.45
_all_reduce_fp8Function · 0.45

Calls 1

DeviceMeshClass · 0.85

Tested by 7

exam_dist_adafactor_baseFunction · 0.36
check_packed_seqFunction · 0.36
split_ddp_gradFunction · 0.36
forwardMethod · 0.36
forwardMethod · 0.36
split_ddp_gradFunction · 0.36