Add an operation that splits a tensor into sub-tensors. This operation creates a list of tensors that are obtained from the input tensor by slicing it along the dimension 'dim'. If 'split_size_or_sections' is an integer, the tensor is split into 'input.shape[dim] / split_size_o
(tensor: Tensor,
split_size_or_sections: Union[int, Sequence[int]],
dim: int = 0)
| 3746 | |
| 3747 | |
| 3748 | def split(tensor: Tensor, |
| 3749 | split_size_or_sections: Union[int, Sequence[int]], |
| 3750 | dim: int = 0) -> Sequence[Tensor]: |
| 3751 | ''' |
| 3752 | Add an operation that splits a tensor into sub-tensors. |
| 3753 | |
| 3754 | This operation creates a list of tensors that are obtained from the input |
| 3755 | tensor by slicing it along the dimension 'dim'. If 'split_size_or_sections' |
| 3756 | is an integer, the tensor is split into 'input.shape[dim] / |
| 3757 | split_size_or_sections' slices. If 'split_size_or_sections' is a list of |
| 3758 | sizes, the tensor is split into 'len(split_size_or_sections)' slices and |
| 3759 | the size of the ith slice is given by 'split_size_or_sections[i]'. |
| 3760 | |
| 3761 | There are several constraints with the current implementation: |
| 3762 | |
| 3763 | - The input tensor must be static (no dynamic dimension), |
| 3764 | - If 'split_size_or_sections' is an integer, the number of elements in |
| 3765 | the 'dim' dimension of the input must be a multiple of |
| 3766 | 'split_size_or_sections': 'input.shape[dim] % split_size_or_sections == 0'. |
| 3767 | - If 'split_size_or_sections' is a sequence, the sum of the elements in |
| 3768 | 'split_size_or_sections' must be equal to the size in the dimension |
| 3769 | 'dim': 'input.shape[dim] == sum(ii for ii in split_size_or_sections)'. |
| 3770 | |
| 3771 | That operation is implemented using a 'slice' operation for each output |
| 3772 | slice. |
| 3773 | |
| 3774 | Parameters: |
| 3775 | tensor : Tensor |
| 3776 | The input tensor to slice. |
| 3777 | |
| 3778 | split_size_or_sections : Union[int, Sequence[int]] |
| 3779 | If it is an integer, it encodes the size of each slice. Otherwise, |
| 3780 | if it is a sequence, it is the size of each slice. |
| 3781 | |
| 3782 | dim : int |
| 3783 | The dimension of the tensor to slice. |
| 3784 | |
| 3785 | Returns: |
| 3786 | The list of tensors produced by the different operations. |
| 3787 | ''' |
| 3788 | assert not tensor.is_dynamic(dim) |
| 3789 | |
| 3790 | ndim = tensor.ndim() |
| 3791 | if dim < 0: |
| 3792 | dim += ndim |
| 3793 | dim_value = tensor.size()[dim] |
| 3794 | starts = [constant(dims_array([0])) for _ in range(ndim)] |
| 3795 | sizes = [shape(tensor, i) for i in range(ndim)] |
| 3796 | |
| 3797 | if isinstance(split_size_or_sections, int): |
| 3798 | # TODO: support non-divisible cases |
| 3799 | assert dim_value % split_size_or_sections == 0 |
| 3800 | num_sections = dim_value // split_size_or_sections |
| 3801 | sizes[dim] = constant(dims_array([split_size_or_sections])) |
| 3802 | |
| 3803 | outputs = [] |
| 3804 | for i in range(num_sections): |
| 3805 | starts[dim] = constant(dims_array([split_size_or_sections * i])) |
no test coverage detected