MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / masked_scatter

Function masked_scatter

tensorrt_llm/functional.py:2539–2574  ·  view source on GitHub ↗

Add the masked_scatter base on PyTorch definition. See https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter_.html#torch-tensor-masked-scatter for a description of that function. Parameters: input : Tensor The input tensor. mask : Tenso

(input: Tensor, mask: Tensor, source: Tensor)

Source from the content-addressed store, hash-verified

2537
2538
2539def masked_scatter(input: Tensor, mask: Tensor, source: Tensor) -> Tensor:
2540 '''
2541 Add the masked_scatter base on PyTorch definition.
2542
2543 See https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter_.html#torch-tensor-masked-scatter for a
2544 description of that function.
2545
2546 Parameters:
2547 input : Tensor
2548 The input tensor.
2549
2550 mask : Tensor
2551 The boolean mask tensor that indicates elements to select.
2552
2553 source: Tensor
2554 The tensor to copy from
2555 Returns:
2556 The tensor containing the source tensor selected by mask.
2557
2558 '''
2559 assert input.rank() >= 1, "input should have rank >= 1"
2560 input, mask = broadcast_helper(input, mask)
2561 expanded_mask = expand(mask, shape(input))
2562
2563 non_zero_layer = default_trtnet().add_non_zero(expanded_mask.trt_tensor)
2564
2565 shuffle_layer = default_trtnet().add_shuffle(non_zero_layer.get_output(0))
2566 shuffle_layer.second_transpose = (1, 0)
2567 source = source.view([-1])
2568
2569 scatter_layer = default_trtnet().add_scatter(input.trt_tensor,
2570 shuffle_layer.get_output(0),
2571 source.trt_tensor,
2572 mode=trt.ScatterMode.ND)
2573
2574 return _create_tensor(scatter_layer.get_output(0), scatter_layer)
2575
2576
2577def concat(inputs: Sequence[Union[Tensor, int]], dim: int = 0) -> Tensor:

Callers

nothing calls this directly

Calls 8

broadcast_helperFunction · 0.85
expandFunction · 0.85
default_trtnetFunction · 0.85
_create_tensorFunction · 0.85
shapeFunction · 0.70
rankMethod · 0.45
get_outputMethod · 0.45
viewMethod · 0.45

Tested by

no test coverage detected