MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / topk

Function topk

tensorrt_llm/functional.py:7320–7416  ·  view source on GitHub ↗

Add an topk operation. As explained in the ONNX documentation, https://github.com/onnx/onnx/blob/main/docs/Operators.md#topk NOTE: One distinction from the ONNX topk op, the output is always sorted with TensorRT layer. Retrieve the top-K largest elements along a spec

(input: Tensor,
         k: Union[Tensor, int],
         dim: int,
         largest: bool = True,
         prefer_plugin: bool = True)

Source from the content-addressed store, hash-verified

7318
7319
7320def topk(input: Tensor,
7321 k: Union[Tensor, int],
7322 dim: int,
7323 largest: bool = True,
7324 prefer_plugin: bool = True) -> Tuple[Tensor, Tensor]:
7325 '''
7326 Add an topk operation.
7327
7328 As explained in the ONNX documentation,
7329
7330 https://github.com/onnx/onnx/blob/main/docs/Operators.md#topk
7331
7332 NOTE: One distinction from the ONNX topk op, the output is always sorted
7333 with TensorRT layer.
7334
7335 Retrieve the top-K largest elements along a specified axis.
7336 Given an input tensor of shape [a_1, a_2, ..., a_n, r]
7337 and integer argument k, return two outputs:
7338 Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which contains the values of the top k elements along the specified axis
7339 Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which contains the indices of the top k elements (original indices from the input tensor).
7340
7341 Parameters:
7342 input : Tensor
7343 The input tensor.
7344
7345 k : int
7346 A single positive value corresponding to the number of top elements to retrieve
7347
7348 dim: int
7349 The dimension in which to compute the topk indices.
7350
7351 largest: bool
7352 Controls whether to return largest or smallest elements
7353
7354 prefer_plugin : bool
7355 Whether to use the topkLastDim plugin if dim is last dim and k is static.
7356
7357
7358 Returns:
7359 The tensors (values, indices) produced by this topk operation.
7360 '''
7361 dim = dim_resolve_negative(dim, input.ndim())[0]
7362 if prefer_plugin and dim == input.ndim() - 1 and not isinstance(k, Tensor):
7363 last_dim = input.size(-1)
7364 if last_dim == -1: # dynamic?
7365 last_dim = shape(input, -1)
7366 # since we might need to flatten the input to 2d tensor,
7367 # we need to prepare the output shape
7368 out_shape = []
7369 for i in range(input.ndim() - 1):
7370 out_shape.append(shape(input, i))
7371 out_shape = concat(out_shape + [k])
7372 if input.ndim() == 1:
7373 input_2d = unsqueeze(input,
7374 0) # special handling of rank-1 dynamic tensor
7375 elif input.ndim() != 2:
7376 input_2d = input.view(concat([-1, last_dim]),
7377 zero_is_placeholder=False)

Callers 9

_validate_draft_tokensFunction · 0.90
warp_logitsFunction · 0.90
_beam_search_candidatesFunction · 0.90
_top_1_logitsFunction · 0.90
default_routingMethod · 0.50
renormalizeMethod · 0.50
group_limited_greedyMethod · 0.50
mask_and_softmaxMethod · 0.50
forwardMethod · 0.50

Calls 15

dim_resolve_negativeFunction · 0.85
concatFunction · 0.85
unsqueezeFunction · 0.85
default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
dim_to_trt_axesFunction · 0.85
squeezeFunction · 0.85
create_pluginMethod · 0.80
shapeFunction · 0.70
ndimMethod · 0.45
sizeMethod · 0.45

Tested by

no test coverage detected