Build engine from given model and optimization options specified in the build_config WARNING: this function may change the given model object state in some optimization passes to avoid cloning a model since normally the LLM models consumes large memory. Create a new fresh model
(model: PretrainedModel, build_config: BuildConfig)
| 970 | |
| 971 | |
| 972 | def build(model: PretrainedModel, build_config: BuildConfig) -> Engine: |
| 973 | '''Build engine from given model and optimization options specified in the build_config |
| 974 | WARNING: this function may change the given model object state in some optimization passes |
| 975 | to avoid cloning a model since normally the LLM models consumes large memory. |
| 976 | Create a new fresh model object if you need to build with different options. |
| 977 | ''' |
| 978 | tic = time.time() |
| 979 | # avoid changing the input config |
| 980 | build_config = build_config.model_copy(deep=True) |
| 981 | build_config.plugin_config.dtype = model.config.dtype |
| 982 | build_config.update_kv_cache_type(model.config.architecture) |
| 983 | |
| 984 | _init_max_seq_len(model.config, build_config) |
| 985 | |
| 986 | if build_config.plugin_config.streamingllm: |
| 987 | build_config.plugin_config.use_paged_context_fmha = False |
| 988 | logger.warning( |
| 989 | "Paged Context FMHA is disabled because StreamingLLM is not supported when enabling paged KV context FMHA." |
| 990 | ) |
| 991 | if build_config.plugin_config.reduce_fusion and ( |
| 992 | model.config.mapping.tp_size == 1 or |
| 993 | (model.config.architecture != "LlamaForCausalLM" |
| 994 | and model.config.architecture != "Gemma2ForCausalLM" |
| 995 | and model.config.architecture != "MedusaForCausalLM")): |
| 996 | logger.warning('Overriding reduce_fusion to False') |
| 997 | build_config.plugin_config.reduce_fusion = False |
| 998 | if build_config.plugin_config.user_buffer and not build_config.plugin_config.reduce_fusion: |
| 999 | logger.warning('Overriding user_buffer to False') |
| 1000 | build_config.plugin_config.user_buffer = False |
| 1001 | if build_config.plugin_config.norm_quant_fusion and ( |
| 1002 | build_config.plugin_config.reduce_fusion |
| 1003 | or model.config.architecture != "LlamaForCausalLM" |
| 1004 | or model.config.quantization.quant_algo != QuantAlgo.NVFP4): |
| 1005 | logger.warning('Overriding norm_quant_fusion to False') |
| 1006 | build_config.plugin_config.norm_quant_fusion = False |
| 1007 | |
| 1008 | if model.config.quantization.quant_algo == QuantAlgo.FP8 or \ |
| 1009 | model.config.quantization.kv_cache_quant_algo == QuantAlgo.FP8: |
| 1010 | build_config.strongly_typed = True |
| 1011 | |
| 1012 | if hasattr(model.config, 'max_draft_len'): |
| 1013 | # If model.config has 'max_draft_len' but build_config not specified, |
| 1014 | # use the value of model.config.max_draft_len to set the value of build_config.max_draft_len |
| 1015 | if build_config.max_draft_len == 0: |
| 1016 | build_config.max_draft_len = model.config.max_draft_len |
| 1017 | |
| 1018 | if hasattr(model.config, 'redrafter_num_beams') and hasattr( |
| 1019 | model.config, 'redrafter_draft_len_per_beam'): |
| 1020 | build_config.max_draft_len = model.config.redrafter_num_beams * model.config.redrafter_draft_len_per_beam |
| 1021 | if build_config.speculative_decoding_mode != SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENS: |
| 1022 | logger.warning( |
| 1023 | 'speculative_decoding_mode is not EXPLICIT_DRAFT_TOKENS for ReDrafter model. Overwriting speculative_decoding_mode' |
| 1024 | ) |
| 1025 | build_config.speculative_decoding_mode = SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENS |
| 1026 | |
| 1027 | if build_config.speculative_decoding_mode != SpeculativeDecodingMode.NONE: |
| 1028 | logger.info( |
| 1029 | f'Increasing max_seq_len ({build_config.max_seq_len}) ' |