MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / fc_gate_plugin

Method fc_gate_plugin

tensorrt_llm/layers/mlp.py:318–389  ·  view source on GitHub ↗
(self, hidden_states, lora_layer_params=None)

Source from the content-addressed store, hash-verified

316 self.fused_gate_up_dora = None
317
318 def fc_gate_plugin(self, hidden_states, lora_layer_params=None):
319 # Combine the following pattern
320 #
321 # SiLU(FC(x)) * Gate(x)
322 #
323 # into:
324 #
325 # SwiGLU(FusedFC(x))
326 if default_net(
327 ).plugin_config.low_latency_gemm_swiglu_plugin is not None:
328 p_dtype = default_net().plugin_config.low_latency_gemm_swiglu_plugin
329 else:
330 p_dtype = default_net().plugin_config.gemm_swiglu_plugin
331 use_fp8 = p_dtype == 'fp8'
332 assert use_fp8, "gemm_swiglu_plugin and low_latency_gemm_swiglu_plugin only supports fp8 now"
333
334 if lora_layer_params is not None:
335 mlp_fc_lora_params = lora_layer_params.get_runtime_params(
336 0, "mlp_h_to_4h")
337 mlp_gate_lora_params = lora_layer_params.get_runtime_params(
338 0, "mlp_gate")
339
340 if mlp_fc_lora_params is not None or mlp_gate_lora_params is not None:
341 raise NotImplementedError(
342 f"LoRA of splitting fc and gate is not yet implemented for gemm_swiglu_plugin"
343 )
344
345 if self.hidden_act != 'silu':
346 raise NotImplementedError(
347 f"Activation {self.hidden_act} not yet implemented for gemm_swiglu_plugin"
348 )
349
350 if self.bias:
351 raise NotImplementedError(
352 f"bias not yet implemented for gemm_swiglu_plugin fp8")
353
354 assert isinstance(
355 self.fused_fc,
356 FP8Linear), "fp8 gemm_swiglu only supports fp8 weights"
357 assert isinstance(
358 self.proj,
359 FP8RowLinear), "fp8 gemm_swiglu only supports fp8 weights"
360 assert self.fused_fc.weight.shape == (
361 self.hidden_size, self.ffn_hidden_size * 2 //
362 self.tp_size), "fp8 gemm_swiglu only supports (k, n) weights"
363
364 scale_d0 = (self.fused_fc.weights_scaling_factor.raw_value.item() *
365 self.fused_fc.activation_scaling_factor.raw_value.item())
366 scale_d1 = scale_d0
367 scale_output = 1.0 / self.proj.activation_scaling_factor.raw_value.item(
368 )
369 activation_scaling_factor = cast(
370 self.fused_fc.activation_scaling_factor.value, self.dtype)
371 if hidden_states.dtype != trt.fp8:
372 hidden_states = quantize(hidden_states, activation_scaling_factor,
373 'fp8')
374
375 if default_net(

Callers 1

forwardMethod · 0.95

Calls 7

quantizeFunction · 0.90
default_netFunction · 0.85
castFunction · 0.85
low_latency_gemm_swigluFunction · 0.85
gemm_swigluFunction · 0.85
fc_gate_loraFunction · 0.85
get_runtime_paramsMethod · 0.80

Tested by

no test coverage detected