Add the masked_scatter base on PyTorch definition. See https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter_.html#torch-tensor-masked-scatter for a description of that function. Parameters: input : Tensor The input tensor. mask : Tenso
(input: Tensor, mask: Tensor, source: Tensor)
| 2537 | |
| 2538 | |
| 2539 | def masked_scatter(input: Tensor, mask: Tensor, source: Tensor) -> Tensor: |
| 2540 | ''' |
| 2541 | Add the masked_scatter base on PyTorch definition. |
| 2542 | |
| 2543 | See https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter_.html#torch-tensor-masked-scatter for a |
| 2544 | description of that function. |
| 2545 | |
| 2546 | Parameters: |
| 2547 | input : Tensor |
| 2548 | The input tensor. |
| 2549 | |
| 2550 | mask : Tensor |
| 2551 | The boolean mask tensor that indicates elements to select. |
| 2552 | |
| 2553 | source: Tensor |
| 2554 | The tensor to copy from |
| 2555 | Returns: |
| 2556 | The tensor containing the source tensor selected by mask. |
| 2557 | |
| 2558 | ''' |
| 2559 | assert input.rank() >= 1, "input should have rank >= 1" |
| 2560 | input, mask = broadcast_helper(input, mask) |
| 2561 | expanded_mask = expand(mask, shape(input)) |
| 2562 | |
| 2563 | non_zero_layer = default_trtnet().add_non_zero(expanded_mask.trt_tensor) |
| 2564 | |
| 2565 | shuffle_layer = default_trtnet().add_shuffle(non_zero_layer.get_output(0)) |
| 2566 | shuffle_layer.second_transpose = (1, 0) |
| 2567 | source = source.view([-1]) |
| 2568 | |
| 2569 | scatter_layer = default_trtnet().add_scatter(input.trt_tensor, |
| 2570 | shuffle_layer.get_output(0), |
| 2571 | source.trt_tensor, |
| 2572 | mode=trt.ScatterMode.ND) |
| 2573 | |
| 2574 | return _create_tensor(scatter_layer.get_output(0), scatter_layer) |
| 2575 | |
| 2576 | |
| 2577 | def concat(inputs: Sequence[Union[Tensor, int]], dim: int = 0) -> Tensor: |
nothing calls this directly
no test coverage detected