MCPcopy Index your code
hub / github.com/modelscope/ms-swift / gather

Method gather

swift/sequence_parallel/sequence_parallel.py:357–414  ·  view source on GitHub ↗

Gather tensor for sequence parallel - reverse of split

(self, local_output, dim: int, position_ids=None)

Source from the content-addressed store, hash-verified

355 return _do_pad(tensor)
356
357 def gather(self, local_output, dim: int, position_ids=None):
358 """Gather tensor for sequence parallel - reverse of split"""
359 if self.world_size == 1:
360 return local_output
361
362 if self.rp_world_size > 1:
363 input_dim = local_output.dim()
364 assert input_dim >= 2
365
366 if position_ids is not None:
367 position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids)
368
369 # Step 1: Gather from all sequence parallel ranks
370 # Each sp_rank has its own piece, we need to gather them first
371 gathered_sp = [torch.zeros_like(local_output) for _ in range(self.sp_world_size)]
372 torch.distributed.all_gather(gathered_sp, local_output.contiguous(), group=self.sp_group)
373
374 # Concatenate the sp pieces to form the complete chunk for this rp_rank
375 rp_chunk = torch.cat(gathered_sp, dim=dim)
376
377 # Step 2: Gather all rp chunks
378 gathered_rp = [torch.zeros_like(rp_chunk) for _ in range(self.rp_world_size)]
379 torch.distributed.all_gather(gathered_rp, rp_chunk, group=self.rp_group)
380
381 cu_seqlens = get_cu_seqlens_from_position_ids(position_ids)
382 all_tensor_length = []
383 for i in range(len(cu_seqlens) - 1):
384 length = cu_seqlens[i + 1] - cu_seqlens[i]
385 padding_length = math.ceil(length / (self.world_size * 2)) * (self.world_size * 2)
386 all_tensor_length.append(padding_length)
387
388 full_output = torch.zeros(
389 [local_output.shape[0], sum(all_tensor_length), *local_output.shape[2:]], device=local_output.device)
390 for idx_rp, rp_tensor in enumerate(gathered_rp): # rp world size
391 # re-group the zigzag to the correct order
392 accumulated_length = 0
393 for idx_seq, length in enumerate(all_tensor_length): # sequence number
394 local_length = length // self.rp_world_size
395 local_tensor = rp_tensor[:, accumulated_length:local_length + accumulated_length]
396 chunk_size = local_length // 2
397 left_idx = accumulated_length * self.rp_world_size + idx_rp * chunk_size
398 right_idx = accumulated_length * self.rp_world_size + (idx_rp + 1) * chunk_size
399 full_output[:, left_idx:right_idx] = local_tensor[:, :chunk_size]
400 left_idx = accumulated_length * self.rp_world_size + (2 * self.rp_world_size - idx_rp
401 - 1) * chunk_size
402 right_idx = accumulated_length * self.rp_world_size + (2 * self.rp_world_size - idx_rp) * chunk_size
403 full_output[:, left_idx:right_idx] = local_tensor[:, chunk_size:]
404 accumulated_length += local_length
405
406 return full_output.contiguous()
407 else:
408 gathered_sp = torch.empty(
409 [local_output.shape[0] * self.sp_world_size] + list(local_output.shape[1:]),
410 dtype=local_output.dtype,
411 device=local_output.device)
412 dist.all_gather_into_tensor(gathered_sp, local_output, group=self.sp_group)
413 gathered_sp = torch.cat(gathered_sp.split(local_output.shape[0], dim=0), dim=dim)
414 return gathered_sp.contiguous()

Callers 15

run_parallelMethod · 0.80
__call__Method · 0.80
run_asyncMethod · 0.80
run_asyncMethod · 0.80
__call__Method · 0.80
runMethod · 0.80
vocab_parallel_topkFunction · 0.80
tp_gather_topkFunction · 0.80
_run_async_funcsMethod · 0.80
_run_allMethod · 0.80
_genMethod · 0.80

Calls 4

padMethod · 0.95
appendMethod · 0.80
splitMethod · 0.80

Tested by 2

run_parallelMethod · 0.64
__call__Method · 0.64