MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / argmax

Function argmax

tensorrt_llm/functional.py:3303–3347  ·  view source on GitHub ↗

Add an argmax operation. As explained in the ONNX documentation, https://github.com/onnx/onnx/blob/main/docs/Operators.md#argmax that function creates a layer computing the indices of the max elements of the input tensor's element along the provided dim. The resulting ten

(input: Tensor, dim: int, keepdim: bool = False)

Source from the content-addressed store, hash-verified

3301
3302
3303def argmax(input: Tensor, dim: int, keepdim: bool = False) -> Tensor:
3304 '''
3305 Add an argmax operation.
3306
3307 As explained in the ONNX documentation,
3308
3309 https://github.com/onnx/onnx/blob/main/docs/Operators.md#argmax
3310
3311 that function creates a layer computing the indices of the max elements of
3312 the input tensor's element along the provided dim. The resulting tensor
3313 has the same rank as the input if keepdims is True. If keepdims is False,
3314 then the resulting tensor has the reduced dimension pruned.
3315
3316 Parameters:
3317 input : Tensor
3318 The input tensor.
3319
3320 dim : int
3321 The dimension in which to compute the argmax indices.
3322
3323 keepdim : bool
3324 Do we keep the dimension along which the reduction is performed?
3325 Yes, if set to True, no otherwise.
3326
3327 Returns:
3328 The tensor produced by this argmax operation.
3329 '''
3330 dim = dim_resolve_negative(dim, input.ndim())
3331 axes = dim_to_trt_axes(dim)
3332
3333 layer = default_trtnet().add_topk(input.trt_tensor, trt.TopKOperation.MAX,
3334 1, axes)
3335 output = layer.get_output(1)
3336
3337 if keepdim:
3338 return _create_tensor(output, layer)
3339
3340 output = _create_tensor(output, layer)
3341 a = list(range(input.ndim()))
3342 for d in dim:
3343 a.pop(d)
3344 indices = constant(int32_array(a))
3345 output_shape = shape(output)
3346 new_shape = gather(output_shape, 0, indices)
3347 return view(output, new_shape)
3348
3349
3350def gelu(x: Tensor) -> Tensor:

Callers 2

categorical_sampleFunction · 0.85

Calls 11

dim_resolve_negativeFunction · 0.85
dim_to_trt_axesFunction · 0.85
default_trtnetFunction · 0.85
_create_tensorFunction · 0.85
constantFunction · 0.85
viewFunction · 0.85
popMethod · 0.80
shapeFunction · 0.70
gatherFunction · 0.70
ndimMethod · 0.45
get_outputMethod · 0.45

Tested by

no test coverage detected