MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / expand_dims

Function expand_dims

tensorrt_llm/functional.py:1831–1882  ·  view source on GitHub ↗

Add an operation to expand the tensor shape with singleton dimensions. That function adds a tensorrt.IShuffleLayer to the network. Given an 'input' of rank N and a sequence of M dimensions, the output tensor produced by this operation (when executed by TensorRT) will have a rank of

(input: Tensor,
                dim: Union[int, Sequence[int]],
                shape_cast_dtype=None)

Source from the content-addressed store, hash-verified

1829
1830
1831def expand_dims(input: Tensor,
1832 dim: Union[int, Sequence[int]],
1833 shape_cast_dtype=None) -> Tensor:
1834 '''
1835 Add an operation to expand the tensor shape with singleton dimensions.
1836
1837 That function adds a tensorrt.IShuffleLayer to the network. Given an 'input'
1838 of rank N and a sequence of M dimensions, the output tensor produced by
1839 this operation (when executed by TensorRT) will have a rank of N+M. Singleton
1840 dimensions will be inserted at the different positions in 'dim'.
1841
1842 The pseudo-code for that operation is:
1843
1844 new_shape, ii = [], 0
1845 for jj in range(input.rank() + len(dim)):
1846 new_shape.append(1 if jj in dims else input.shape[ii++])
1847
1848 For example, for a tensor of shape [3, 4, 1, 5]
1849
1850 expand_dims(input, [0, 2])
1851
1852 will produce a tensor of shape [1, 3, 1, 4, 1, 5].
1853
1854 Parameters:
1855 input : Tensor
1856 The input tensor to expand.
1857
1858 dim : Union[int, Sequence[int]]
1859 The positions in the output tensor where to insert singleton
1860 dimensions.
1861
1862 Returns:
1863 The tensor produced by the shuffle layer.
1864 '''
1865 if isinstance(dim, int):
1866 dim = (dim, )
1867
1868 out_ndim = len(dim) + input.ndim()
1869
1870 input_shape = shape(input, cast_to_dtype=shape_cast_dtype)
1871 out_shapes = []
1872 j = 0
1873 for i in range(out_ndim):
1874 if i in dim:
1875 out_shapes.append(1)
1876 else:
1877 out_shapes.append(gather(input_shape, 0, j))
1878 j = j + 1
1879
1880 out_shape = concat(out_shapes)
1881
1882 return view(input, out_shape, zero_is_placeholder=False)
1883
1884
1885# NOTE: Jointly added with Apple

Callers 15

_get_draft_token_indicesFunction · 0.90
_get_draft_token_arrayFunction · 0.90
_get_maskFunction · 0.90
_beam_search_candidatesFunction · 0.90
_batch_index_selectFunction · 0.90
unsqueezeFunction · 0.85
expand_dims_likeFunction · 0.85
cumsumFunction · 0.85
embeddingFunction · 0.85
gegeluFunction · 0.85

Calls 6

concatFunction · 0.85
viewFunction · 0.85
shapeFunction · 0.70
gatherFunction · 0.70
ndimMethod · 0.45
appendMethod · 0.45

Tested by

no test coverage detected