(self,
hidden_states,
use_cache=False,
attention_mask=None,
kv_cache_params=None,
attention_params=None,
mrope_params=None,
position_ids=None,
lora_params=None,
spec_decoding_params=None,
vision_token_mask=None)
| 579 | super().__init__([cls(config, idx) for idx in self.layer_list]) |
| 580 | |
| 581 | def forward(self, |
| 582 | hidden_states, |
| 583 | use_cache=False, |
| 584 | attention_mask=None, |
| 585 | kv_cache_params=None, |
| 586 | attention_params=None, |
| 587 | mrope_params=None, |
| 588 | position_ids=None, |
| 589 | lora_params=None, |
| 590 | spec_decoding_params=None, |
| 591 | vision_token_mask=None): |
| 592 | kv_cache_params.fill_none_tensor_list(len(self.layer_list)) |
| 593 | |
| 594 | if use_cache: |
| 595 | presents = [] |
| 596 | |
| 597 | for layer_idx, (layer, past) in enumerate( |
| 598 | zip(self, kv_cache_params.past_key_value)): |
| 599 | |
| 600 | lora_layer_params = None |
| 601 | if lora_params is not None and lora_params.lora_ranks is not None: |
| 602 | lora_layer_params = lora_params.get_layer_params(layer_idx) |
| 603 | |
| 604 | kwargs = {} |
| 605 | if position_ids is not None: |
| 606 | kwargs['position_ids'] = position_ids |
| 607 | if vision_token_mask is not None: |
| 608 | kwargs['vision_token_mask'] = vision_token_mask |
| 609 | if lora_layer_params is not None: |
| 610 | kwargs['lora_layer_params'] = lora_layer_params |
| 611 | if spec_decoding_params is not None: |
| 612 | kwargs['spec_decoding_params'] = spec_decoding_params |
| 613 | if mrope_params is not None: |
| 614 | kwargs['mrope_params'] = mrope_params |
| 615 | |
| 616 | if default_net().plugin_config.reduce_fusion: |
| 617 | if layer_idx + self.layer_list[0] < self.layer_list[-1]: |
| 618 | qkv_activation_scaling_factor = None |
| 619 | if default_net().plugin_config.user_buffer: |
| 620 | qkv_linear = self[layer_idx + 1].attention.qkv |
| 621 | if self.quant_mode.has_fp8_qdq(): |
| 622 | qkv_activation_scaling_factor = constant( |
| 623 | qkv_linear.activation_scaling_factor.raw_value. |
| 624 | copy()) |
| 625 | elif self.quant_mode.has_nvfp4(): |
| 626 | qkv_activation_scaling_factor = constant( |
| 627 | qkv_linear.activation_global_scaling_factor. |
| 628 | raw_value.copy()) |
| 629 | kwargs['next_layer_input_layernorm_args'] = ( |
| 630 | self[layer_idx + 1].input_layernorm.weight.value, |
| 631 | self[layer_idx + 1].input_layernorm.eps, |
| 632 | qkv_activation_scaling_factor) |
| 633 | else: |
| 634 | kwargs['next_layer_input_layernorm_args'] = None |
| 635 | elif default_net().plugin_config.norm_quant_fusion: |
| 636 | if layer_idx < self.layer_list[-1] - self.layer_list[0]: |
| 637 | try: |
| 638 | activation_scaling_factor = constant( |
no test coverage detected