(self, hidden_states, lora_layer_params=None)
| 316 | self.fused_gate_up_dora = None |
| 317 | |
| 318 | def fc_gate_plugin(self, hidden_states, lora_layer_params=None): |
| 319 | # Combine the following pattern |
| 320 | # |
| 321 | # SiLU(FC(x)) * Gate(x) |
| 322 | # |
| 323 | # into: |
| 324 | # |
| 325 | # SwiGLU(FusedFC(x)) |
| 326 | if default_net( |
| 327 | ).plugin_config.low_latency_gemm_swiglu_plugin is not None: |
| 328 | p_dtype = default_net().plugin_config.low_latency_gemm_swiglu_plugin |
| 329 | else: |
| 330 | p_dtype = default_net().plugin_config.gemm_swiglu_plugin |
| 331 | use_fp8 = p_dtype == 'fp8' |
| 332 | assert use_fp8, "gemm_swiglu_plugin and low_latency_gemm_swiglu_plugin only supports fp8 now" |
| 333 | |
| 334 | if lora_layer_params is not None: |
| 335 | mlp_fc_lora_params = lora_layer_params.get_runtime_params( |
| 336 | 0, "mlp_h_to_4h") |
| 337 | mlp_gate_lora_params = lora_layer_params.get_runtime_params( |
| 338 | 0, "mlp_gate") |
| 339 | |
| 340 | if mlp_fc_lora_params is not None or mlp_gate_lora_params is not None: |
| 341 | raise NotImplementedError( |
| 342 | f"LoRA of splitting fc and gate is not yet implemented for gemm_swiglu_plugin" |
| 343 | ) |
| 344 | |
| 345 | if self.hidden_act != 'silu': |
| 346 | raise NotImplementedError( |
| 347 | f"Activation {self.hidden_act} not yet implemented for gemm_swiglu_plugin" |
| 348 | ) |
| 349 | |
| 350 | if self.bias: |
| 351 | raise NotImplementedError( |
| 352 | f"bias not yet implemented for gemm_swiglu_plugin fp8") |
| 353 | |
| 354 | assert isinstance( |
| 355 | self.fused_fc, |
| 356 | FP8Linear), "fp8 gemm_swiglu only supports fp8 weights" |
| 357 | assert isinstance( |
| 358 | self.proj, |
| 359 | FP8RowLinear), "fp8 gemm_swiglu only supports fp8 weights" |
| 360 | assert self.fused_fc.weight.shape == ( |
| 361 | self.hidden_size, self.ffn_hidden_size * 2 // |
| 362 | self.tp_size), "fp8 gemm_swiglu only supports (k, n) weights" |
| 363 | |
| 364 | scale_d0 = (self.fused_fc.weights_scaling_factor.raw_value.item() * |
| 365 | self.fused_fc.activation_scaling_factor.raw_value.item()) |
| 366 | scale_d1 = scale_d0 |
| 367 | scale_output = 1.0 / self.proj.activation_scaling_factor.raw_value.item( |
| 368 | ) |
| 369 | activation_scaling_factor = cast( |
| 370 | self.fused_fc.activation_scaling_factor.value, self.dtype) |
| 371 | if hidden_states.dtype != trt.fp8: |
| 372 | hidden_states = quantize(hidden_states, activation_scaling_factor, |
| 373 | 'fp8') |
| 374 | |
| 375 | if default_net( |
no test coverage detected