(
self,
hidden_states: Tensor,
attention_mask=None,
use_cache=False,
kv_cache_params=None,
attention_params=None,
spec_decoding_params=None,
mrope_params=None,
encoder_output=None,
position_embedding=None,
norm_before_bmm1=False,
lora_layer_params=None,
all_reduce_params: Optional[AllReduceParams] = None,
)
| 2580 | self.use_lora = False |
| 2581 | |
| 2582 | def forward( |
| 2583 | self, |
| 2584 | hidden_states: Tensor, |
| 2585 | attention_mask=None, |
| 2586 | use_cache=False, |
| 2587 | kv_cache_params=None, |
| 2588 | attention_params=None, |
| 2589 | spec_decoding_params=None, |
| 2590 | mrope_params=None, |
| 2591 | encoder_output=None, |
| 2592 | position_embedding=None, |
| 2593 | norm_before_bmm1=False, |
| 2594 | lora_layer_params=None, |
| 2595 | all_reduce_params: Optional[AllReduceParams] = None, |
| 2596 | ): |
| 2597 | assert lora_layer_params is None, f"lora is not supported on {self.__class__.__name__} now" |
| 2598 | qkv = self.qkv(hidden_states) |
| 2599 | |
| 2600 | alibi_slopes = None |
| 2601 | if self.position_embedding_type == PositionEmbeddingType.alibi: |
| 2602 | alibi_slopes = self.alibi_slopes.value |
| 2603 | dtype = trt.float32 |
| 2604 | if default_net().plugin_config.gpt_attention_plugin or default_net( |
| 2605 | ).plugin_config.inflight_batching_gpt_attention_plugin: |
| 2606 | dtype = hidden_states.dtype if self.quant_mode.has_act_static_scaling( |
| 2607 | ) else hidden_states[0].dtype |
| 2608 | if dtype == trt.int8: |
| 2609 | dtype = trt.float16 |
| 2610 | alibi_slopes = cast(alibi_slopes, dtype) |
| 2611 | |
| 2612 | if spec_decoding_params is None: |
| 2613 | spec_decoding_params = SpecDecodingParams() |
| 2614 | |
| 2615 | if mrope_params is None: |
| 2616 | mrope_params = MropeParams() |
| 2617 | |
| 2618 | if default_net().plugin_config.gpt_attention_plugin: |
| 2619 | |
| 2620 | assert attention_params.is_valid( |
| 2621 | default_net().plugin_config.gpt_attention_plugin, |
| 2622 | default_net().plugin_config.remove_input_padding, use_cache) |
| 2623 | if use_cache: |
| 2624 | assert kv_cache_params.is_valid( |
| 2625 | default_net().plugin_config.gpt_attention_plugin) |
| 2626 | assert self.attention_mask_type == AttentionMaskType.causal, \ |
| 2627 | 'Plugin only support masked MHA.' |
| 2628 | if self.kv_cache_scaling_factor is not None: |
| 2629 | kv_orig_quant_scale = self.kv_cache_rcp_scaling_factor.value |
| 2630 | kv_quant_orig_scale = self.kv_cache_scaling_factor.value |
| 2631 | else: |
| 2632 | kv_orig_quant_scale = None |
| 2633 | kv_quant_orig_scale = None |
| 2634 | if self.position_embedding_type.is_rope(): |
| 2635 | rotary_inv_freq = attention_params.rotary_inv_freq |
| 2636 | rotary_cos_sin = attention_params.embed_positions_for_gpt_attention |
| 2637 | else: |
| 2638 | rotary_inv_freq = None |
| 2639 | rotary_cos_sin = None |
nothing calls this directly
no test coverage detected