Removes a tensor dimension. Returns a tuple of all slices along a given dimension, already without it.
(input: Tensor, dim: int = 0)
| 3860 | |
| 3861 | |
| 3862 | def unbind(input: Tensor, dim: int = 0): |
| 3863 | ''' |
| 3864 | Removes a tensor dimension. |
| 3865 | |
| 3866 | Returns a tuple of all slices along a given dimension, already without it. |
| 3867 | ''' |
| 3868 | ndim = input.ndim() |
| 3869 | outputs = split(input, 1, dim) |
| 3870 | output_shape = [input.shape[i] for i in range(ndim) if i != dim] |
| 3871 | return [output.view(output_shape) for output in outputs] |
| 3872 | |
| 3873 | |
| 3874 | class AllReduceStrategy(IntEnum): |