Add an operation to expand the tensor shape with singleton dimensions. That function adds a tensorrt.IShuffleLayer to the network. Given an 'input' of rank N and a sequence of M dimensions, the output tensor produced by this operation (when executed by TensorRT) will have a rank of
(input: Tensor,
dim: Union[int, Sequence[int]],
shape_cast_dtype=None)
| 1829 | |
| 1830 | |
| 1831 | def expand_dims(input: Tensor, |
| 1832 | dim: Union[int, Sequence[int]], |
| 1833 | shape_cast_dtype=None) -> Tensor: |
| 1834 | ''' |
| 1835 | Add an operation to expand the tensor shape with singleton dimensions. |
| 1836 | |
| 1837 | That function adds a tensorrt.IShuffleLayer to the network. Given an 'input' |
| 1838 | of rank N and a sequence of M dimensions, the output tensor produced by |
| 1839 | this operation (when executed by TensorRT) will have a rank of N+M. Singleton |
| 1840 | dimensions will be inserted at the different positions in 'dim'. |
| 1841 | |
| 1842 | The pseudo-code for that operation is: |
| 1843 | |
| 1844 | new_shape, ii = [], 0 |
| 1845 | for jj in range(input.rank() + len(dim)): |
| 1846 | new_shape.append(1 if jj in dims else input.shape[ii++]) |
| 1847 | |
| 1848 | For example, for a tensor of shape [3, 4, 1, 5] |
| 1849 | |
| 1850 | expand_dims(input, [0, 2]) |
| 1851 | |
| 1852 | will produce a tensor of shape [1, 3, 1, 4, 1, 5]. |
| 1853 | |
| 1854 | Parameters: |
| 1855 | input : Tensor |
| 1856 | The input tensor to expand. |
| 1857 | |
| 1858 | dim : Union[int, Sequence[int]] |
| 1859 | The positions in the output tensor where to insert singleton |
| 1860 | dimensions. |
| 1861 | |
| 1862 | Returns: |
| 1863 | The tensor produced by the shuffle layer. |
| 1864 | ''' |
| 1865 | if isinstance(dim, int): |
| 1866 | dim = (dim, ) |
| 1867 | |
| 1868 | out_ndim = len(dim) + input.ndim() |
| 1869 | |
| 1870 | input_shape = shape(input, cast_to_dtype=shape_cast_dtype) |
| 1871 | out_shapes = [] |
| 1872 | j = 0 |
| 1873 | for i in range(out_ndim): |
| 1874 | if i in dim: |
| 1875 | out_shapes.append(1) |
| 1876 | else: |
| 1877 | out_shapes.append(gather(input_shape, 0, j)) |
| 1878 | j = j + 1 |
| 1879 | |
| 1880 | out_shape = concat(out_shapes) |
| 1881 | |
| 1882 | return view(input, out_shape, zero_is_placeholder=False) |
| 1883 | |
| 1884 | |
| 1885 | # NOTE: Jointly added with Apple |
no test coverage detected