@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the ranges of the dimensions of when using TRT dynamic shapes. @return: a list contains values which can be fed into the self.forward()
(
self,
max_batch_size,
max_input_len,
max_seq_len,
max_num_tokens,
use_cache,
max_beam_width: int = 1,
opt_num_tokens: int = None,
prompt_embedding_table_size: int = 0,
position_encoding_2d: bool = False,
max_draft_len: int = 0,
speculative_decoding_draft_tokens_external: bool = False,
spec_decoding_is_generation_length_variable: bool = False,
gather_context_logits: bool = False,
lora_target_modules: List[str] = None,
opt_batch_size: int = 0,
num_hidden_layers: int = None,
mrope_rotary_cos_sin_size: int = None,
)
| 820 | self.config.to_json_file(os.path.join(output_dir, 'config.json')) |
| 821 | |
| 822 | def prepare_inputs( |
| 823 | self, |
| 824 | max_batch_size, |
| 825 | max_input_len, |
| 826 | max_seq_len, |
| 827 | max_num_tokens, |
| 828 | use_cache, |
| 829 | max_beam_width: int = 1, |
| 830 | opt_num_tokens: int = None, |
| 831 | prompt_embedding_table_size: int = 0, |
| 832 | position_encoding_2d: bool = False, |
| 833 | max_draft_len: int = 0, |
| 834 | speculative_decoding_draft_tokens_external: bool = False, |
| 835 | spec_decoding_is_generation_length_variable: bool = False, |
| 836 | gather_context_logits: bool = False, |
| 837 | lora_target_modules: List[str] = None, |
| 838 | opt_batch_size: int = 0, |
| 839 | num_hidden_layers: int = None, |
| 840 | mrope_rotary_cos_sin_size: int = None, |
| 841 | ): |
| 842 | '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the |
| 843 | ranges of the dimensions of when using TRT dynamic shapes. |
| 844 | |
| 845 | @return: a list contains values which can be fed into the self.forward() |
| 846 | ''' |
| 847 | |
| 848 | # Prepare inputs |
| 849 | remove_input_padding = default_net().plugin_config.remove_input_padding |
| 850 | use_gpt_attention_plugin = default_net( |
| 851 | ).plugin_config.gpt_attention_plugin |
| 852 | use_gemm_plugin = default_net().plugin_config.gemm_plugin |
| 853 | paged_kv_cache = default_net().plugin_config.paged_kv_cache |
| 854 | tokens_per_block = default_net().plugin_config.tokens_per_block |
| 855 | use_lora_plugin = default_net().plugin_config.lora_plugin |
| 856 | multiple_profiles = default_net().plugin_config.multiple_profiles |
| 857 | streamingllm = default_net().plugin_config.streamingllm |
| 858 | pp_reduce_scatter = default_net().plugin_config.pp_reduce_scatter |
| 859 | |
| 860 | kv_cache_type = None |
| 861 | if not use_cache: |
| 862 | kv_cache_type = KVCacheType.DISABLED |
| 863 | else: |
| 864 | if paged_kv_cache: |
| 865 | kv_cache_type = KVCacheType.PAGED |
| 866 | else: |
| 867 | kv_cache_type = KVCacheType.CONTINUOUS |
| 868 | |
| 869 | model_inputs = self.prepare_basic_inputs( |
| 870 | max_batch_size=max_batch_size, |
| 871 | max_beam_width=max_beam_width, |
| 872 | max_input_len=max_input_len, |
| 873 | max_seq_len=max_seq_len, |
| 874 | hidden_size=self.config.hidden_size, |
| 875 | num_kv_heads=self.config.num_key_value_heads, |
| 876 | head_size=self.config.head_size, |
| 877 | num_layers=num_hidden_layers |
| 878 | if num_hidden_layers is not None else self.config.num_hidden_layers, |
| 879 | kv_dtype=str_dtype_to_trt(self.config.kv_dtype), |
nothing calls this directly
no test coverage detected