MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/models/enc_dec/model.py:279–342  ·  view source on GitHub ↗
(self,
                hidden_states: Tensor,
                attention_mask=None,
                input_lengths=None,
                max_input_length=None,
                lora_layer_params=None,
                language_adapter_routings: Optional[Tensor] = None)

Source from the content-addressed store, hash-verified

277 quant_mode=quant_mode)
278
279 def forward(self,
280 hidden_states: Tensor,
281 attention_mask=None,
282 input_lengths=None,
283 max_input_length=None,
284 lora_layer_params=None,
285 language_adapter_routings: Optional[Tensor] = None):
286 assert isinstance(hidden_states, Tensor)
287
288 # self attention
289 residual = hidden_states * self.residual_scaling
290
291 if self.layernorm_position == LayerNormPositionType.pre_layernorm:
292 hidden_states = self.attention_layernorm(hidden_states)
293
294 attention_output = self.attention(hidden_states,
295 attention_mask=attention_mask,
296 input_lengths=input_lengths,
297 max_input_length=max_input_length,
298 lora_layer_params=lora_layer_params)
299
300 self.register_network_output('attention_output', attention_output)
301
302 hidden_states = residual + attention_output
303
304 if self.fp16_clamping:
305 hidden_states = maximum(-64000.0, hidden_states)
306 hidden_states = minimum(64000.0, hidden_states)
307
308 if self.layernorm_position == LayerNormPositionType.post_layernorm:
309 hidden_states = self.attention_layernorm(hidden_states)
310
311 # MLP
312 residual = hidden_states * self.residual_scaling
313
314 if self.layernorm_position == LayerNormPositionType.pre_layernorm:
315 hidden_states = self.mlp_layernorm(hidden_states)
316
317 hidden_states = self.mlp(hidden_states,
318 lora_layer_params=lora_layer_params)
319
320 self.register_network_output('mlp_output', hidden_states)
321
322 hidden_states = residual + hidden_states
323
324 if self.fp16_clamping:
325 hidden_states = maximum(-64000.0, hidden_states)
326 hidden_states = minimum(64000.0, hidden_states)
327
328 if self.layernorm_position == LayerNormPositionType.post_layernorm:
329 hidden_states = self.mlp_layernorm(hidden_states)
330
331 # MT Specific: adapters
332 if self.adapter:
333 residual = hidden_states
334 hidden_states = self.adapter_layer_norm(hidden_states)
335 hidden_states = self.adapter.layers(
336 hidden_states, static_routing_input=language_adapter_routings)

Callers

nothing calls this directly

Calls 1

Tested by

no test coverage detected