MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / split

Function split

tensorrt_llm/_torch/utils.py:394–402  ·  view source on GitHub ↗
(x: torch.Tensor,
          tp_size: int,
          idx: int,
          dim: int = 0)

Source from the content-addressed store, hash-verified

392
393
394def split(x: torch.Tensor,
395 tp_size: int,
396 idx: int,
397 dim: int = 0) -> torch.Tensor:
398 assert x.shape[dim] % tp_size == 0
399 split_size = x.shape[dim] // tp_size
400 if tp_size == 1:
401 return x
402 return torch.split(x, split_size, dim=dim)[idx]
403
404
405def relu2(x: torch.Tensor) -> torch.Tensor:

Callers 9

forwardMethod · 0.50
forwardMethod · 0.50
forward_expertsMethod · 0.50
forwardMethod · 0.50
compute_cross_kvMethod · 0.50
forwardMethod · 0.50
postprocessMethod · 0.50
forwardMethod · 0.50
forwardMethod · 0.50

Calls 1

splitMethod · 0.45

Tested by

no test coverage detected