MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / fuse_gate_mlp

Function fuse_gate_mlp

tensorrt_llm/models/modeling_utils.py:1153–1299  ·  view source on GitHub ↗
(
    model: PretrainedModel,
    gemm_swiglu_plugin_dtype: Optional[str] = None,
    low_latency_gemm_swiglu_plugin_dtype: Optional[str] = None,
)

Source from the content-addressed store, hash-verified

1151
1152
1153def fuse_gate_mlp(
1154 model: PretrainedModel,
1155 gemm_swiglu_plugin_dtype: Optional[str] = None,
1156 low_latency_gemm_swiglu_plugin_dtype: Optional[str] = None,
1157) -> PretrainedModel:
1158 from ..quantization.quantize import fp8_quantize
1159
1160 for name, mlp, layer in model.named_modules_with_parent():
1161 if isinstance(mlp, GatedMLP):
1162 init_params = get_init_params(mlp)
1163
1164 hidden_act = init_params["hidden_act"]
1165 if hidden_act not in ["silu", "gelu"]:
1166 logger.warning(
1167 f"fuse_gate_mlp cannot be done for {name} due to unsupported activation {hidden_act}. Skipping."
1168 )
1169 continue
1170
1171 init_params["inner_layernorm"] = mlp.inner_layernorm is not None
1172 fused_layer = FusedGatedMLP(**init_params)
1173
1174 fc_name = name + '.fc'
1175 layer_quant_cfg = model.config._get_quant_cfg(fc_name)
1176 layer_quant_algo = layer_quant_cfg.quant_algo
1177 if layer_quant_algo != QuantAlgo.FP8 and layer_quant_algo is not None:
1178 continue
1179
1180 if isinstance(model.config.quantization.exclude_modules, list) \
1181 and fc_name in model.config.quantization.exclude_modules:
1182 layer_quant_algo = None
1183
1184 if layer_quant_algo == QuantAlgo.FP8:
1185 fused_layer = fp8_quantize(fused_layer, layer_quant_cfg)
1186
1187 if isinstance(mlp.dtype, str):
1188 dtype = str_dtype_to_torch(mlp.dtype)
1189 else:
1190 dtype = trt_dtype_to_torch(mlp.dtype)
1191
1192 gate_weight = numpy_to_torch(mlp.gate.weight.raw_value)
1193 fc_weight = numpy_to_torch(mlp.fc.weight.raw_value)
1194 assert gate_weight.dtype == fc_weight.dtype
1195 need_qdq = gate_weight.dtype == torch.float8_e4m3fn
1196
1197 gate_weight = gate_weight.to(dtype)
1198 fc_weight = fc_weight.to(dtype)
1199 # dequantize if needed
1200 if need_qdq:
1201 gate_weight = gate_weight.to(dtype) * numpy_to_torch(
1202 mlp.gate.weights_scaling_factor.raw_value)
1203 fc_weight = fc_weight.to(dtype) * numpy_to_torch(
1204 mlp.fc.weights_scaling_factor.raw_value)
1205
1206 # concat
1207 fused_weight = torch.cat([gate_weight, fc_weight], dim=0)
1208
1209 fused_weight_scaling_factor = numpy_to_torch(
1210 max(

Callers 1

optimize_modelFunction · 0.85

Calls 14

get_init_paramsFunction · 0.85
FusedGatedMLPClass · 0.85
fp8_quantizeFunction · 0.85
str_dtype_to_torchFunction · 0.85
numpy_to_torchFunction · 0.85
maxFunction · 0.85
ParameterClass · 0.85
trt_dtype_to_torchFunction · 0.50
warningMethod · 0.45
_get_quant_cfgMethod · 0.45

Tested by

no test coverage detected