Add an argmax operation. As explained in the ONNX documentation, https://github.com/onnx/onnx/blob/main/docs/Operators.md#argmax that function creates a layer computing the indices of the max elements of the input tensor's element along the provided dim. The resulting ten
(input: Tensor, dim: int, keepdim: bool = False)
| 3301 | |
| 3302 | |
| 3303 | def argmax(input: Tensor, dim: int, keepdim: bool = False) -> Tensor: |
| 3304 | ''' |
| 3305 | Add an argmax operation. |
| 3306 | |
| 3307 | As explained in the ONNX documentation, |
| 3308 | |
| 3309 | https://github.com/onnx/onnx/blob/main/docs/Operators.md#argmax |
| 3310 | |
| 3311 | that function creates a layer computing the indices of the max elements of |
| 3312 | the input tensor's element along the provided dim. The resulting tensor |
| 3313 | has the same rank as the input if keepdims is True. If keepdims is False, |
| 3314 | then the resulting tensor has the reduced dimension pruned. |
| 3315 | |
| 3316 | Parameters: |
| 3317 | input : Tensor |
| 3318 | The input tensor. |
| 3319 | |
| 3320 | dim : int |
| 3321 | The dimension in which to compute the argmax indices. |
| 3322 | |
| 3323 | keepdim : bool |
| 3324 | Do we keep the dimension along which the reduction is performed? |
| 3325 | Yes, if set to True, no otherwise. |
| 3326 | |
| 3327 | Returns: |
| 3328 | The tensor produced by this argmax operation. |
| 3329 | ''' |
| 3330 | dim = dim_resolve_negative(dim, input.ndim()) |
| 3331 | axes = dim_to_trt_axes(dim) |
| 3332 | |
| 3333 | layer = default_trtnet().add_topk(input.trt_tensor, trt.TopKOperation.MAX, |
| 3334 | 1, axes) |
| 3335 | output = layer.get_output(1) |
| 3336 | |
| 3337 | if keepdim: |
| 3338 | return _create_tensor(output, layer) |
| 3339 | |
| 3340 | output = _create_tensor(output, layer) |
| 3341 | a = list(range(input.ndim())) |
| 3342 | for d in dim: |
| 3343 | a.pop(d) |
| 3344 | indices = constant(int32_array(a)) |
| 3345 | output_shape = shape(output) |
| 3346 | new_shape = gather(output_shape, 0, indices) |
| 3347 | return view(output, new_shape) |
| 3348 | |
| 3349 | |
| 3350 | def gelu(x: Tensor) -> Tensor: |
no test coverage detected