(self, hidden_states, token_selected_experts,
token_final_scales, lora_layer_params, side_stream_id)
| 1036 | return output |
| 1037 | |
| 1038 | def forward_experts(self, hidden_states, token_selected_experts, |
| 1039 | token_final_scales, lora_layer_params, side_stream_id): |
| 1040 | |
| 1041 | groupwise_quant_params = MoeGroupwiseQuantParams() |
| 1042 | if self.quant_mode.has_fp8_qdq(): |
| 1043 | assert self.fc.weight.value.dtype == trt.fp8, ( |
| 1044 | "mlp fc weight dtype should be fp8 in the fp8 quantization mode." |
| 1045 | ) |
| 1046 | assert self.proj.weight.value.dtype == trt.fp8, ( |
| 1047 | "mlp proj weight dtype should be fp8 in the fp8 quantization mode." |
| 1048 | ) |
| 1049 | hidden_states_quant = hidden_states |
| 1050 | if hidden_states_quant.dtype != trt.fp8: |
| 1051 | hidden_states_quant = quantize( |
| 1052 | hidden_states, self.fc.activation_scaling_factor.value, |
| 1053 | 'fp8') |
| 1054 | |
| 1055 | dtype_quant = trt.fp8 |
| 1056 | weight_dtype_quant = trt.fp8 |
| 1057 | |
| 1058 | fc1_dequant = self.fc.weights_scaling_factor.value * self.fc.activation_scaling_factor.value |
| 1059 | fc2_quant = div(1.0, self.proj.activation_scaling_factor.value) |
| 1060 | fc2_dequant = self.proj.weights_scaling_factor.value * self.proj.activation_scaling_factor.value |
| 1061 | fc1_act_dequant = self.fc.activation_scaling_factor.value |
| 1062 | |
| 1063 | scale_1 = fc1_dequant |
| 1064 | scale_2 = fc2_quant |
| 1065 | scale_3 = fc2_dequant |
| 1066 | scale_4 = None |
| 1067 | scale_5 = fc1_act_dequant |
| 1068 | scale_6 = None |
| 1069 | |
| 1070 | output_dtype_quant = self.dtype |
| 1071 | |
| 1072 | if output_dtype_quant == trt.fp8 and scale_4 is None: |
| 1073 | raise RuntimeError( |
| 1074 | "Cannot output FP8 value without knowing quantization parameter" |
| 1075 | ) |
| 1076 | elif self.quant_mode.has_nvfp4(): |
| 1077 | # We pass through the weights unchanged, the quantization is done in the plugin |
| 1078 | hidden_states_quant = hidden_states |
| 1079 | dtype_quant = trt.fp4 |
| 1080 | weight_dtype_quant = trt.fp4 |
| 1081 | output_dtype_quant = self.dtype |
| 1082 | |
| 1083 | scale_1 = div(1.0, self.fc.activation_global_scaling_factor.value) |
| 1084 | scale_2 = self.fc.weights_block_scaling_factor_interleaved |
| 1085 | scale_3 = self.fc.alpha |
| 1086 | scale_4 = div(1.0, self.proj.activation_global_scaling_factor.value) |
| 1087 | scale_5 = self.proj.weights_block_scaling_factor_interleaved |
| 1088 | scale_6 = self.proj.alpha |
| 1089 | elif self.quant_mode.has_per_group_scaling(): |
| 1090 | hidden_states_quant = hidden_states |
| 1091 | dtype_quant = trt.fp8 if self.use_w4a8_awq else self.dtype |
| 1092 | weight_dtype_quant = self.weight_dtype |
| 1093 | output_dtype_quant = self.dtype |
| 1094 | |
| 1095 | scale_1 = None |
no test coverage detected