Pad tensor for sequence parallel
(self, tensor, padding_value, position_ids=None, dim=1)
| 317 | return query, key, value |
| 318 | |
| 319 | def pad(self, tensor, padding_value, position_ids=None, dim=1): |
| 320 | """Pad tensor for sequence parallel""" |
| 321 | if self.rp_world_size > 1: |
| 322 | world_size = self.world_size * 2 |
| 323 | else: |
| 324 | world_size = self.world_size |
| 325 | |
| 326 | def _do_pad(tensor): |
| 327 | length = tensor.shape[dim] |
| 328 | pad_num = world_size - (length % world_size) |
| 329 | if pad_num == 0 or pad_num == world_size: |
| 330 | return tensor |
| 331 | if not isinstance(padding_value, torch.Tensor): |
| 332 | # ids |
| 333 | pad_shape = ((*tensor.shape[:dim], pad_num, *tensor.shape[dim + 1:]) if dim != -1 else |
| 334 | (*tensor.shape[:dim], pad_num)) |
| 335 | pad = torch.full(pad_shape, padding_value, dtype=tensor.dtype, device=tensor.device) |
| 336 | tensor = torch.cat([tensor, pad], dim=dim) |
| 337 | else: |
| 338 | # For embeddings |
| 339 | tensor = torch.cat([tensor, padding_value.unsqueeze(0).repeat(tensor.shape[0], pad_num, 1)], dim=dim) |
| 340 | return tensor |
| 341 | |
| 342 | if position_ids is not None and self.rp_world_size > 1: |
| 343 | cu_seqlens = get_cu_seqlens_from_position_ids(position_ids) |
| 344 | all_tensors = [] |
| 345 | for i in range(len(cu_seqlens) - 1): |
| 346 | if dim == 1: |
| 347 | sub_tensor = tensor[:, cu_seqlens[i]:cu_seqlens[i + 1]] |
| 348 | elif dim == -1: |
| 349 | sub_tensor = tensor[..., cu_seqlens[i]:cu_seqlens[i + 1]] |
| 350 | else: |
| 351 | raise NotImplementedError() |
| 352 | all_tensors.append(_do_pad(sub_tensor)) |
| 353 | tensor = torch.cat(all_tensors, dim=dim) |
| 354 | |
| 355 | return _do_pad(tensor) |
| 356 | |
| 357 | def gather(self, local_output, dim: int, position_ids=None): |
| 358 | """Gather tensor for sequence parallel - reverse of split""" |
no test coverage detected