MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward_experts

Method forward_experts

tensorrt_llm/layers/moe.py:1038–1164  ·  view source on GitHub ↗
(self, hidden_states, token_selected_experts,
                        token_final_scales, lora_layer_params, side_stream_id)

Source from the content-addressed store, hash-verified

1036 return output
1037
1038 def forward_experts(self, hidden_states, token_selected_experts,
1039 token_final_scales, lora_layer_params, side_stream_id):
1040
1041 groupwise_quant_params = MoeGroupwiseQuantParams()
1042 if self.quant_mode.has_fp8_qdq():
1043 assert self.fc.weight.value.dtype == trt.fp8, (
1044 "mlp fc weight dtype should be fp8 in the fp8 quantization mode."
1045 )
1046 assert self.proj.weight.value.dtype == trt.fp8, (
1047 "mlp proj weight dtype should be fp8 in the fp8 quantization mode."
1048 )
1049 hidden_states_quant = hidden_states
1050 if hidden_states_quant.dtype != trt.fp8:
1051 hidden_states_quant = quantize(
1052 hidden_states, self.fc.activation_scaling_factor.value,
1053 'fp8')
1054
1055 dtype_quant = trt.fp8
1056 weight_dtype_quant = trt.fp8
1057
1058 fc1_dequant = self.fc.weights_scaling_factor.value * self.fc.activation_scaling_factor.value
1059 fc2_quant = div(1.0, self.proj.activation_scaling_factor.value)
1060 fc2_dequant = self.proj.weights_scaling_factor.value * self.proj.activation_scaling_factor.value
1061 fc1_act_dequant = self.fc.activation_scaling_factor.value
1062
1063 scale_1 = fc1_dequant
1064 scale_2 = fc2_quant
1065 scale_3 = fc2_dequant
1066 scale_4 = None
1067 scale_5 = fc1_act_dequant
1068 scale_6 = None
1069
1070 output_dtype_quant = self.dtype
1071
1072 if output_dtype_quant == trt.fp8 and scale_4 is None:
1073 raise RuntimeError(
1074 "Cannot output FP8 value without knowing quantization parameter"
1075 )
1076 elif self.quant_mode.has_nvfp4():
1077 # We pass through the weights unchanged, the quantization is done in the plugin
1078 hidden_states_quant = hidden_states
1079 dtype_quant = trt.fp4
1080 weight_dtype_quant = trt.fp4
1081 output_dtype_quant = self.dtype
1082
1083 scale_1 = div(1.0, self.fc.activation_global_scaling_factor.value)
1084 scale_2 = self.fc.weights_block_scaling_factor_interleaved
1085 scale_3 = self.fc.alpha
1086 scale_4 = div(1.0, self.proj.activation_global_scaling_factor.value)
1087 scale_5 = self.proj.weights_block_scaling_factor_interleaved
1088 scale_6 = self.proj.alpha
1089 elif self.quant_mode.has_per_group_scaling():
1090 hidden_states_quant = hidden_states
1091 dtype_quant = trt.fp8 if self.use_w4a8_awq else self.dtype
1092 weight_dtype_quant = self.weight_dtype
1093 output_dtype_quant = self.dtype
1094
1095 scale_1 = None

Callers 1

forwardMethod · 0.95

Calls 6

quantizeFunction · 0.90
_moe_pluginFunction · 0.85
has_per_group_scalingMethod · 0.80
has_fp8_qdqMethod · 0.45
has_nvfp4Method · 0.45

Tested by

no test coverage detected