Add an topk operation. As explained in the ONNX documentation, https://github.com/onnx/onnx/blob/main/docs/Operators.md#topk NOTE: One distinction from the ONNX topk op, the output is always sorted with TensorRT layer. Retrieve the top-K largest elements along a spec
(input: Tensor,
k: Union[Tensor, int],
dim: int,
largest: bool = True,
prefer_plugin: bool = True)
| 7318 | |
| 7319 | |
| 7320 | def topk(input: Tensor, |
| 7321 | k: Union[Tensor, int], |
| 7322 | dim: int, |
| 7323 | largest: bool = True, |
| 7324 | prefer_plugin: bool = True) -> Tuple[Tensor, Tensor]: |
| 7325 | ''' |
| 7326 | Add an topk operation. |
| 7327 | |
| 7328 | As explained in the ONNX documentation, |
| 7329 | |
| 7330 | https://github.com/onnx/onnx/blob/main/docs/Operators.md#topk |
| 7331 | |
| 7332 | NOTE: One distinction from the ONNX topk op, the output is always sorted |
| 7333 | with TensorRT layer. |
| 7334 | |
| 7335 | Retrieve the top-K largest elements along a specified axis. |
| 7336 | Given an input tensor of shape [a_1, a_2, ..., a_n, r] |
| 7337 | and integer argument k, return two outputs: |
| 7338 | Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which contains the values of the top k elements along the specified axis |
| 7339 | Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which contains the indices of the top k elements (original indices from the input tensor). |
| 7340 | |
| 7341 | Parameters: |
| 7342 | input : Tensor |
| 7343 | The input tensor. |
| 7344 | |
| 7345 | k : int |
| 7346 | A single positive value corresponding to the number of top elements to retrieve |
| 7347 | |
| 7348 | dim: int |
| 7349 | The dimension in which to compute the topk indices. |
| 7350 | |
| 7351 | largest: bool |
| 7352 | Controls whether to return largest or smallest elements |
| 7353 | |
| 7354 | prefer_plugin : bool |
| 7355 | Whether to use the topkLastDim plugin if dim is last dim and k is static. |
| 7356 | |
| 7357 | |
| 7358 | Returns: |
| 7359 | The tensors (values, indices) produced by this topk operation. |
| 7360 | ''' |
| 7361 | dim = dim_resolve_negative(dim, input.ndim())[0] |
| 7362 | if prefer_plugin and dim == input.ndim() - 1 and not isinstance(k, Tensor): |
| 7363 | last_dim = input.size(-1) |
| 7364 | if last_dim == -1: # dynamic? |
| 7365 | last_dim = shape(input, -1) |
| 7366 | # since we might need to flatten the input to 2d tensor, |
| 7367 | # we need to prepare the output shape |
| 7368 | out_shape = [] |
| 7369 | for i in range(input.ndim() - 1): |
| 7370 | out_shape.append(shape(input, i)) |
| 7371 | out_shape = concat(out_shape + [k]) |
| 7372 | if input.ndim() == 1: |
| 7373 | input_2d = unsqueeze(input, |
| 7374 | 0) # special handling of rank-1 dynamic tensor |
| 7375 | elif input.ndim() != 2: |
| 7376 | input_2d = input.view(concat([-1, last_dim]), |
| 7377 | zero_is_placeholder=False) |
no test coverage detected