MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/models/modeling_utils.py:1037–1150  ·  view source on GitHub ↗
(self,
                input_ids: Tensor,
                position_ids=None,
                use_cache=False,
                last_token_ids=None,
                attention_mask=None,
                kv_cache_params=None,
                attention_params=None,
                mrope_params=None,
                hidden_states=None,
                prompt_embedding_table: Optional[Tensor] = None,
                prompt_tasks: Optional[Tensor] = None,
                prompt_vocab_size: Optional[Tensor] = None,
                lora_params=None,
                spec_decoding_params=None)

Source from the content-addressed store, hash-verified

1035 self.position_embedding_type = config.position_embedding_type
1036
1037 def forward(self,
1038 input_ids: Tensor,
1039 position_ids=None,
1040 use_cache=False,
1041 last_token_ids=None,
1042 attention_mask=None,
1043 kv_cache_params=None,
1044 attention_params=None,
1045 mrope_params=None,
1046 hidden_states=None,
1047 prompt_embedding_table: Optional[Tensor] = None,
1048 prompt_tasks: Optional[Tensor] = None,
1049 prompt_vocab_size: Optional[Tensor] = None,
1050 lora_params=None,
1051 spec_decoding_params=None):
1052
1053 # fill attention params.
1054 attention_params = Attention.fill_attention_params(
1055 self, attention_params)
1056
1057 # split the sequence for context parallelism
1058 if self.config.mapping.cp_size > 1:
1059 if len(input_ids.shape) == 1:
1060 # input shape is [-1]
1061 input_ids, cp_join_index = cp_split_plugin(
1062 input_ids,
1063 attention_params.host_request_types,
1064 attention_params.host_context_lengths,
1065 self.config.mapping.cp_size,
1066 self.config.mapping.cp_rank,
1067 )
1068 else:
1069 assert False, "Context parallelism with non-remove-padding is not supported yet."
1070
1071 is_gemma_2_cg = self.config.has_config_group(Gemma2ConfigGroup)
1072 is_gemma_3_cg = self.config.has_config_group(Gemma3ConfigGroup)
1073
1074 kwargs = {
1075 'input_ids': input_ids,
1076 'position_ids': position_ids,
1077 'use_cache': use_cache,
1078 'attention_mask': attention_mask,
1079 'kv_cache_params': kv_cache_params,
1080 'attention_params': attention_params,
1081 }
1082 if lora_params is not None:
1083 kwargs['lora_params'] = lora_params
1084 if hidden_states is not None:
1085 kwargs['hidden_states'] = hidden_states
1086 if prompt_embedding_table is not None:
1087 kwargs['prompt_embedding_table'] = prompt_embedding_table
1088 if prompt_tasks is not None:
1089 kwargs['prompt_tasks'] = prompt_tasks
1090 if prompt_vocab_size is not None:
1091 kwargs['prompt_vocab_size'] = prompt_vocab_size
1092
1093 if spec_decoding_params is not None:
1094 kwargs['spec_decoding_params'] = spec_decoding_params

Callers

nothing calls this directly

Calls 13

cp_split_pluginFunction · 0.85
viewFunction · 0.85
index_selectFunction · 0.85
gather_last_token_logitsFunction · 0.85
default_netFunction · 0.85
fill_attention_paramsMethod · 0.80
has_config_groupMethod · 0.80
get_config_groupMethod · 0.80
mark_outputMethod · 0.80
pp_layersMethod · 0.80
allgatherFunction · 0.50
forwardMethod · 0.45

Tested by

no test coverage detected