MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / select

Function select

tensorrt_llm/functional.py:2160–2213  ·  view source on GitHub ↗

Add an operation to select a slice of elements from a tensor. Given an input tensor, that function creates an operation that selects the index-th slice of elements in the dimension 'dim' to create a new tensor. The output tensor has a shape in which the input dimension 'dim' is

(input: Tensor, dim: int, index: Union[Tensor, int])

Source from the content-addressed store, hash-verified

2158
2159
2160def select(input: Tensor, dim: int, index: Union[Tensor, int]) -> Tensor:
2161 '''
2162 Add an operation to select a slice of elements from a tensor.
2163
2164 Given an input tensor, that function creates an operation that selects the
2165 index-th slice of elements in the dimension 'dim' to create a new tensor.
2166 The output tensor has a shape in which the input dimension 'dim' is
2167 removed.
2168
2169 The 'index' can either be an integer or a 1D tensor containing a single
2170 element.
2171
2172 For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape
2173 [3, 3],
2174
2175 select(input, 0, 1)
2176
2177 will create a tensor of shape [3] that contains the [2, 1, 2].
2178
2179 Regarding the shape of the output tensor, the dimension 'dim' is removed.
2180 It means that for a tensor of shape [4, 2, 6, 3],
2181
2182 select(input, 2, 4)
2183
2184 will select the 5th slice (index == 4) from the 3rd dimension (dim == 2)
2185 and return a tensor of shape [4, 2, 3] (i.e. the 3rd dimension is removed).
2186
2187 That operation maps to the TensorRT IGatherLayer.
2188
2189 Parameters:
2190 input : Tensor
2191 The input tensor to select from.
2192
2193 dim : int
2194 The dimension to select from.
2195
2196 index : Union[Tensor, int]
2197 The index of the slice in the 'dim' dimension to select.
2198
2199 Returns:
2200 The tensor containing the selected slice.
2201 '''
2202 if isinstance(index, int):
2203 index = constant(int32_array([index]))
2204 assert index.rank() == 1 and index.size(
2205 0) == 1, f"index should have rank 1, got {index.rank()}"
2206
2207 new_shape = []
2208 for i in range(input.rank()):
2209 if i != dim:
2210 new_shape.append(shape(input, i))
2211
2212 layer = default_trtnet().add_gather(input.trt_tensor, index.trt_tensor, dim)
2213 return _create_tensor(layer.get_output(0), layer).view(concat(new_shape))
2214
2215
2216def index_select(input: Tensor, dim: int, index: Tensor) -> Tensor:

Callers 5

_validate_draft_tokensFunction · 0.90
selectMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85

Calls 10

constantFunction · 0.85
default_trtnetFunction · 0.85
_create_tensorFunction · 0.85
concatFunction · 0.85
shapeFunction · 0.70
rankMethod · 0.45
sizeMethod · 0.45
appendMethod · 0.45
viewMethod · 0.45
get_outputMethod · 0.45

Tested by

no test coverage detected