Add an operation to gather elements from a tensor. That function implements the GatherElements operator from the ONNX specification as described in https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements The input and indices arguments must have the same r
(input: Tensor, dim: int, indices: Union[Tensor, int])
| 2097 | |
| 2098 | |
| 2099 | def gather(input: Tensor, dim: int, indices: Union[Tensor, int]) -> Tensor: |
| 2100 | ''' |
| 2101 | Add an operation to gather elements from a tensor. |
| 2102 | |
| 2103 | That function implements the GatherElements operator from the ONNX |
| 2104 | specification as described in |
| 2105 | |
| 2106 | https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements |
| 2107 | |
| 2108 | The input and indices arguments must have the same rank >= 1. The operation |
| 2109 | will produce a tensor with the same shape as the indices tensor. The axis |
| 2110 | is the dimension to gather on. |
| 2111 | |
| 2112 | As shown in the ONNX description, for a 3D tensor, the output is: |
| 2113 | |
| 2114 | out[i][j][k] = input[indices[i][j][k]][j][k] if axis = 0, |
| 2115 | out[i][j][k] = input[i][indices[i][j][k]][k] if axis = 1, |
| 2116 | out[i][j][k] = input[i][j][indices[i][j][k]] if axis = 2. |
| 2117 | |
| 2118 | For example, |
| 2119 | |
| 2120 | gather([[4, 2], [5, 3]], 0, [[1, 0], [0, 1]]) |
| 2121 | |
| 2122 | will produce [[5, 2], [4, 3]]. |
| 2123 | |
| 2124 | gather([[1, 2, 3], [4, 5, 6], 1, [[1], [0]]) |
| 2125 | |
| 2126 | will produce [[2], [4]]. See the ONNX documentation for more examples. |
| 2127 | |
| 2128 | That operation maps to the TensorRT IGatherLayer. |
| 2129 | |
| 2130 | Parameters: |
| 2131 | input : Tensor |
| 2132 | The input tensor to gather elements from. |
| 2133 | |
| 2134 | dim : int |
| 2135 | The dimension to gather on. |
| 2136 | |
| 2137 | indices : Union[Tensor, int] |
| 2138 | The positions in the 'dim' dimension to gather from. |
| 2139 | |
| 2140 | Returns: |
| 2141 | The tensor containing the gathered elements. It has the same shape as |
| 2142 | the indices tensor. |
| 2143 | ''' |
| 2144 | if isinstance(indices, int): |
| 2145 | indices = constant(int32_array([indices])) |
| 2146 | |
| 2147 | # The input and indices tensors must have the same rank. |
| 2148 | assert input.rank() == indices.rank() |
| 2149 | |
| 2150 | layer = default_trtnet().add_gather_v2(input.trt_tensor, |
| 2151 | indices.trt_tensor, |
| 2152 | mode=trt.GatherMode.ELEMENT) |
| 2153 | |
| 2154 | if dim < 0: |
| 2155 | dim = input.ndim() + dim |
| 2156 | layer.axis = dim |
no test coverage detected