MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / scatter_nd

Function scatter_nd

tensorrt_llm/functional.py:7419–7439  ·  view source on GitHub ↗

Scatter_nd is a tensor operation that writes or updates values in a tensor based on indices. Parameters: input: Tensor The input tensor to be updated mask: Tensor A tensor of indices specifying the locations in data to be updated. source: Ten

(input: Tensor, mask: Tensor, source: Tensor)

Source from the content-addressed store, hash-verified

7417
7418
7419def scatter_nd(input: Tensor, mask: Tensor, source: Tensor) -> Tensor:
7420 '''
7421 Scatter_nd is a tensor operation that writes or updates values in a tensor based on indices.
7422
7423 Parameters:
7424 input: Tensor
7425 The input tensor to be updated
7426 mask: Tensor
7427 A tensor of indices specifying the locations in data to be updated.
7428 source: Tensor
7429 A tensor of values to be written or scattered into data.
7430 Returns:
7431 New tensor with the same shape as the input tensor data,
7432 where the values from the source tensor are scattered or written into the output tensor
7433 at the locations specified by the mask tensor.
7434 '''
7435 scatter_layer = default_trtnet().add_scatter(input.trt_tensor,
7436 mask.trt_tensor,
7437 source.trt_tensor,
7438 mode=trt.ScatterMode.ND)
7439 return _create_tensor(scatter_layer.get_output(0), scatter_layer)
7440
7441
7442def low_latency_gemm(input: Tensor,

Callers 1

forward_expertsMethod · 0.85

Calls 3

default_trtnetFunction · 0.85
_create_tensorFunction · 0.85
get_outputMethod · 0.45

Tested by

no test coverage detected