(
model: PretrainedModel,
gemm_swiglu_plugin_dtype: Optional[str] = None,
low_latency_gemm_swiglu_plugin_dtype: Optional[str] = None,
)
| 1151 | |
| 1152 | |
| 1153 | def fuse_gate_mlp( |
| 1154 | model: PretrainedModel, |
| 1155 | gemm_swiglu_plugin_dtype: Optional[str] = None, |
| 1156 | low_latency_gemm_swiglu_plugin_dtype: Optional[str] = None, |
| 1157 | ) -> PretrainedModel: |
| 1158 | from ..quantization.quantize import fp8_quantize |
| 1159 | |
| 1160 | for name, mlp, layer in model.named_modules_with_parent(): |
| 1161 | if isinstance(mlp, GatedMLP): |
| 1162 | init_params = get_init_params(mlp) |
| 1163 | |
| 1164 | hidden_act = init_params["hidden_act"] |
| 1165 | if hidden_act not in ["silu", "gelu"]: |
| 1166 | logger.warning( |
| 1167 | f"fuse_gate_mlp cannot be done for {name} due to unsupported activation {hidden_act}. Skipping." |
| 1168 | ) |
| 1169 | continue |
| 1170 | |
| 1171 | init_params["inner_layernorm"] = mlp.inner_layernorm is not None |
| 1172 | fused_layer = FusedGatedMLP(**init_params) |
| 1173 | |
| 1174 | fc_name = name + '.fc' |
| 1175 | layer_quant_cfg = model.config._get_quant_cfg(fc_name) |
| 1176 | layer_quant_algo = layer_quant_cfg.quant_algo |
| 1177 | if layer_quant_algo != QuantAlgo.FP8 and layer_quant_algo is not None: |
| 1178 | continue |
| 1179 | |
| 1180 | if isinstance(model.config.quantization.exclude_modules, list) \ |
| 1181 | and fc_name in model.config.quantization.exclude_modules: |
| 1182 | layer_quant_algo = None |
| 1183 | |
| 1184 | if layer_quant_algo == QuantAlgo.FP8: |
| 1185 | fused_layer = fp8_quantize(fused_layer, layer_quant_cfg) |
| 1186 | |
| 1187 | if isinstance(mlp.dtype, str): |
| 1188 | dtype = str_dtype_to_torch(mlp.dtype) |
| 1189 | else: |
| 1190 | dtype = trt_dtype_to_torch(mlp.dtype) |
| 1191 | |
| 1192 | gate_weight = numpy_to_torch(mlp.gate.weight.raw_value) |
| 1193 | fc_weight = numpy_to_torch(mlp.fc.weight.raw_value) |
| 1194 | assert gate_weight.dtype == fc_weight.dtype |
| 1195 | need_qdq = gate_weight.dtype == torch.float8_e4m3fn |
| 1196 | |
| 1197 | gate_weight = gate_weight.to(dtype) |
| 1198 | fc_weight = fc_weight.to(dtype) |
| 1199 | # dequantize if needed |
| 1200 | if need_qdq: |
| 1201 | gate_weight = gate_weight.to(dtype) * numpy_to_torch( |
| 1202 | mlp.gate.weights_scaling_factor.raw_value) |
| 1203 | fc_weight = fc_weight.to(dtype) * numpy_to_torch( |
| 1204 | mlp.fc.weights_scaling_factor.raw_value) |
| 1205 | |
| 1206 | # concat |
| 1207 | fused_weight = torch.cat([gate_weight, fc_weight], dim=0) |
| 1208 | |
| 1209 | fused_weight_scaling_factor = numpy_to_torch( |
| 1210 | max( |
no test coverage detected