MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / optimize_model

Function optimize_model

tensorrt_llm/models/modeling_utils.py:1601–1656  ·  view source on GitHub ↗

Run optimization passes on model. There are dependencies between some passes, so we always run passes in the order of arguments to guarantee the execution order.

(
    model: PretrainedModel,
    use_parallel_embedding: bool = False,
    share_embedding_table: bool = False,
    use_ootb_moe: bool = False,
    use_fused_mlp: bool = False,
    gemm_swiglu_plugin_dtype: Optional[str] = None,
    low_latency_gemm_swiglu_plugin_dtype: Optional[str] = None,
    use_fused_rg_lru: bool = False,
    use_unfused_qkv_gemm: bool = False,
    use_prompt_tuning: bool = False,
    use_lora: bool = False,
    max_lora_rank: Optional[int] = None,
    use_fp8_context_fmha: bool = False,
    fuse_fp4_quant: bool = False,
    use_optimize_cross_qkv: bool = False,
    use_dora: bool = False,
)

Source from the content-addressed store, hash-verified

1599
1600
1601def optimize_model(
1602 model: PretrainedModel,
1603 use_parallel_embedding: bool = False,
1604 share_embedding_table: bool = False,
1605 use_ootb_moe: bool = False,
1606 use_fused_mlp: bool = False,
1607 gemm_swiglu_plugin_dtype: Optional[str] = None,
1608 low_latency_gemm_swiglu_plugin_dtype: Optional[str] = None,
1609 use_fused_rg_lru: bool = False,
1610 use_unfused_qkv_gemm: bool = False,
1611 use_prompt_tuning: bool = False,
1612 use_lora: bool = False,
1613 max_lora_rank: Optional[int] = None,
1614 use_fp8_context_fmha: bool = False,
1615 fuse_fp4_quant: bool = False,
1616 use_optimize_cross_qkv: bool = False,
1617 use_dora: bool = False,
1618) -> PretrainedModel:
1619 """
1620 Run optimization passes on model.
1621 There are dependencies between some passes,
1622 so we always run passes in the order of arguments to guarantee the execution order.
1623 """
1624 # before weight loading
1625 if use_parallel_embedding:
1626 model = parallelize_embedding(model)
1627
1628 if share_embedding_table:
1629 # if share_embedding_table is enabled, only one copy of the embedding table is store in converted ckpt
1630 # this pass is required to make lm_head.weight and vocab_embedding.weight point to the same tensor
1631 # however even if share_embedding_table is not enabled, trt would still only keep one copy of the table if the weights are identical
1632 model = share_embedding(model)
1633
1634 # After weight loading
1635 if use_ootb_moe:
1636 model = to_ootb_moe(model)
1637 if use_fused_mlp:
1638 model = fuse_gate_mlp(model, gemm_swiglu_plugin_dtype,
1639 low_latency_gemm_swiglu_plugin_dtype)
1640 if use_fused_rg_lru:
1641 model = fuse_rg_lru(model)
1642 if use_unfused_qkv_gemm:
1643 model = unfuse_qkv_gemm(model)
1644 if use_prompt_tuning:
1645 model = set_prompt_tuning(model)
1646 if use_lora:
1647 model = add_lora(model, max_lora_rank, with_dora=use_dora)
1648 if use_fp8_context_fmha:
1649 model = set_fp8_context_fhma(model)
1650 if fuse_fp4_quant:
1651 model = set_fuse_fp4_quant(model)
1652 if not use_lora and use_optimize_cross_qkv is True:
1653 # This optimization is not supported when we use lora
1654 model = optimize_cross_qkv(model)
1655
1656 return model
1657
1658

Callers 5

build_gptFunction · 0.90
__post_init__Method · 0.85
from_hugging_faceFunction · 0.85

Calls 11

parallelize_embeddingFunction · 0.85
share_embeddingFunction · 0.85
to_ootb_moeFunction · 0.85
fuse_gate_mlpFunction · 0.85
fuse_rg_lruFunction · 0.85
unfuse_qkv_gemmFunction · 0.85
set_prompt_tuningFunction · 0.85
add_loraFunction · 0.85
set_fp8_context_fhmaFunction · 0.85
set_fuse_fp4_quantFunction · 0.85
optimize_cross_qkvFunction · 0.85

Tested by 1