MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / _init_max_seq_len

Function _init_max_seq_len

tensorrt_llm/builder.py:811–890  ·  view source on GitHub ↗

If max_seq_len is not specified, set it to max_position_embeddings * rotary_factor Additional checks to ensure max_seq_len, max_input_len, and max_num_tokens have valid values.

(model_config, build_config)

Source from the content-addressed store, hash-verified

809
810
811def _init_max_seq_len(model_config, build_config):
812 """
813 If max_seq_len is not specified, set it to max_position_embeddings * rotary_factor
814 Additional checks to ensure max_seq_len, max_input_len, and max_num_tokens have valid values.
815 """
816 # Extract rotary scaling which will be used for checks and default value of max_seq_len
817 rotary_scaling = getattr(model_config, "rotary_scaling", None)
818 if rotary_scaling is not None:
819 rotary_type = rotary_scaling.get('type',
820 rotary_scaling.get('rope_type'))
821 rotary_factor = rotary_scaling.get(
822 'factor', 1.0) if rotary_type not in ("su", "longrope",
823 "llama3") else 1
824 else:
825 rotary_factor = 1
826
827 if model_config.architecture == "EncoderModel":
828 if build_config.max_seq_len is None:
829 build_config.max_seq_len = build_config.max_input_len
830 logger.info(
831 f'max_seq_len is not specified for EncoderModel, using --max_input_len.'
832 )
833 assert build_config.max_input_len == build_config.max_seq_len, f"EncoderModel should have same --max_input_len ({build_config.max_input_len}) and --max_seq_len ({build_config.max_seq_len})."
834
835 if build_config.max_seq_len is None:
836 # Step 1: Find the upper bound of max_seq_len
837 deduced_max_seq_len = 2048
838 if model_config.max_position_embeddings is not None:
839 deduced_max_seq_len = model_config.max_position_embeddings
840
841 # Step 2: Scale max_seq_len with rotary scaling
842 if rotary_factor != 1:
843 deduced_max_seq_len = math.ceil(deduced_max_seq_len * rotary_factor)
844 logger.warning(
845 f'max_seq_len is scaled to {deduced_max_seq_len} by rotary scaling {rotary_factor}'
846 )
847
848 # Step 3: Assign the new max_seq_len
849 build_config.max_seq_len = int(deduced_max_seq_len)
850 logger.info(
851 f'max_seq_len is not specified, using deduced value {deduced_max_seq_len}'
852 )
853 else:
854 if not build_config.plugin_config.streamingllm and model_config.max_position_embeddings is not None \
855 and model_config.position_embedding_type != PositionEmbeddingType.relative:
856 if build_config.max_seq_len > model_config.max_position_embeddings * rotary_factor:
857 logger.warning(
858 f'max_seq_len {build_config.max_seq_len} is larger than max_position_embeddings {model_config.max_position_embeddings} * rotary scaling {rotary_factor}, '
859 'the model accuracy might be affected')
860
861 if build_config.max_input_len > build_config.max_seq_len:
862 logger.warning(
863 f'max_input_len is {build_config.max_input_len} is larger than max_seq_len {build_config.max_seq_len}, clipping it to max_seq_len'
864 )
865 build_config.max_input_len = build_config.max_seq_len
866
867 # Check and may modify max_num_tokens and opt_num_tokens (need to happen after max_seq_len is deduced)
868 max_num_tokens, opt_num_tokens = check_max_num_tokens(

Callers 1

buildFunction · 0.85

Calls 4

check_max_num_tokensFunction · 0.85
getMethod · 0.45
infoMethod · 0.45
warningMethod · 0.45

Tested by

no test coverage detected