MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / prepare_inputs

Method prepare_inputs

tensorrt_llm/models/modeling_utils.py:822–965  ·  view source on GitHub ↗

@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the ranges of the dimensions of when using TRT dynamic shapes. @return: a list contains values which can be fed into the self.forward()

(
        self,
        max_batch_size,
        max_input_len,
        max_seq_len,
        max_num_tokens,
        use_cache,
        max_beam_width: int = 1,
        opt_num_tokens: int = None,
        prompt_embedding_table_size: int = 0,
        position_encoding_2d: bool = False,
        max_draft_len: int = 0,
        speculative_decoding_draft_tokens_external: bool = False,
        spec_decoding_is_generation_length_variable: bool = False,
        gather_context_logits: bool = False,
        lora_target_modules: List[str] = None,
        opt_batch_size: int = 0,
        num_hidden_layers: int = None,
        mrope_rotary_cos_sin_size: int = None,
    )

Source from the content-addressed store, hash-verified

820 self.config.to_json_file(os.path.join(output_dir, 'config.json'))
821
822 def prepare_inputs(
823 self,
824 max_batch_size,
825 max_input_len,
826 max_seq_len,
827 max_num_tokens,
828 use_cache,
829 max_beam_width: int = 1,
830 opt_num_tokens: int = None,
831 prompt_embedding_table_size: int = 0,
832 position_encoding_2d: bool = False,
833 max_draft_len: int = 0,
834 speculative_decoding_draft_tokens_external: bool = False,
835 spec_decoding_is_generation_length_variable: bool = False,
836 gather_context_logits: bool = False,
837 lora_target_modules: List[str] = None,
838 opt_batch_size: int = 0,
839 num_hidden_layers: int = None,
840 mrope_rotary_cos_sin_size: int = None,
841 ):
842 '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
843 ranges of the dimensions of when using TRT dynamic shapes.
844
845 @return: a list contains values which can be fed into the self.forward()
846 '''
847
848 # Prepare inputs
849 remove_input_padding = default_net().plugin_config.remove_input_padding
850 use_gpt_attention_plugin = default_net(
851 ).plugin_config.gpt_attention_plugin
852 use_gemm_plugin = default_net().plugin_config.gemm_plugin
853 paged_kv_cache = default_net().plugin_config.paged_kv_cache
854 tokens_per_block = default_net().plugin_config.tokens_per_block
855 use_lora_plugin = default_net().plugin_config.lora_plugin
856 multiple_profiles = default_net().plugin_config.multiple_profiles
857 streamingllm = default_net().plugin_config.streamingllm
858 pp_reduce_scatter = default_net().plugin_config.pp_reduce_scatter
859
860 kv_cache_type = None
861 if not use_cache:
862 kv_cache_type = KVCacheType.DISABLED
863 else:
864 if paged_kv_cache:
865 kv_cache_type = KVCacheType.PAGED
866 else:
867 kv_cache_type = KVCacheType.CONTINUOUS
868
869 model_inputs = self.prepare_basic_inputs(
870 max_batch_size=max_batch_size,
871 max_beam_width=max_beam_width,
872 max_input_len=max_input_len,
873 max_seq_len=max_seq_len,
874 hidden_size=self.config.hidden_size,
875 num_kv_heads=self.config.num_key_value_heads,
876 head_size=self.config.head_size,
877 num_layers=num_hidden_layers
878 if num_hidden_layers is not None else self.config.num_hidden_layers,
879 kv_dtype=str_dtype_to_trt(self.config.kv_dtype),

Callers

nothing calls this directly

Calls 6

default_netFunction · 0.85
str_dtype_to_trtFunction · 0.85
KeyValueCacheParamsClass · 0.85
AttentionParamsClass · 0.85
LoraParamsClass · 0.85
prepare_basic_inputsMethod · 0.80

Tested by

no test coverage detected