(model: PretrainedModel,
build_config: BuildConfig)
| 758 | |
| 759 | |
| 760 | def optimize_model_with_config(model: PretrainedModel, |
| 761 | build_config: BuildConfig): |
| 762 | gemm_swiglu_plugin = build_config.plugin_config.gemm_swiglu_plugin |
| 763 | low_latency_gemm_swiglu_plugin = build_config.plugin_config.low_latency_gemm_swiglu_plugin |
| 764 | if gemm_swiglu_plugin or low_latency_gemm_swiglu_plugin: |
| 765 | if not build_config.plugin_config.use_fused_mlp: |
| 766 | raise RuntimeError( |
| 767 | "GemmSwiGLU plugin requires --use_fused_mlp flag") |
| 768 | if gemm_swiglu_plugin not in [ |
| 769 | "fp8" |
| 770 | ] and low_latency_gemm_swiglu_plugin not in ["fp8"]: |
| 771 | raise RuntimeError( |
| 772 | f"GemmSwiGLU plugin currently has limited support: fp8 only, " |
| 773 | f"got: {gemm_swiglu_plugin}" |
| 774 | f"got: {low_latency_gemm_swiglu_plugin}") |
| 775 | |
| 776 | if build_config.plugin_config.lora_plugin is not None: |
| 777 | model.use_lora(build_config.lora_config) |
| 778 | |
| 779 | is_enc_dec = model.config.architecture in ["EncoderModel", "DecoderModel"] |
| 780 | # FusedMLP does not support RecurrentGemma FP8 currently. |
| 781 | is_recurrent_gemma = model.config.architecture in [ |
| 782 | "RecurrentGemmaForCausalLM" |
| 783 | ] |
| 784 | is_fp8 = model.config.quantization.quant_algo == QuantAlgo.FP8 |
| 785 | model = optimize_model( |
| 786 | model, |
| 787 | share_embedding_table=True, |
| 788 | use_ootb_moe=build_config.plugin_config.moe_plugin is None, |
| 789 | use_fused_mlp=(build_config.plugin_config.use_fused_mlp |
| 790 | and not is_enc_dec |
| 791 | and not (is_recurrent_gemma and is_fp8)), |
| 792 | gemm_swiglu_plugin_dtype=gemm_swiglu_plugin, |
| 793 | low_latency_gemm_swiglu_plugin_dtype=low_latency_gemm_swiglu_plugin, |
| 794 | use_fused_rg_lru=is_recurrent_gemma, |
| 795 | use_unfused_qkv_gemm=False, |
| 796 | use_prompt_tuning=(build_config.max_prompt_embedding_table_size > 0), |
| 797 | use_lora=build_config.plugin_config.lora_plugin is not None, |
| 798 | max_lora_rank=build_config.lora_config.max_lora_rank, |
| 799 | use_fp8_context_fmha=(model.config.quantization.quant_algo in [ |
| 800 | QuantAlgo.FP8, QuantAlgo.W4A8_AWQ, QuantAlgo.NVFP4 |
| 801 | ] and build_config.plugin_config.use_fp8_context_fmha), |
| 802 | fuse_fp4_quant=build_config.plugin_config.fuse_fp4_quant, |
| 803 | use_optimize_cross_qkv=True, |
| 804 | use_dora=build_config.plugin_config.dora_plugin) |
| 805 | |
| 806 | if is_enc_dec: |
| 807 | model.precompute_relative_attention_bias(build_config) |
| 808 | return model |
| 809 | |
| 810 | |
| 811 | def _init_max_seq_len(model_config, build_config): |
no test coverage detected