MCPcopy
hub / github.com/modelscope/ms-swift / split

Method split

swift/sequence_parallel/sequence_parallel.py:434–456  ·  view source on GitHub ↗

Split tensor for sequence parallel

(self, input, dim: int, position_ids=None)

Source from the content-addressed store, hash-verified

432 return torch.cat(local_values, dim=dim).contiguous()
433
434 def split(self, input, dim: int, position_ids=None):
435 """Split tensor for sequence parallel"""
436 if self.world_size == 1:
437 return input
438
439 if self.rp_world_size > 1:
440 input_dim = input.dim()
441 assert input_dim >= 2
442 cu_seqlens = get_cu_seqlens_from_position_ids(position_ids)
443 assert torch.all(cu_seqlens % (2 * self.rp_world_size) == 0)
444 value_chunks = self._split_packed(input, cu_seqlens, dim=dim)
445 local_value = value_chunks.chunk(self.sp_world_size, dim=dim)[self.sp_rank].contiguous()
446 return local_value
447 else:
448 rank = self.sp_rank
449 dim_size = input.size(dim)
450 assert dim_size % self.sp_world_size == 0, (
451 f'The dimension to split ({dim_size}) is not a multiple of '
452 f'world size ({self.sp_world_size}), cannot split tensor evenly')
453
454 tensor_list = torch.split(input, dim_size // self.sp_world_size, dim=dim)
455 output = tensor_list[rank].contiguous()
456 return output
457
458 def pad_and_split_mm_tokens(self, visual_mask, mm_embeds):
459 input_ids = self.extra_kwargs['input_ids']

Callers 15

pad_and_split_inputsMethod · 0.95
parse_lineFunction · 0.80
parse_require_fileFunction · 0.80
get_cache_mappingFunction · 0.80
get_url_suffixFunction · 0.80
get_cache_mappingFunction · 0.80
_get_metricMethod · 0.80
llm_expFunction · 0.80
get_selected_casesFunction · 0.80
get_case_model_infoFunction · 0.80

Calls 2

_split_packedMethod · 0.95

Tested by 2

get_case_model_infoFunction · 0.64
_startMethod · 0.64