Add an operation that splits a tensor into sub-tensors. This operation creates a list of tensors that are obtained from the input tensor by chunking it along the dimension 'dim'. It produces 'chunks' sub-tensors. That operation is only defined for static tensors (no dynamic di
(tensor: Tensor, chunks: int, dim: int = 0)
| 3822 | |
| 3823 | |
| 3824 | def chunk(tensor: Tensor, chunks: int, dim: int = 0) -> Tensor: |
| 3825 | ''' |
| 3826 | Add an operation that splits a tensor into sub-tensors. |
| 3827 | |
| 3828 | This operation creates a list of tensors that are obtained from the input |
| 3829 | tensor by chunking it along the dimension 'dim'. It produces 'chunks' |
| 3830 | sub-tensors. |
| 3831 | |
| 3832 | That operation is only defined for static tensors (no dynamic dimension) |
| 3833 | and the size of the tensor in the dimension 'dim' must be a multiple of |
| 3834 | 'chunks': 'input.shape[dim] % chunks == 0'. |
| 3835 | |
| 3836 | It maps to 'split' with 'split_size = input.shape[dim] / chunks'. |
| 3837 | |
| 3838 | Parameters: |
| 3839 | tensor : Tensor |
| 3840 | The input tensor to slice. |
| 3841 | |
| 3842 | chunks : int |
| 3843 | The number of slices to split the input tensor into. |
| 3844 | |
| 3845 | dim : int |
| 3846 | The dimension of the tensor to slice. |
| 3847 | |
| 3848 | Returns: |
| 3849 | The list of tensors produced by the different operations. |
| 3850 | ''' |
| 3851 | assert not tensor.is_dynamic(dim) |
| 3852 | |
| 3853 | ndim = tensor.ndim() |
| 3854 | if dim < 0: |
| 3855 | dim += ndim |
| 3856 | dim_value = tensor.size()[dim] |
| 3857 | assert dim_value % chunks == 0 |
| 3858 | |
| 3859 | return split(tensor, dim_value // chunks, dim) |
| 3860 | |
| 3861 | |
| 3862 | def unbind(input: Tensor, dim: int = 0): |
no test coverage detected