Gather tensor for sequence parallel - reverse of split
(self, local_output, dim: int, position_ids=None)
| 355 | return _do_pad(tensor) |
| 356 | |
| 357 | def gather(self, local_output, dim: int, position_ids=None): |
| 358 | """Gather tensor for sequence parallel - reverse of split""" |
| 359 | if self.world_size == 1: |
| 360 | return local_output |
| 361 | |
| 362 | if self.rp_world_size > 1: |
| 363 | input_dim = local_output.dim() |
| 364 | assert input_dim >= 2 |
| 365 | |
| 366 | if position_ids is not None: |
| 367 | position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids) |
| 368 | |
| 369 | # Step 1: Gather from all sequence parallel ranks |
| 370 | # Each sp_rank has its own piece, we need to gather them first |
| 371 | gathered_sp = [torch.zeros_like(local_output) for _ in range(self.sp_world_size)] |
| 372 | torch.distributed.all_gather(gathered_sp, local_output.contiguous(), group=self.sp_group) |
| 373 | |
| 374 | # Concatenate the sp pieces to form the complete chunk for this rp_rank |
| 375 | rp_chunk = torch.cat(gathered_sp, dim=dim) |
| 376 | |
| 377 | # Step 2: Gather all rp chunks |
| 378 | gathered_rp = [torch.zeros_like(rp_chunk) for _ in range(self.rp_world_size)] |
| 379 | torch.distributed.all_gather(gathered_rp, rp_chunk, group=self.rp_group) |
| 380 | |
| 381 | cu_seqlens = get_cu_seqlens_from_position_ids(position_ids) |
| 382 | all_tensor_length = [] |
| 383 | for i in range(len(cu_seqlens) - 1): |
| 384 | length = cu_seqlens[i + 1] - cu_seqlens[i] |
| 385 | padding_length = math.ceil(length / (self.world_size * 2)) * (self.world_size * 2) |
| 386 | all_tensor_length.append(padding_length) |
| 387 | |
| 388 | full_output = torch.zeros( |
| 389 | [local_output.shape[0], sum(all_tensor_length), *local_output.shape[2:]], device=local_output.device) |
| 390 | for idx_rp, rp_tensor in enumerate(gathered_rp): # rp world size |
| 391 | # re-group the zigzag to the correct order |
| 392 | accumulated_length = 0 |
| 393 | for idx_seq, length in enumerate(all_tensor_length): # sequence number |
| 394 | local_length = length // self.rp_world_size |
| 395 | local_tensor = rp_tensor[:, accumulated_length:local_length + accumulated_length] |
| 396 | chunk_size = local_length // 2 |
| 397 | left_idx = accumulated_length * self.rp_world_size + idx_rp * chunk_size |
| 398 | right_idx = accumulated_length * self.rp_world_size + (idx_rp + 1) * chunk_size |
| 399 | full_output[:, left_idx:right_idx] = local_tensor[:, :chunk_size] |
| 400 | left_idx = accumulated_length * self.rp_world_size + (2 * self.rp_world_size - idx_rp |
| 401 | - 1) * chunk_size |
| 402 | right_idx = accumulated_length * self.rp_world_size + (2 * self.rp_world_size - idx_rp) * chunk_size |
| 403 | full_output[:, left_idx:right_idx] = local_tensor[:, chunk_size:] |
| 404 | accumulated_length += local_length |
| 405 | |
| 406 | return full_output.contiguous() |
| 407 | else: |
| 408 | gathered_sp = torch.empty( |
| 409 | [local_output.shape[0] * self.sp_world_size] + list(local_output.shape[1:]), |
| 410 | dtype=local_output.dtype, |
| 411 | device=local_output.device) |
| 412 | dist.all_gather_into_tensor(gathered_sp, local_output, group=self.sp_group) |
| 413 | gathered_sp = torch.cat(gathered_sp.split(local_output.shape[0], dim=0), dim=dim) |
| 414 | return gathered_sp.contiguous() |