MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / masked_select

Function masked_select

tensorrt_llm/functional.py:2352–2408  ·  view source on GitHub ↗

Add an operation to select elements from a tensor according to a boolean mask tensor. Given an input tensor, that function creates an operation that selects elements at the indices indicated by the boolean mask tensor to create a new tensor. The output tensor is a 1-D tensor.

(input: Tensor, mask: Tensor)

Source from the content-addressed store, hash-verified

2350
2351
2352def masked_select(input: Tensor, mask: Tensor) -> Tensor:
2353 '''
2354 Add an operation to select elements from a tensor according to a boolean
2355 mask tensor.
2356
2357 Given an input tensor, that function creates an operation that selects
2358 elements at the indices indicated by the boolean mask tensor to create
2359 a new tensor. The output tensor is a 1-D tensor.
2360
2361 The input tensor must have rank >= 1. The shapes of the input tensor and
2362 the mask tensor don’t need to match, but they must be able to be broadcasted.
2363
2364 For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape
2365 [3, 3],
2366
2367 masked_select(input, [[True, False, True], [False, True, False], [True, False, True]])
2368
2369 will create a tensor of shape [5] that contains the [4, 5, 1, 4, 1].
2370
2371 masked_select(input, [[True], [False], [True]])
2372
2373 will create a tensor of shape [6] that contains the [4, 2, 5, 4, 7, 1].
2374
2375 masked_select(input, [[False, False, True]])
2376
2377 will create a tensor of shape [3] that contains the [5, 2, 1].
2378
2379 masked_select(input, [False])
2380
2381 will create a tensor of shape [0] which is empty.
2382
2383 That operation is implemented by NonZero, Shuffle and GatherV2 layers
2384 in TensorRT.
2385
2386 Parameters:
2387 input : Tensor
2388 The input tensor to select from.
2389
2390 mask : Tensor
2391 The boolean mask tensor that indicates elements to select.
2392
2393 Returns:
2394 The 1-D tensor containing the selected elements.
2395 '''
2396 assert input.rank() >= 1, "input should have rank >= 1"
2397 input, mask = broadcast_helper(input, mask)
2398 expanded_mask = expand(mask, shape(input))
2399
2400 non_zero_layer = default_trtnet().add_non_zero(expanded_mask.trt_tensor)
2401
2402 shuffle_layer = default_trtnet().add_shuffle(non_zero_layer.get_output(0))
2403 shuffle_layer.second_transpose = (1, 0)
2404
2405 gather_layer = default_trtnet().add_gather_v2(input.trt_tensor,
2406 shuffle_layer.get_output(0),
2407 mode=trt.GatherMode.ND)
2408 return _create_tensor(gather_layer.get_output(0), gather_layer)
2409

Callers 2

_get_packed_position_idsFunction · 0.90
encode_textMethod · 0.85

Calls 7

broadcast_helperFunction · 0.85
expandFunction · 0.85
default_trtnetFunction · 0.85
_create_tensorFunction · 0.85
shapeFunction · 0.70
rankMethod · 0.45
get_outputMethod · 0.45

Tested by

no test coverage detected