Add an operation to select a slice of elements from a tensor. Given an input tensor, that function creates an operation that selects the index-th slice of elements in the dimension 'dim' to create a new tensor. The output tensor has a shape in which the input dimension 'dim' is
(input: Tensor, dim: int, index: Union[Tensor, int])
| 2158 | |
| 2159 | |
| 2160 | def select(input: Tensor, dim: int, index: Union[Tensor, int]) -> Tensor: |
| 2161 | ''' |
| 2162 | Add an operation to select a slice of elements from a tensor. |
| 2163 | |
| 2164 | Given an input tensor, that function creates an operation that selects the |
| 2165 | index-th slice of elements in the dimension 'dim' to create a new tensor. |
| 2166 | The output tensor has a shape in which the input dimension 'dim' is |
| 2167 | removed. |
| 2168 | |
| 2169 | The 'index' can either be an integer or a 1D tensor containing a single |
| 2170 | element. |
| 2171 | |
| 2172 | For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape |
| 2173 | [3, 3], |
| 2174 | |
| 2175 | select(input, 0, 1) |
| 2176 | |
| 2177 | will create a tensor of shape [3] that contains the [2, 1, 2]. |
| 2178 | |
| 2179 | Regarding the shape of the output tensor, the dimension 'dim' is removed. |
| 2180 | It means that for a tensor of shape [4, 2, 6, 3], |
| 2181 | |
| 2182 | select(input, 2, 4) |
| 2183 | |
| 2184 | will select the 5th slice (index == 4) from the 3rd dimension (dim == 2) |
| 2185 | and return a tensor of shape [4, 2, 3] (i.e. the 3rd dimension is removed). |
| 2186 | |
| 2187 | That operation maps to the TensorRT IGatherLayer. |
| 2188 | |
| 2189 | Parameters: |
| 2190 | input : Tensor |
| 2191 | The input tensor to select from. |
| 2192 | |
| 2193 | dim : int |
| 2194 | The dimension to select from. |
| 2195 | |
| 2196 | index : Union[Tensor, int] |
| 2197 | The index of the slice in the 'dim' dimension to select. |
| 2198 | |
| 2199 | Returns: |
| 2200 | The tensor containing the selected slice. |
| 2201 | ''' |
| 2202 | if isinstance(index, int): |
| 2203 | index = constant(int32_array([index])) |
| 2204 | assert index.rank() == 1 and index.size( |
| 2205 | 0) == 1, f"index should have rank 1, got {index.rank()}" |
| 2206 | |
| 2207 | new_shape = [] |
| 2208 | for i in range(input.rank()): |
| 2209 | if i != dim: |
| 2210 | new_shape.append(shape(input, i)) |
| 2211 | |
| 2212 | layer = default_trtnet().add_gather(input.trt_tensor, index.trt_tensor, dim) |
| 2213 | return _create_tensor(layer.get_output(0), layer).view(concat(new_shape)) |
| 2214 | |
| 2215 | |
| 2216 | def index_select(input: Tensor, dim: int, index: Tensor) -> Tensor: |
no test coverage detected