MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/models/enc_dec/model.py:1215–1339  ·  view source on GitHub ↗
(self,
                decoder_input_ids: Tensor,
                encoder_output: Tensor,
                position_ids=None,
                token_type_ids=None,
                use_cache=False,
                attention_mask_params=None,
                last_token_ids=None,
                kv_cache_params=None,
                attention_params=None,
                hidden_states=None,
                lora_params: LoraParams = None,
                cross_kv_cache_gen: Optional[Tensor] = None,
                cross_kv_reuse: Optional[Tensor] = None,
                language_adapter_routings: Optional[Tensor] = None)

Source from the content-addressed store, hash-verified

1213 config.set_if_not_exist('relative_attention', False)
1214
1215 def forward(self,
1216 decoder_input_ids: Tensor,
1217 encoder_output: Tensor,
1218 position_ids=None,
1219 token_type_ids=None,
1220 use_cache=False,
1221 attention_mask_params=None,
1222 last_token_ids=None,
1223 kv_cache_params=None,
1224 attention_params=None,
1225 hidden_states=None,
1226 lora_params: LoraParams = None,
1227 cross_kv_cache_gen: Optional[Tensor] = None,
1228 cross_kv_reuse: Optional[Tensor] = None,
1229 language_adapter_routings: Optional[Tensor] = None):
1230 if self.mapping.is_first_pp_rank():
1231 assert isinstance(decoder_input_ids, Tensor)
1232 else:
1233 assert isinstance(hidden_states, Tensor)
1234
1235 # In PP, layer 0 has ids as inputs, all other layers have hidden_states as inputs
1236 if self.mapping.is_first_pp_rank():
1237 hidden_states = self.transformer.embedding(decoder_input_ids,
1238 position_ids,
1239 token_type_ids)
1240 self.register_network_output('embedding_layer_output',
1241 hidden_states)
1242 else:
1243 hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
1244
1245 kv_cache_params.fill_none_tensor_list(len(self.transformer.layers))
1246
1247 if use_cache:
1248 presents = []
1249
1250 for i, (decoder_layer, past) in enumerate(
1251 zip(self.transformer.layers, kv_cache_params.past_key_value)):
1252
1253 lora_layer_params = None
1254 if lora_params is not None and lora_params.lora_ranks is not None:
1255 lora_layer_params = lora_params.get_layer_params(i)
1256
1257 hidden_states = decoder_layer(
1258 hidden_states,
1259 encoder_output=encoder_output,
1260 attention_mask_params=attention_mask_params,
1261 use_cache=use_cache,
1262 kv_cache_params=KeyValueCacheParams(
1263 past_key_value=past,
1264 host_past_key_value_lengths=kv_cache_params.
1265 host_past_key_value_lengths,
1266 host_max_attention_window_sizes=kv_cache_params.
1267 host_max_attention_window_sizes,
1268 host_sink_token_length=kv_cache_params.
1269 host_sink_token_length,
1270 cache_indirection=kv_cache_params.cache_indirection,
1271 kv_cache_block_offsets=kv_cache_params.
1272 kv_cache_block_offsets,

Callers

nothing calls this directly

Calls 15

recvFunction · 0.90
KeyValueCacheParamsClass · 0.90
gather_last_token_logitsFunction · 0.90
default_netFunction · 0.90
sendFunction · 0.90
embeddingMethod · 0.80
fill_none_tensor_listMethod · 0.80
get_layer_paramsMethod · 0.80
mark_outputMethod · 0.80
pp_layersMethod · 0.80
is_first_pp_rankMethod · 0.45

Tested by

no test coverage detected