MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / optimize_model_with_config

Function optimize_model_with_config

tensorrt_llm/builder.py:760–808  ·  view source on GitHub ↗
(model: PretrainedModel,
                               build_config: BuildConfig)

Source from the content-addressed store, hash-verified

758
759
760def optimize_model_with_config(model: PretrainedModel,
761 build_config: BuildConfig):
762 gemm_swiglu_plugin = build_config.plugin_config.gemm_swiglu_plugin
763 low_latency_gemm_swiglu_plugin = build_config.plugin_config.low_latency_gemm_swiglu_plugin
764 if gemm_swiglu_plugin or low_latency_gemm_swiglu_plugin:
765 if not build_config.plugin_config.use_fused_mlp:
766 raise RuntimeError(
767 "GemmSwiGLU plugin requires --use_fused_mlp flag")
768 if gemm_swiglu_plugin not in [
769 "fp8"
770 ] and low_latency_gemm_swiglu_plugin not in ["fp8"]:
771 raise RuntimeError(
772 f"GemmSwiGLU plugin currently has limited support: fp8 only, "
773 f"got: {gemm_swiglu_plugin}"
774 f"got: {low_latency_gemm_swiglu_plugin}")
775
776 if build_config.plugin_config.lora_plugin is not None:
777 model.use_lora(build_config.lora_config)
778
779 is_enc_dec = model.config.architecture in ["EncoderModel", "DecoderModel"]
780 # FusedMLP does not support RecurrentGemma FP8 currently.
781 is_recurrent_gemma = model.config.architecture in [
782 "RecurrentGemmaForCausalLM"
783 ]
784 is_fp8 = model.config.quantization.quant_algo == QuantAlgo.FP8
785 model = optimize_model(
786 model,
787 share_embedding_table=True,
788 use_ootb_moe=build_config.plugin_config.moe_plugin is None,
789 use_fused_mlp=(build_config.plugin_config.use_fused_mlp
790 and not is_enc_dec
791 and not (is_recurrent_gemma and is_fp8)),
792 gemm_swiglu_plugin_dtype=gemm_swiglu_plugin,
793 low_latency_gemm_swiglu_plugin_dtype=low_latency_gemm_swiglu_plugin,
794 use_fused_rg_lru=is_recurrent_gemma,
795 use_unfused_qkv_gemm=False,
796 use_prompt_tuning=(build_config.max_prompt_embedding_table_size > 0),
797 use_lora=build_config.plugin_config.lora_plugin is not None,
798 max_lora_rank=build_config.lora_config.max_lora_rank,
799 use_fp8_context_fmha=(model.config.quantization.quant_algo in [
800 QuantAlgo.FP8, QuantAlgo.W4A8_AWQ, QuantAlgo.NVFP4
801 ] and build_config.plugin_config.use_fp8_context_fmha),
802 fuse_fp4_quant=build_config.plugin_config.fuse_fp4_quant,
803 use_optimize_cross_qkv=True,
804 use_dora=build_config.plugin_config.dora_plugin)
805
806 if is_enc_dec:
807 model.precompute_relative_attention_bias(build_config)
808 return model
809
810
811def _init_max_seq_len(model_config, build_config):

Callers 2

refit_engineFunction · 0.90
buildFunction · 0.85

Calls 3

optimize_modelFunction · 0.85
use_loraMethod · 0.45

Tested by

no test coverage detected