(
self,
hidden_states: Tensor,
attention_mask=None,
use_cache=False,
kv_cache_params=None,
attention_params=None,
spec_decoding_params=None,
encoder_output=None,
position_embedding=None,
norm_before_bmm1=False,
lora_layer_params=None,
all_reduce_params: Optional[AllReduceParams] = None,
)
| 1954 | self.use_lora = False |
| 1955 | |
| 1956 | def forward( |
| 1957 | self, |
| 1958 | hidden_states: Tensor, |
| 1959 | attention_mask=None, |
| 1960 | use_cache=False, |
| 1961 | kv_cache_params=None, |
| 1962 | attention_params=None, |
| 1963 | spec_decoding_params=None, |
| 1964 | encoder_output=None, |
| 1965 | position_embedding=None, |
| 1966 | norm_before_bmm1=False, |
| 1967 | lora_layer_params=None, |
| 1968 | all_reduce_params: Optional[AllReduceParams] = None, |
| 1969 | ): |
| 1970 | assert lora_layer_params is None, f"lora is not supported on {self.__class__.__name__} now" |
| 1971 | qkv = self.qkv(hidden_states) |
| 1972 | |
| 1973 | alibi_slopes = None |
| 1974 | if self.position_embedding_type == PositionEmbeddingType.alibi: |
| 1975 | alibi_slopes = self.alibi_slopes.value |
| 1976 | dtype = trt.float32 |
| 1977 | if default_net().plugin_config.gpt_attention_plugin or default_net( |
| 1978 | ).plugin_config.inflight_batching_gpt_attention_plugin: |
| 1979 | dtype = hidden_states.dtype if self.quant_mode.has_act_static_scaling( |
| 1980 | ) else hidden_states[0].dtype |
| 1981 | if dtype == trt.int8: |
| 1982 | dtype = trt.float16 |
| 1983 | alibi_slopes = cast(alibi_slopes, dtype) |
| 1984 | |
| 1985 | if spec_decoding_params is None: |
| 1986 | spec_decoding_params = SpecDecodingParams() |
| 1987 | |
| 1988 | assert default_net().plugin_config.gpt_attention_plugin |
| 1989 | |
| 1990 | assert attention_params.is_valid( |
| 1991 | default_net().plugin_config.gpt_attention_plugin, |
| 1992 | default_net().plugin_config.remove_input_padding, use_cache) |
| 1993 | if use_cache: |
| 1994 | assert kv_cache_params.is_valid( |
| 1995 | default_net().plugin_config.gpt_attention_plugin) |
| 1996 | assert self.attention_mask_type == AttentionMaskType.causal, \ |
| 1997 | 'Plugin only support masked MHA.' |
| 1998 | if self.kv_cache_scaling_factor is not None: |
| 1999 | kv_orig_quant_scale = self.kv_cache_rcp_scaling_factor.value |
| 2000 | kv_quant_orig_scale = self.kv_cache_scaling_factor.value |
| 2001 | else: |
| 2002 | kv_orig_quant_scale = None |
| 2003 | kv_quant_orig_scale = None |
| 2004 | if self.position_embedding_type.is_rope(): |
| 2005 | rotary_inv_freq = attention_params.rotary_inv_freq |
| 2006 | rotary_cos_sin = attention_params.embed_positions_for_gpt_attention |
| 2007 | else: |
| 2008 | rotary_inv_freq = None |
| 2009 | rotary_cos_sin = None |
| 2010 | context, past_key_value = gpt_attention( |
| 2011 | qkv=qkv, |
| 2012 | past_key_value=kv_cache_params.get_first_past_key_value(), |
| 2013 | sequence_length=attention_params.sequence_length, |
nothing calls this directly
no test coverage detected