MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/quantization/layers.py:1956–2075  ·  view source on GitHub ↗
(
        self,
        hidden_states: Tensor,
        attention_mask=None,
        use_cache=False,
        kv_cache_params=None,
        attention_params=None,
        spec_decoding_params=None,
        encoder_output=None,
        position_embedding=None,
        norm_before_bmm1=False,
        lora_layer_params=None,
        all_reduce_params: Optional[AllReduceParams] = None,
    )

Source from the content-addressed store, hash-verified

1954 self.use_lora = False
1955
1956 def forward(
1957 self,
1958 hidden_states: Tensor,
1959 attention_mask=None,
1960 use_cache=False,
1961 kv_cache_params=None,
1962 attention_params=None,
1963 spec_decoding_params=None,
1964 encoder_output=None,
1965 position_embedding=None,
1966 norm_before_bmm1=False,
1967 lora_layer_params=None,
1968 all_reduce_params: Optional[AllReduceParams] = None,
1969 ):
1970 assert lora_layer_params is None, f"lora is not supported on {self.__class__.__name__} now"
1971 qkv = self.qkv(hidden_states)
1972
1973 alibi_slopes = None
1974 if self.position_embedding_type == PositionEmbeddingType.alibi:
1975 alibi_slopes = self.alibi_slopes.value
1976 dtype = trt.float32
1977 if default_net().plugin_config.gpt_attention_plugin or default_net(
1978 ).plugin_config.inflight_batching_gpt_attention_plugin:
1979 dtype = hidden_states.dtype if self.quant_mode.has_act_static_scaling(
1980 ) else hidden_states[0].dtype
1981 if dtype == trt.int8:
1982 dtype = trt.float16
1983 alibi_slopes = cast(alibi_slopes, dtype)
1984
1985 if spec_decoding_params is None:
1986 spec_decoding_params = SpecDecodingParams()
1987
1988 assert default_net().plugin_config.gpt_attention_plugin
1989
1990 assert attention_params.is_valid(
1991 default_net().plugin_config.gpt_attention_plugin,
1992 default_net().plugin_config.remove_input_padding, use_cache)
1993 if use_cache:
1994 assert kv_cache_params.is_valid(
1995 default_net().plugin_config.gpt_attention_plugin)
1996 assert self.attention_mask_type == AttentionMaskType.causal, \
1997 'Plugin only support masked MHA.'
1998 if self.kv_cache_scaling_factor is not None:
1999 kv_orig_quant_scale = self.kv_cache_rcp_scaling_factor.value
2000 kv_quant_orig_scale = self.kv_cache_scaling_factor.value
2001 else:
2002 kv_orig_quant_scale = None
2003 kv_quant_orig_scale = None
2004 if self.position_embedding_type.is_rope():
2005 rotary_inv_freq = attention_params.rotary_inv_freq
2006 rotary_cos_sin = attention_params.embed_positions_for_gpt_attention
2007 else:
2008 rotary_inv_freq = None
2009 rotary_cos_sin = None
2010 context, past_key_value = gpt_attention(
2011 qkv=qkv,
2012 past_key_value=kv_cache_params.get_first_past_key_value(),
2013 sequence_length=attention_params.sequence_length,

Callers

nothing calls this directly

Calls 10

default_netFunction · 0.85
castFunction · 0.85
SpecDecodingParamsClass · 0.85
gpt_attentionFunction · 0.85
quantize_fp8_per_tokenFunction · 0.85
is_ropeMethod · 0.80
is_validMethod · 0.45
has_fp8_rowwiseMethod · 0.45

Tested by

no test coverage detected