Scatter_nd is a tensor operation that writes or updates values in a tensor based on indices. Parameters: input: Tensor The input tensor to be updated mask: Tensor A tensor of indices specifying the locations in data to be updated. source: Ten
(input: Tensor, mask: Tensor, source: Tensor)
| 7417 | |
| 7418 | |
| 7419 | def scatter_nd(input: Tensor, mask: Tensor, source: Tensor) -> Tensor: |
| 7420 | ''' |
| 7421 | Scatter_nd is a tensor operation that writes or updates values in a tensor based on indices. |
| 7422 | |
| 7423 | Parameters: |
| 7424 | input: Tensor |
| 7425 | The input tensor to be updated |
| 7426 | mask: Tensor |
| 7427 | A tensor of indices specifying the locations in data to be updated. |
| 7428 | source: Tensor |
| 7429 | A tensor of values to be written or scattered into data. |
| 7430 | Returns: |
| 7431 | New tensor with the same shape as the input tensor data, |
| 7432 | where the values from the source tensor are scattered or written into the output tensor |
| 7433 | at the locations specified by the mask tensor. |
| 7434 | ''' |
| 7435 | scatter_layer = default_trtnet().add_scatter(input.trt_tensor, |
| 7436 | mask.trt_tensor, |
| 7437 | source.trt_tensor, |
| 7438 | mode=trt.ScatterMode.ND) |
| 7439 | return _create_tensor(scatter_layer.get_output(0), scatter_layer) |
| 7440 | |
| 7441 | |
| 7442 | def low_latency_gemm(input: Tensor, |
no test coverage detected