Run optimization passes on model. There are dependencies between some passes, so we always run passes in the order of arguments to guarantee the execution order.
(
model: PretrainedModel,
use_parallel_embedding: bool = False,
share_embedding_table: bool = False,
use_ootb_moe: bool = False,
use_fused_mlp: bool = False,
gemm_swiglu_plugin_dtype: Optional[str] = None,
low_latency_gemm_swiglu_plugin_dtype: Optional[str] = None,
use_fused_rg_lru: bool = False,
use_unfused_qkv_gemm: bool = False,
use_prompt_tuning: bool = False,
use_lora: bool = False,
max_lora_rank: Optional[int] = None,
use_fp8_context_fmha: bool = False,
fuse_fp4_quant: bool = False,
use_optimize_cross_qkv: bool = False,
use_dora: bool = False,
)
| 1599 | |
| 1600 | |
| 1601 | def optimize_model( |
| 1602 | model: PretrainedModel, |
| 1603 | use_parallel_embedding: bool = False, |
| 1604 | share_embedding_table: bool = False, |
| 1605 | use_ootb_moe: bool = False, |
| 1606 | use_fused_mlp: bool = False, |
| 1607 | gemm_swiglu_plugin_dtype: Optional[str] = None, |
| 1608 | low_latency_gemm_swiglu_plugin_dtype: Optional[str] = None, |
| 1609 | use_fused_rg_lru: bool = False, |
| 1610 | use_unfused_qkv_gemm: bool = False, |
| 1611 | use_prompt_tuning: bool = False, |
| 1612 | use_lora: bool = False, |
| 1613 | max_lora_rank: Optional[int] = None, |
| 1614 | use_fp8_context_fmha: bool = False, |
| 1615 | fuse_fp4_quant: bool = False, |
| 1616 | use_optimize_cross_qkv: bool = False, |
| 1617 | use_dora: bool = False, |
| 1618 | ) -> PretrainedModel: |
| 1619 | """ |
| 1620 | Run optimization passes on model. |
| 1621 | There are dependencies between some passes, |
| 1622 | so we always run passes in the order of arguments to guarantee the execution order. |
| 1623 | """ |
| 1624 | # before weight loading |
| 1625 | if use_parallel_embedding: |
| 1626 | model = parallelize_embedding(model) |
| 1627 | |
| 1628 | if share_embedding_table: |
| 1629 | # if share_embedding_table is enabled, only one copy of the embedding table is store in converted ckpt |
| 1630 | # this pass is required to make lm_head.weight and vocab_embedding.weight point to the same tensor |
| 1631 | # however even if share_embedding_table is not enabled, trt would still only keep one copy of the table if the weights are identical |
| 1632 | model = share_embedding(model) |
| 1633 | |
| 1634 | # After weight loading |
| 1635 | if use_ootb_moe: |
| 1636 | model = to_ootb_moe(model) |
| 1637 | if use_fused_mlp: |
| 1638 | model = fuse_gate_mlp(model, gemm_swiglu_plugin_dtype, |
| 1639 | low_latency_gemm_swiglu_plugin_dtype) |
| 1640 | if use_fused_rg_lru: |
| 1641 | model = fuse_rg_lru(model) |
| 1642 | if use_unfused_qkv_gemm: |
| 1643 | model = unfuse_qkv_gemm(model) |
| 1644 | if use_prompt_tuning: |
| 1645 | model = set_prompt_tuning(model) |
| 1646 | if use_lora: |
| 1647 | model = add_lora(model, max_lora_rank, with_dora=use_dora) |
| 1648 | if use_fp8_context_fmha: |
| 1649 | model = set_fp8_context_fhma(model) |
| 1650 | if fuse_fp4_quant: |
| 1651 | model = set_fuse_fp4_quant(model) |
| 1652 | if not use_lora and use_optimize_cross_qkv is True: |
| 1653 | # This optimization is not supported when we use lora |
| 1654 | model = optimize_cross_qkv(model) |
| 1655 | |
| 1656 | return model |
| 1657 | |
| 1658 |