MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/quantization/layers.py:2582–2807  ·  view source on GitHub ↗
(
        self,
        hidden_states: Tensor,
        attention_mask=None,
        use_cache=False,
        kv_cache_params=None,
        attention_params=None,
        spec_decoding_params=None,
        mrope_params=None,
        encoder_output=None,
        position_embedding=None,
        norm_before_bmm1=False,
        lora_layer_params=None,
        all_reduce_params: Optional[AllReduceParams] = None,
    )

Source from the content-addressed store, hash-verified

2580 self.use_lora = False
2581
2582 def forward(
2583 self,
2584 hidden_states: Tensor,
2585 attention_mask=None,
2586 use_cache=False,
2587 kv_cache_params=None,
2588 attention_params=None,
2589 spec_decoding_params=None,
2590 mrope_params=None,
2591 encoder_output=None,
2592 position_embedding=None,
2593 norm_before_bmm1=False,
2594 lora_layer_params=None,
2595 all_reduce_params: Optional[AllReduceParams] = None,
2596 ):
2597 assert lora_layer_params is None, f"lora is not supported on {self.__class__.__name__} now"
2598 qkv = self.qkv(hidden_states)
2599
2600 alibi_slopes = None
2601 if self.position_embedding_type == PositionEmbeddingType.alibi:
2602 alibi_slopes = self.alibi_slopes.value
2603 dtype = trt.float32
2604 if default_net().plugin_config.gpt_attention_plugin or default_net(
2605 ).plugin_config.inflight_batching_gpt_attention_plugin:
2606 dtype = hidden_states.dtype if self.quant_mode.has_act_static_scaling(
2607 ) else hidden_states[0].dtype
2608 if dtype == trt.int8:
2609 dtype = trt.float16
2610 alibi_slopes = cast(alibi_slopes, dtype)
2611
2612 if spec_decoding_params is None:
2613 spec_decoding_params = SpecDecodingParams()
2614
2615 if mrope_params is None:
2616 mrope_params = MropeParams()
2617
2618 if default_net().plugin_config.gpt_attention_plugin:
2619
2620 assert attention_params.is_valid(
2621 default_net().plugin_config.gpt_attention_plugin,
2622 default_net().plugin_config.remove_input_padding, use_cache)
2623 if use_cache:
2624 assert kv_cache_params.is_valid(
2625 default_net().plugin_config.gpt_attention_plugin)
2626 assert self.attention_mask_type == AttentionMaskType.causal, \
2627 'Plugin only support masked MHA.'
2628 if self.kv_cache_scaling_factor is not None:
2629 kv_orig_quant_scale = self.kv_cache_rcp_scaling_factor.value
2630 kv_quant_orig_scale = self.kv_cache_scaling_factor.value
2631 else:
2632 kv_orig_quant_scale = None
2633 kv_quant_orig_scale = None
2634 if self.position_embedding_type.is_rope():
2635 rotary_inv_freq = attention_params.rotary_inv_freq
2636 rotary_cos_sin = attention_params.embed_positions_for_gpt_attention
2637 else:
2638 rotary_inv_freq = None
2639 rotary_cos_sin = None

Callers

nothing calls this directly

Calls 15

default_netFunction · 0.85
castFunction · 0.85
SpecDecodingParamsClass · 0.85
MropeParamsClass · 0.85
gpt_attentionFunction · 0.85
concatFunction · 0.85
constantFunction · 0.85
sliceFunction · 0.85
precisionFunction · 0.85
matmulFunction · 0.85
whereFunction · 0.85
softmaxFunction · 0.85

Tested by

no test coverage detected