The input tensors in the different ranks must have the same shape. The output tensor will have that same shape with the input tensor. The output tensor will be replicated among the TP group. Note that it is not an in-place operation like torch.distributed.all_reduce.
(
self,
input: torch.Tensor,
*,
all_reduce_params: Optional[AllReduceParams] = None,
)
| 750 | self.mnnvl_allreduce = None |
| 751 | |
| 752 | def forward( |
| 753 | self, |
| 754 | input: torch.Tensor, |
| 755 | *, |
| 756 | all_reduce_params: Optional[AllReduceParams] = None, |
| 757 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: |
| 758 | ''' |
| 759 | The input tensors in the different ranks must have the same shape. |
| 760 | The output tensor will have that same shape with the input tensor. |
| 761 | The output tensor will be replicated among the TP group. |
| 762 | Note that it is not an in-place operation like torch.distributed.all_reduce. |
| 763 | |
| 764 | That operation is implemented using a torch op that wraps the NCCL all-reduce |
| 765 | collective operation and custom one-shot/two-shot allreduce kernels. See |
| 766 | https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce |
| 767 | for details. |
| 768 | |
| 769 | Args: |
| 770 | input (Tensor): The input tensor. |
| 771 | all_reduce_params (AllReduceParams): The parameters for the fused ops into the allreduce op. |
| 772 | Returns: |
| 773 | A tensor lists with different tensor outptus according to the fusion_op. |
| 774 | NONE: [hidden_states] |
| 775 | RESIDUAL_RMS_NORM: [hidden_states, residual] |
| 776 | RESIDUAL_RMS_NORM_QUANT_FP8: [norm_quant, residual] |
| 777 | RESIDUAL_RMS_NORM_OUT_QUANT_FP8: [norm, norm_quant, residual] |
| 778 | RESIDUAL_RMS_NORM_QUANT_NVFP4: [norm_quant_fp4, scale_factor, residual] |
| 779 | RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: [norm, norm_quant_fp4, scale_factor, residual] |
| 780 | ''' |
| 781 | if self.mapping.tp_size == 1 or (all_reduce_params is not None |
| 782 | and all_reduce_params.enable_allreduce |
| 783 | == False): |
| 784 | return input |
| 785 | |
| 786 | input = input.contiguous() # Underlying op requires contiguous input |
| 787 | |
| 788 | allreduce_strategy = self.strategy |
| 789 | |
| 790 | if all_reduce_params is None: |
| 791 | all_reduce_params = AllReduceParams() |
| 792 | |
| 793 | # Try Symmetric Memory AllReduce first if available |
| 794 | # Note: Currently only supports NONE fusion op (plain allreduce) |
| 795 | if self.symm_mem_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE: |
| 796 | symm_mem_output = self.symm_mem_allreduce(input) |
| 797 | if symm_mem_output is not None: |
| 798 | logger.debug( |
| 799 | f"Using SymmetricMemoryAllReduce (MULTIMEM) for input shape {input.shape}" |
| 800 | ) |
| 801 | return symm_mem_output |
| 802 | elif self.symm_mem_allreduce and all_reduce_params.fusion_op != AllReduceFusionOp.NONE: |
| 803 | # Log once per rank that we're skipping symm_mem due to fusion |
| 804 | logger.debug_once( |
| 805 | f"Skipping SymmetricMemoryAllReduce for fused operation (fusion_op={all_reduce_params.fusion_op}), using regular allreduce", |
| 806 | key=(self.mapping.tp_rank, all_reduce_params.fusion_op, |
| 807 | "debug_fusion_skip"), |
| 808 | ) |
| 809 |