MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / build

Function build

tensorrt_llm/builder.py:972–1289  ·  view source on GitHub ↗

Build engine from given model and optimization options specified in the build_config WARNING: this function may change the given model object state in some optimization passes to avoid cloning a model since normally the LLM models consumes large memory. Create a new fresh model

(model: PretrainedModel, build_config: BuildConfig)

Source from the content-addressed store, hash-verified

970
971
972def build(model: PretrainedModel, build_config: BuildConfig) -> Engine:
973 '''Build engine from given model and optimization options specified in the build_config
974 WARNING: this function may change the given model object state in some optimization passes
975 to avoid cloning a model since normally the LLM models consumes large memory.
976 Create a new fresh model object if you need to build with different options.
977 '''
978 tic = time.time()
979 # avoid changing the input config
980 build_config = build_config.model_copy(deep=True)
981 build_config.plugin_config.dtype = model.config.dtype
982 build_config.update_kv_cache_type(model.config.architecture)
983
984 _init_max_seq_len(model.config, build_config)
985
986 if build_config.plugin_config.streamingllm:
987 build_config.plugin_config.use_paged_context_fmha = False
988 logger.warning(
989 "Paged Context FMHA is disabled because StreamingLLM is not supported when enabling paged KV context FMHA."
990 )
991 if build_config.plugin_config.reduce_fusion and (
992 model.config.mapping.tp_size == 1 or
993 (model.config.architecture != "LlamaForCausalLM"
994 and model.config.architecture != "Gemma2ForCausalLM"
995 and model.config.architecture != "MedusaForCausalLM")):
996 logger.warning('Overriding reduce_fusion to False')
997 build_config.plugin_config.reduce_fusion = False
998 if build_config.plugin_config.user_buffer and not build_config.plugin_config.reduce_fusion:
999 logger.warning('Overriding user_buffer to False')
1000 build_config.plugin_config.user_buffer = False
1001 if build_config.plugin_config.norm_quant_fusion and (
1002 build_config.plugin_config.reduce_fusion
1003 or model.config.architecture != "LlamaForCausalLM"
1004 or model.config.quantization.quant_algo != QuantAlgo.NVFP4):
1005 logger.warning('Overriding norm_quant_fusion to False')
1006 build_config.plugin_config.norm_quant_fusion = False
1007
1008 if model.config.quantization.quant_algo == QuantAlgo.FP8 or \
1009 model.config.quantization.kv_cache_quant_algo == QuantAlgo.FP8:
1010 build_config.strongly_typed = True
1011
1012 if hasattr(model.config, 'max_draft_len'):
1013 # If model.config has 'max_draft_len' but build_config not specified,
1014 # use the value of model.config.max_draft_len to set the value of build_config.max_draft_len
1015 if build_config.max_draft_len == 0:
1016 build_config.max_draft_len = model.config.max_draft_len
1017
1018 if hasattr(model.config, 'redrafter_num_beams') and hasattr(
1019 model.config, 'redrafter_draft_len_per_beam'):
1020 build_config.max_draft_len = model.config.redrafter_num_beams * model.config.redrafter_draft_len_per_beam
1021 if build_config.speculative_decoding_mode != SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENS:
1022 logger.warning(
1023 'speculative_decoding_mode is not EXPLICIT_DRAFT_TOKENS for ReDrafter model. Overwriting speculative_decoding_mode'
1024 )
1025 build_config.speculative_decoding_mode = SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENS
1026
1027 if build_config.speculative_decoding_mode != SpeculativeDecodingMode.NONE:
1028 logger.info(
1029 f'Increasing max_seq_len ({build_config.max_seq_len}) '

Callers 10

build_gptFunction · 0.90
enc_dec_build_helperFunction · 0.90
build_from_hfFunction · 0.90
test_save_loadFunction · 0.90
test_async_ioFunction · 0.90
test_fp8_quantizationFunction · 0.90
build_modelFunction · 0.90
_build_engineMethod · 0.50

Calls 15

create_builder_configMethod · 0.95
create_networkMethod · 0.95
build_engineMethod · 0.95
save_timing_cacheMethod · 0.95
_init_max_seq_lenFunction · 0.85
BuilderClass · 0.85
net_guardFunction · 0.85
str_dtype_to_trtFunction · 0.85
optimizeFunction · 0.85
EngineConfigClass · 0.85
sumFunction · 0.85

Tested by 5

test_save_loadFunction · 0.72
test_async_ioFunction · 0.72
test_fp8_quantizationFunction · 0.72