Add an operation to select elements from a tensor according to a boolean mask tensor. Given an input tensor, that function creates an operation that selects elements at the indices indicated by the boolean mask tensor to create a new tensor. The output tensor is a 1-D tensor.
(input: Tensor, mask: Tensor)
| 2350 | |
| 2351 | |
| 2352 | def masked_select(input: Tensor, mask: Tensor) -> Tensor: |
| 2353 | ''' |
| 2354 | Add an operation to select elements from a tensor according to a boolean |
| 2355 | mask tensor. |
| 2356 | |
| 2357 | Given an input tensor, that function creates an operation that selects |
| 2358 | elements at the indices indicated by the boolean mask tensor to create |
| 2359 | a new tensor. The output tensor is a 1-D tensor. |
| 2360 | |
| 2361 | The input tensor must have rank >= 1. The shapes of the input tensor and |
| 2362 | the mask tensor don’t need to match, but they must be able to be broadcasted. |
| 2363 | |
| 2364 | For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape |
| 2365 | [3, 3], |
| 2366 | |
| 2367 | masked_select(input, [[True, False, True], [False, True, False], [True, False, True]]) |
| 2368 | |
| 2369 | will create a tensor of shape [5] that contains the [4, 5, 1, 4, 1]. |
| 2370 | |
| 2371 | masked_select(input, [[True], [False], [True]]) |
| 2372 | |
| 2373 | will create a tensor of shape [6] that contains the [4, 2, 5, 4, 7, 1]. |
| 2374 | |
| 2375 | masked_select(input, [[False, False, True]]) |
| 2376 | |
| 2377 | will create a tensor of shape [3] that contains the [5, 2, 1]. |
| 2378 | |
| 2379 | masked_select(input, [False]) |
| 2380 | |
| 2381 | will create a tensor of shape [0] which is empty. |
| 2382 | |
| 2383 | That operation is implemented by NonZero, Shuffle and GatherV2 layers |
| 2384 | in TensorRT. |
| 2385 | |
| 2386 | Parameters: |
| 2387 | input : Tensor |
| 2388 | The input tensor to select from. |
| 2389 | |
| 2390 | mask : Tensor |
| 2391 | The boolean mask tensor that indicates elements to select. |
| 2392 | |
| 2393 | Returns: |
| 2394 | The 1-D tensor containing the selected elements. |
| 2395 | ''' |
| 2396 | assert input.rank() >= 1, "input should have rank >= 1" |
| 2397 | input, mask = broadcast_helper(input, mask) |
| 2398 | expanded_mask = expand(mask, shape(input)) |
| 2399 | |
| 2400 | non_zero_layer = default_trtnet().add_non_zero(expanded_mask.trt_tensor) |
| 2401 | |
| 2402 | shuffle_layer = default_trtnet().add_shuffle(non_zero_layer.get_output(0)) |
| 2403 | shuffle_layer.second_transpose = (1, 0) |
| 2404 | |
| 2405 | gather_layer = default_trtnet().add_gather_v2(input.trt_tensor, |
| 2406 | shuffle_layer.get_output(0), |
| 2407 | mode=trt.GatherMode.ND) |
| 2408 | return _create_tensor(gather_layer.get_output(0), gather_layer) |
| 2409 |
no test coverage detected