MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/_torch/distributed/ops.py:752–870  ·  view source on GitHub ↗

The input tensors in the different ranks must have the same shape. The output tensor will have that same shape with the input tensor. The output tensor will be replicated among the TP group. Note that it is not an in-place operation like torch.distributed.all_reduce.

(
        self,
        input: torch.Tensor,
        *,
        all_reduce_params: Optional[AllReduceParams] = None,
    )

Source from the content-addressed store, hash-verified

750 self.mnnvl_allreduce = None
751
752 def forward(
753 self,
754 input: torch.Tensor,
755 *,
756 all_reduce_params: Optional[AllReduceParams] = None,
757 ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
758 '''
759 The input tensors in the different ranks must have the same shape.
760 The output tensor will have that same shape with the input tensor.
761 The output tensor will be replicated among the TP group.
762 Note that it is not an in-place operation like torch.distributed.all_reduce.
763
764 That operation is implemented using a torch op that wraps the NCCL all-reduce
765 collective operation and custom one-shot/two-shot allreduce kernels. See
766 https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce
767 for details.
768
769 Args:
770 input (Tensor): The input tensor.
771 all_reduce_params (AllReduceParams): The parameters for the fused ops into the allreduce op.
772 Returns:
773 A tensor lists with different tensor outptus according to the fusion_op.
774 NONE: [hidden_states]
775 RESIDUAL_RMS_NORM: [hidden_states, residual]
776 RESIDUAL_RMS_NORM_QUANT_FP8: [norm_quant, residual]
777 RESIDUAL_RMS_NORM_OUT_QUANT_FP8: [norm, norm_quant, residual]
778 RESIDUAL_RMS_NORM_QUANT_NVFP4: [norm_quant_fp4, scale_factor, residual]
779 RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: [norm, norm_quant_fp4, scale_factor, residual]
780 '''
781 if self.mapping.tp_size == 1 or (all_reduce_params is not None
782 and all_reduce_params.enable_allreduce
783 == False):
784 return input
785
786 input = input.contiguous() # Underlying op requires contiguous input
787
788 allreduce_strategy = self.strategy
789
790 if all_reduce_params is None:
791 all_reduce_params = AllReduceParams()
792
793 # Try Symmetric Memory AllReduce first if available
794 # Note: Currently only supports NONE fusion op (plain allreduce)
795 if self.symm_mem_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE:
796 symm_mem_output = self.symm_mem_allreduce(input)
797 if symm_mem_output is not None:
798 logger.debug(
799 f"Using SymmetricMemoryAllReduce (MULTIMEM) for input shape {input.shape}"
800 )
801 return symm_mem_output
802 elif self.symm_mem_allreduce and all_reduce_params.fusion_op != AllReduceFusionOp.NONE:
803 # Log once per rank that we're skipping symm_mem due to fusion
804 logger.debug_once(
805 f"Skipping SymmetricMemoryAllReduce for fused operation (fusion_op={all_reduce_params.fusion_op}), using regular allreduce",
806 key=(self.mapping.tp_rank, all_reduce_params.fusion_op,
807 "debug_fusion_skip"),
808 )
809

Calls 4

AllReduceParamsClass · 0.90
debug_onceMethod · 0.80
debugMethod · 0.45
getMethod · 0.45