MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / split

Function split

tensorrt_llm/functional.py:3748–3821  ·  view source on GitHub ↗

Add an operation that splits a tensor into sub-tensors. This operation creates a list of tensors that are obtained from the input tensor by slicing it along the dimension 'dim'. If 'split_size_or_sections' is an integer, the tensor is split into 'input.shape[dim] / split_size_o

(tensor: Tensor,
          split_size_or_sections: Union[int, Sequence[int]],
          dim: int = 0)

Source from the content-addressed store, hash-verified

3746
3747
3748def split(tensor: Tensor,
3749 split_size_or_sections: Union[int, Sequence[int]],
3750 dim: int = 0) -> Sequence[Tensor]:
3751 '''
3752 Add an operation that splits a tensor into sub-tensors.
3753
3754 This operation creates a list of tensors that are obtained from the input
3755 tensor by slicing it along the dimension 'dim'. If 'split_size_or_sections'
3756 is an integer, the tensor is split into 'input.shape[dim] /
3757 split_size_or_sections' slices. If 'split_size_or_sections' is a list of
3758 sizes, the tensor is split into 'len(split_size_or_sections)' slices and
3759 the size of the ith slice is given by 'split_size_or_sections[i]'.
3760
3761 There are several constraints with the current implementation:
3762
3763 - The input tensor must be static (no dynamic dimension),
3764 - If 'split_size_or_sections' is an integer, the number of elements in
3765 the 'dim' dimension of the input must be a multiple of
3766 'split_size_or_sections': 'input.shape[dim] % split_size_or_sections == 0'.
3767 - If 'split_size_or_sections' is a sequence, the sum of the elements in
3768 'split_size_or_sections' must be equal to the size in the dimension
3769 'dim': 'input.shape[dim] == sum(ii for ii in split_size_or_sections)'.
3770
3771 That operation is implemented using a 'slice' operation for each output
3772 slice.
3773
3774 Parameters:
3775 tensor : Tensor
3776 The input tensor to slice.
3777
3778 split_size_or_sections : Union[int, Sequence[int]]
3779 If it is an integer, it encodes the size of each slice. Otherwise,
3780 if it is a sequence, it is the size of each slice.
3781
3782 dim : int
3783 The dimension of the tensor to slice.
3784
3785 Returns:
3786 The list of tensors produced by the different operations.
3787 '''
3788 assert not tensor.is_dynamic(dim)
3789
3790 ndim = tensor.ndim()
3791 if dim < 0:
3792 dim += ndim
3793 dim_value = tensor.size()[dim]
3794 starts = [constant(dims_array([0])) for _ in range(ndim)]
3795 sizes = [shape(tensor, i) for i in range(ndim)]
3796
3797 if isinstance(split_size_or_sections, int):
3798 # TODO: support non-divisible cases
3799 assert dim_value % split_size_or_sections == 0
3800 num_sections = dim_value // split_size_or_sections
3801 sizes[dim] = constant(dims_array([split_size_or_sections]))
3802
3803 outputs = []
3804 for i in range(num_sections):
3805 starts[dim] = constant(dims_array([split_size_or_sections * i]))

Callers 5

splitMethod · 0.70
chunkFunction · 0.70
unbindFunction · 0.70

Calls 9

constantFunction · 0.85
dims_arrayFunction · 0.85
sliceFunction · 0.85
concatFunction · 0.85
is_dynamicMethod · 0.80
shapeFunction · 0.70
ndimMethod · 0.45
sizeMethod · 0.45
appendMethod · 0.45

Tested by

no test coverage detected