MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / index_select

Function index_select

tensorrt_llm/functional.py:2216–2275  ·  view source on GitHub ↗

Add an operation to select slices of elements from a tensor. Given an input tensor, that function creates an operation that selects the slices of elements in the dimension 'dim' at the indices listed in 'index' to create a new tensor. The output tensor has the same rank as the inp

(input: Tensor, dim: int, index: Tensor)

Source from the content-addressed store, hash-verified

2214
2215
2216def index_select(input: Tensor, dim: int, index: Tensor) -> Tensor:
2217 '''
2218 Add an operation to select slices of elements from a tensor.
2219
2220 Given an input tensor, that function creates an operation that selects the
2221 slices of elements in the dimension 'dim' at the indices listed in 'index'
2222 to create a new tensor. The output tensor has the same rank as the input
2223 tensor.
2224
2225 The 'index' is a tensor of rank 1.
2226
2227 For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape
2228 [3, 3],
2229
2230 index_select(input, 0, [0, 1])
2231
2232 will create a tensor of shape [2, 3] that contains the [[4, 2, 5], [2, 1, 2]].
2233
2234 Regarding the shape of the output tensor, the dimension 'dim' has the same
2235 size as the 'index' tensor. It means that for a input tensor of shape [4, 2, 6, 3],
2236
2237 index_select(input, 2, [1, 4])
2238
2239 will select the 2nd and 5th slices (index == 1 or 4) from the 3rd dimension
2240 (dim == 2) and return a tensor of shape [4, 2, 2, 3] (i.e. the 3rd
2241 dimension is shrunk to 2).
2242
2243 Note that this operation can also be used to expand a tensor in the 'dim'
2244 dimension, for example, on input [[0, 1], [2, 3]],
2245
2246 index_select(input, 1, [0, 0, 0])
2247
2248 will produce a tensor of shape [2, 3] containing [[0, 0, 0], [2, 2, 2]].
2249
2250 That operation maps to the TensorRT IGatherLayer.
2251
2252 Parameters:
2253 input : Tensor
2254 The input tensor to select from.
2255
2256 dim : int
2257 The dimension to select from.
2258
2259 index : Tensor
2260 The indices of the slices in the 'dim' dimension to select.
2261
2262 Returns:
2263 The tensor containing the selected slices.
2264 '''
2265 assert index.rank() == 1, f"index should have rank 1, got {index.rank()}"
2266
2267 new_shape = []
2268 for i in range(input.rank()):
2269 if i != dim:
2270 new_shape.append(shape(input, i))
2271 else:
2272 new_shape.append(shape(index, 0))
2273

Callers 9

_validate_draft_tokensFunction · 0.90
_unpack_gen_dataFunction · 0.90
padFunction · 0.85
gather_last_token_logitsFunction · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
_slice_hidden_statesMethod · 0.85

Calls 8

default_trtnetFunction · 0.85
_create_tensorFunction · 0.85
concatFunction · 0.85
shapeFunction · 0.70
rankMethod · 0.45
appendMethod · 0.45
viewMethod · 0.45
get_outputMethod · 0.45

Tested by

no test coverage detected