(
input: torch.Tensor,
residual: Optional[torch.Tensor],
norm_weight: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
workspace: Optional[torch.Tensor],
group: List[int],
strategy: int,
op: int,
eps: float,
trigger_completion_at_end: bool,
)
| 15 | |
| 16 | @torch.library.register_fake("trtllm::allreduce") |
| 17 | def allreduce( |
| 18 | input: torch.Tensor, |
| 19 | residual: Optional[torch.Tensor], |
| 20 | norm_weight: Optional[torch.Tensor], |
| 21 | scale: Optional[torch.Tensor], |
| 22 | bias: Optional[torch.Tensor], |
| 23 | workspace: Optional[torch.Tensor], |
| 24 | group: List[int], |
| 25 | strategy: int, |
| 26 | op: int, |
| 27 | eps: float, |
| 28 | trigger_completion_at_end: bool, |
| 29 | ) -> List[torch.Tensor]: |
| 30 | from tensorrt_llm.functional import AllReduceFusionOp |
| 31 | if op == int(AllReduceFusionOp.NONE): |
| 32 | return [torch.empty_like(input)] |
| 33 | elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM): |
| 34 | norm_out = torch.empty_like(input) |
| 35 | residual_out = torch.empty_like(input) |
| 36 | return [norm_out, residual_out] |
| 37 | elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8): |
| 38 | quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn) |
| 39 | residual_out = torch.empty_like(input) |
| 40 | return [quant_out, residual_out] |
| 41 | elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8): |
| 42 | norm_out = torch.empty_like(input) |
| 43 | quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn) |
| 44 | residual_out = torch.empty_like(input) |
| 45 | return [norm_out, quant_out, residual_out] |
| 46 | elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4): |
| 47 | fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16) |
| 48 | quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8) |
| 49 | scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8) |
| 50 | residual_out = torch.empty_like(input) |
| 51 | return [quant_fp4, scale_fp4, residual_out] |
| 52 | elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4): |
| 53 | fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16) |
| 54 | quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8) |
| 55 | scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8) |
| 56 | norm_out = torch.empty_like(input) |
| 57 | residual_out = torch.empty_like(input) |
| 58 | return [norm_out, quant_fp4, scale_fp4, residual_out] |
| 59 | else: |
| 60 | return [torch.empty_like(input)] |
| 61 | |
| 62 | @torch.library.register_fake("trtllm::allreduce_pg") |
| 63 | def _( |
no outgoing calls
no test coverage detected