Add an operation to remove singleton dimensions of a tensor. This functions creates an operation that removes singleton dimension (dimension of size 1) at positions 'dim' in the input tensor. It works with negative values for the 'dim'. For example, for a tensor 'input' of sha
(input: Tensor,
dim: Optional[Union[int, Sequence[int]]] = None,
zero_is_placeholder: bool = False)
| 1884 | |
| 1885 | # NOTE: Jointly added with Apple |
| 1886 | def squeeze(input: Tensor, |
| 1887 | dim: Optional[Union[int, Sequence[int]]] = None, |
| 1888 | zero_is_placeholder: bool = False): |
| 1889 | ''' |
| 1890 | Add an operation to remove singleton dimensions of a tensor. |
| 1891 | |
| 1892 | This functions creates an operation that removes singleton dimension |
| 1893 | (dimension of size 1) at positions 'dim' in the input tensor. It works with |
| 1894 | negative values for the 'dim'. |
| 1895 | |
| 1896 | For example, for a tensor 'input' of shape [1, 4, 1, 4]: |
| 1897 | |
| 1898 | squeeze(input, 0) will produce an output of shape [4, 1, 4], |
| 1899 | squeeze(input, 2) will produce an output of shape [1, 4, 4], |
| 1900 | squeeze(input, [0, 2]) will produce an output of shape [4, 4], |
| 1901 | squeeze(input, [-2]) will produce an output of shape [1, 4, 4], |
| 1902 | |
| 1903 | Parameters: |
| 1904 | input : Tensor |
| 1905 | The input tensor for which the singleton dimensions will be removed. |
| 1906 | |
| 1907 | dim : Union[int, Sequence[int]] |
| 1908 | The index of the singleton dimensions in the input tensor. |
| 1909 | |
| 1910 | Returns: |
| 1911 | The tensor produced by the layer. |
| 1912 | ''' |
| 1913 | if dim is None: |
| 1914 | dim = list(range(input.ndim())) |
| 1915 | if isinstance(dim, int): |
| 1916 | dim = (dim, ) |
| 1917 | dim = dim_resolve_negative(dim, input.ndim()) |
| 1918 | |
| 1919 | new_shape = [] |
| 1920 | for i, s in enumerate(input.shape): |
| 1921 | if s == 1 and i in dim: |
| 1922 | continue |
| 1923 | new_shape.append(shape(input, i)) |
| 1924 | |
| 1925 | new_shape = concat(new_shape) if len(new_shape) > 0 else [] |
| 1926 | input = input.view(new_shape, zero_is_placeholder=zero_is_placeholder) |
| 1927 | return input |
| 1928 | |
| 1929 | |
| 1930 | def unsqueeze(input: Tensor, axis: int): |
no test coverage detected