MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / AllReduceParams

Class AllReduceParams

tensorrt_llm/functional.py:3899–3938  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

3897
3898
3899class AllReduceParams():
3900
3901 def __init__(self,
3902 strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
3903 fusion_op: AllReduceFusionOp = AllReduceFusionOp.NONE,
3904 bias: Optional[Tensor] = None,
3905 residual: Optional[Tensor] = None,
3906 norm_weight: Optional[Tensor] = None,
3907 scale: Optional[Tensor] = None,
3908 norm_pre_residual_weight: Optional[Tensor] = None,
3909 eps: float = 1e-06,
3910 enable_allreduce: bool = True,
3911 trigger_completion_at_end: bool = True):
3912 self.strategy = strategy
3913 self.fusion_op = fusion_op
3914 self.bias = bias
3915 self.residual = residual
3916 self.norm_weight = norm_weight
3917 self.scale = scale
3918 self.norm_pre_residual_weight = norm_pre_residual_weight
3919 self.eps = eps
3920 # For torch path only, has no effect on TRT path
3921 self.enable_allreduce = enable_allreduce
3922 self.trigger_completion_at_end = trigger_completion_at_end
3923 assert fusion_op == AllReduceFusionOp.NONE.value or (residual
3924 is not None)
3925
3926 def has_affine(self):
3927 return 1 if self.norm_weight is not None else 0
3928
3929 def has_bias(self):
3930 return 1 if self.bias is not None else 0
3931
3932 def has_scale(self):
3933 return 1 if self.scale is not None else 0
3934
3935 def update_strategy(self):
3936 if self.strategy == AllReduceStrategy.AUTO and default_net(
3937 ).plugin_config.user_buffer:
3938 self.strategy = AllReduceStrategy.UB
3939
3940
3941class MoEAllReduceParams(AllReduceParams):

Callers 15

profile_allreduceFunction · 0.90
test_allreduceMethod · 0.90
test_allreduceMethod · 0.90
test_allreduceMethod · 0.90
forwardMethod · 0.90
forwardMethod · 0.90
funcFunction · 0.85
calc_fused_allreduceFunction · 0.85

Calls

no outgoing calls

Tested by 10

test_allreduceMethod · 0.72
test_allreduceMethod · 0.72
test_allreduceMethod · 0.72
funcFunction · 0.68
calc_fused_allreduceFunction · 0.68
run_moe_allreduce_opFunction · 0.68