MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / gather

Function gather

tensorrt_llm/functional.py:2099–2157  ·  view source on GitHub ↗

Add an operation to gather elements from a tensor. That function implements the GatherElements operator from the ONNX specification as described in https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements The input and indices arguments must have the same r

(input: Tensor, dim: int, indices: Union[Tensor, int])

Source from the content-addressed store, hash-verified

2097
2098
2099def gather(input: Tensor, dim: int, indices: Union[Tensor, int]) -> Tensor:
2100 '''
2101 Add an operation to gather elements from a tensor.
2102
2103 That function implements the GatherElements operator from the ONNX
2104 specification as described in
2105
2106 https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements
2107
2108 The input and indices arguments must have the same rank >= 1. The operation
2109 will produce a tensor with the same shape as the indices tensor. The axis
2110 is the dimension to gather on.
2111
2112 As shown in the ONNX description, for a 3D tensor, the output is:
2113
2114 out[i][j][k] = input[indices[i][j][k]][j][k] if axis = 0,
2115 out[i][j][k] = input[i][indices[i][j][k]][k] if axis = 1,
2116 out[i][j][k] = input[i][j][indices[i][j][k]] if axis = 2.
2117
2118 For example,
2119
2120 gather([[4, 2], [5, 3]], 0, [[1, 0], [0, 1]])
2121
2122 will produce [[5, 2], [4, 3]].
2123
2124 gather([[1, 2, 3], [4, 5, 6], 1, [[1], [0]])
2125
2126 will produce [[2], [4]]. See the ONNX documentation for more examples.
2127
2128 That operation maps to the TensorRT IGatherLayer.
2129
2130 Parameters:
2131 input : Tensor
2132 The input tensor to gather elements from.
2133
2134 dim : int
2135 The dimension to gather on.
2136
2137 indices : Union[Tensor, int]
2138 The positions in the 'dim' dimension to gather from.
2139
2140 Returns:
2141 The tensor containing the gathered elements. It has the same shape as
2142 the indices tensor.
2143 '''
2144 if isinstance(indices, int):
2145 indices = constant(int32_array([indices]))
2146
2147 # The input and indices tensors must have the same rank.
2148 assert input.rank() == indices.rank()
2149
2150 layer = default_trtnet().add_gather_v2(input.trt_tensor,
2151 indices.trt_tensor,
2152 mode=trt.GatherMode.ELEMENT)
2153
2154 if dim < 0:
2155 dim = input.ndim() + dim
2156 layer.axis = dim

Callers 9

_validate_draft_tokensFunction · 0.90
_get_draft_token_indicesFunction · 0.90
_get_draft_token_arrayFunction · 0.90
_get_maskFunction · 0.90
_batch_index_selectFunction · 0.90
expand_dimsFunction · 0.70
shapeFunction · 0.70
argmaxFunction · 0.70
gather_last_token_logitsFunction · 0.70

Calls 6

constantFunction · 0.85
default_trtnetFunction · 0.85
_create_tensorFunction · 0.85
rankMethod · 0.45
ndimMethod · 0.45
get_outputMethod · 0.45

Tested by

no test coverage detected