MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / allreduce

Function allreduce

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py:17–60  ·  view source on GitHub ↗
(
        input: torch.Tensor,
        residual: Optional[torch.Tensor],
        norm_weight: Optional[torch.Tensor],
        scale: Optional[torch.Tensor],
        bias: Optional[torch.Tensor],
        workspace: Optional[torch.Tensor],
        group: List[int],
        strategy: int,
        op: int,
        eps: float,
        trigger_completion_at_end: bool,
    )

Source from the content-addressed store, hash-verified

15
16 @torch.library.register_fake("trtllm::allreduce")
17 def allreduce(
18 input: torch.Tensor,
19 residual: Optional[torch.Tensor],
20 norm_weight: Optional[torch.Tensor],
21 scale: Optional[torch.Tensor],
22 bias: Optional[torch.Tensor],
23 workspace: Optional[torch.Tensor],
24 group: List[int],
25 strategy: int,
26 op: int,
27 eps: float,
28 trigger_completion_at_end: bool,
29 ) -> List[torch.Tensor]:
30 from tensorrt_llm.functional import AllReduceFusionOp
31 if op == int(AllReduceFusionOp.NONE):
32 return [torch.empty_like(input)]
33 elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM):
34 norm_out = torch.empty_like(input)
35 residual_out = torch.empty_like(input)
36 return [norm_out, residual_out]
37 elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8):
38 quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
39 residual_out = torch.empty_like(input)
40 return [quant_out, residual_out]
41 elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8):
42 norm_out = torch.empty_like(input)
43 quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
44 residual_out = torch.empty_like(input)
45 return [norm_out, quant_out, residual_out]
46 elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4):
47 fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
48 quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
49 scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
50 residual_out = torch.empty_like(input)
51 return [quant_fp4, scale_fp4, residual_out]
52 elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4):
53 fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
54 quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
55 scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
56 norm_out = torch.empty_like(input)
57 residual_out = torch.empty_like(input)
58 return [norm_out, quant_fp4, scale_fp4, residual_out]
59 else:
60 return [torch.empty_like(input)]
61
62 @torch.library.register_fake("trtllm::allreduce_pg")
63 def _(

Callers 14

_Function · 0.70
_update_statisticMethod · 0.50
forwardMethod · 0.50
forwardMethod · 0.50
forwardMethod · 0.50
forwardMethod · 0.50
forwardMethod · 0.50
forwardMethod · 0.50
collect_and_biasMethod · 0.50
forward_allreduceMethod · 0.50
forwardMethod · 0.50
forwardMethod · 0.50

Calls

no outgoing calls

Tested by

no test coverage detected