MCPcopy Index your code
hub / github.com/modelscope/ms-swift / pad

Method pad

swift/sequence_parallel/sequence_parallel.py:319–355  ·  view source on GitHub ↗

Pad tensor for sequence parallel

(self, tensor, padding_value, position_ids=None, dim=1)

Source from the content-addressed store, hash-verified

317 return query, key, value
318
319 def pad(self, tensor, padding_value, position_ids=None, dim=1):
320 """Pad tensor for sequence parallel"""
321 if self.rp_world_size > 1:
322 world_size = self.world_size * 2
323 else:
324 world_size = self.world_size
325
326 def _do_pad(tensor):
327 length = tensor.shape[dim]
328 pad_num = world_size - (length % world_size)
329 if pad_num == 0 or pad_num == world_size:
330 return tensor
331 if not isinstance(padding_value, torch.Tensor):
332 # ids
333 pad_shape = ((*tensor.shape[:dim], pad_num, *tensor.shape[dim + 1:]) if dim != -1 else
334 (*tensor.shape[:dim], pad_num))
335 pad = torch.full(pad_shape, padding_value, dtype=tensor.dtype, device=tensor.device)
336 tensor = torch.cat([tensor, pad], dim=dim)
337 else:
338 # For embeddings
339 tensor = torch.cat([tensor, padding_value.unsqueeze(0).repeat(tensor.shape[0], pad_num, 1)], dim=dim)
340 return tensor
341
342 if position_ids is not None and self.rp_world_size > 1:
343 cu_seqlens = get_cu_seqlens_from_position_ids(position_ids)
344 all_tensors = []
345 for i in range(len(cu_seqlens) - 1):
346 if dim == 1:
347 sub_tensor = tensor[:, cu_seqlens[i]:cu_seqlens[i + 1]]
348 elif dim == -1:
349 sub_tensor = tensor[..., cu_seqlens[i]:cu_seqlens[i + 1]]
350 else:
351 raise NotImplementedError()
352 all_tensors.append(_do_pad(sub_tensor))
353 tensor = torch.cat(all_tensors, dim=dim)
354
355 return _do_pad(tensor)
356
357 def gather(self, local_output, dim: int, position_ids=None):
358 """Gather tensor for sequence parallel - reverse of split"""

Callers 15

sdpa_maskMethod · 0.95
_attentionMethod · 0.95
gatherMethod · 0.95
pad_and_split_inputsMethod · 0.95
_get_encoded_batchMethod · 0.80
forwardMethod · 0.80
forwardMethod · 0.80

Calls 2

appendMethod · 0.80

Tested by

no test coverage detected