Split tensor for sequence parallel
(self, input, dim: int, position_ids=None)
| 432 | return torch.cat(local_values, dim=dim).contiguous() |
| 433 | |
| 434 | def split(self, input, dim: int, position_ids=None): |
| 435 | """Split tensor for sequence parallel""" |
| 436 | if self.world_size == 1: |
| 437 | return input |
| 438 | |
| 439 | if self.rp_world_size > 1: |
| 440 | input_dim = input.dim() |
| 441 | assert input_dim >= 2 |
| 442 | cu_seqlens = get_cu_seqlens_from_position_ids(position_ids) |
| 443 | assert torch.all(cu_seqlens % (2 * self.rp_world_size) == 0) |
| 444 | value_chunks = self._split_packed(input, cu_seqlens, dim=dim) |
| 445 | local_value = value_chunks.chunk(self.sp_world_size, dim=dim)[self.sp_rank].contiguous() |
| 446 | return local_value |
| 447 | else: |
| 448 | rank = self.sp_rank |
| 449 | dim_size = input.size(dim) |
| 450 | assert dim_size % self.sp_world_size == 0, ( |
| 451 | f'The dimension to split ({dim_size}) is not a multiple of ' |
| 452 | f'world size ({self.sp_world_size}), cannot split tensor evenly') |
| 453 | |
| 454 | tensor_list = torch.split(input, dim_size // self.sp_world_size, dim=dim) |
| 455 | output = tensor_list[rank].contiguous() |
| 456 | return output |
| 457 | |
| 458 | def pad_and_split_mm_tokens(self, visual_mask, mm_embeds): |
| 459 | input_ids = self.extra_kwargs['input_ids'] |