Add an operation to select slices of elements from a tensor. Given an input tensor, that function creates an operation that selects the slices of elements in the dimension 'dim' at the indices listed in 'index' to create a new tensor. The output tensor has the same rank as the inp
(input: Tensor, dim: int, index: Tensor)
| 2214 | |
| 2215 | |
| 2216 | def index_select(input: Tensor, dim: int, index: Tensor) -> Tensor: |
| 2217 | ''' |
| 2218 | Add an operation to select slices of elements from a tensor. |
| 2219 | |
| 2220 | Given an input tensor, that function creates an operation that selects the |
| 2221 | slices of elements in the dimension 'dim' at the indices listed in 'index' |
| 2222 | to create a new tensor. The output tensor has the same rank as the input |
| 2223 | tensor. |
| 2224 | |
| 2225 | The 'index' is a tensor of rank 1. |
| 2226 | |
| 2227 | For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape |
| 2228 | [3, 3], |
| 2229 | |
| 2230 | index_select(input, 0, [0, 1]) |
| 2231 | |
| 2232 | will create a tensor of shape [2, 3] that contains the [[4, 2, 5], [2, 1, 2]]. |
| 2233 | |
| 2234 | Regarding the shape of the output tensor, the dimension 'dim' has the same |
| 2235 | size as the 'index' tensor. It means that for a input tensor of shape [4, 2, 6, 3], |
| 2236 | |
| 2237 | index_select(input, 2, [1, 4]) |
| 2238 | |
| 2239 | will select the 2nd and 5th slices (index == 1 or 4) from the 3rd dimension |
| 2240 | (dim == 2) and return a tensor of shape [4, 2, 2, 3] (i.e. the 3rd |
| 2241 | dimension is shrunk to 2). |
| 2242 | |
| 2243 | Note that this operation can also be used to expand a tensor in the 'dim' |
| 2244 | dimension, for example, on input [[0, 1], [2, 3]], |
| 2245 | |
| 2246 | index_select(input, 1, [0, 0, 0]) |
| 2247 | |
| 2248 | will produce a tensor of shape [2, 3] containing [[0, 0, 0], [2, 2, 2]]. |
| 2249 | |
| 2250 | That operation maps to the TensorRT IGatherLayer. |
| 2251 | |
| 2252 | Parameters: |
| 2253 | input : Tensor |
| 2254 | The input tensor to select from. |
| 2255 | |
| 2256 | dim : int |
| 2257 | The dimension to select from. |
| 2258 | |
| 2259 | index : Tensor |
| 2260 | The indices of the slices in the 'dim' dimension to select. |
| 2261 | |
| 2262 | Returns: |
| 2263 | The tensor containing the selected slices. |
| 2264 | ''' |
| 2265 | assert index.rank() == 1, f"index should have rank 1, got {index.rank()}" |
| 2266 | |
| 2267 | new_shape = [] |
| 2268 | for i in range(input.rank()): |
| 2269 | if i != dim: |
| 2270 | new_shape.append(shape(input, i)) |
| 2271 | else: |
| 2272 | new_shape.append(shape(index, 0)) |
| 2273 |
no test coverage detected