(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = True,
deterministic: bool = True,
)
| 702 | |
| 703 | @nn.compact |
| 704 | def __call__( |
| 705 | self, |
| 706 | hidden_states: jnp.ndarray, |
| 707 | attention_mask: jnp.ndarray, |
| 708 | encoder_hidden_states: Optional[jnp.ndarray] = None, |
| 709 | encoder_attention_mask: Optional[jnp.ndarray] = None, |
| 710 | init_cache: bool = False, |
| 711 | output_attentions: bool = True, |
| 712 | deterministic: bool = True, |
| 713 | ) -> Tuple[jnp.ndarray]: |
| 714 | |
| 715 | if self.config.use_scan: |
| 716 | hidden_states = hidden_states[0] |
| 717 | |
| 718 | res_gain = ( |
| 719 | deepnet_gain["decoder"]["alpha"](self.config) |
| 720 | if self.config.use_deepnet_scaling |
| 721 | else 1 |
| 722 | ) |
| 723 | |
| 724 | embed_dim = self.config.d_model |
| 725 | residual = hidden_states |
| 726 | |
| 727 | # Self Attention |
| 728 | if self.config.ln_positions in ["normformer", "cogview", "preln"]: |
| 729 | hidden_states = norm( |
| 730 | self.config.ln_type, |
| 731 | dtype=self.dtype, |
| 732 | epsilon=1e-05, |
| 733 | use_scale=self.config.force_ln_scale, |
| 734 | )(hidden_states) |
| 735 | hidden_states, attn_weights = FlaxBartAttention( |
| 736 | config=self.config, |
| 737 | embed_dim=embed_dim, |
| 738 | num_heads=self.config.decoder_attention_heads, |
| 739 | dropout=self.config.attention_dropout, |
| 740 | causal=True, |
| 741 | bias=self.config.use_bias, |
| 742 | dtype=self.dtype, |
| 743 | is_encoder=False, |
| 744 | q_length=self.config.image_length, |
| 745 | k_length=self.config.image_length, |
| 746 | )( |
| 747 | hidden_states=hidden_states, |
| 748 | attention_mask=attention_mask, |
| 749 | init_cache=init_cache, |
| 750 | ) |
| 751 | |
| 752 | if self.config.ln_positions in ["normformer", "swinv2", "cogview"]: |
| 753 | hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)( |
| 754 | hidden_states |
| 755 | ) |
| 756 | hidden_states = nn.Dropout(rate=self.config.dropout)( |
| 757 | hidden_states, deterministic=deterministic |
| 758 | ) |
| 759 | hidden_states = residual * res_gain + hidden_states |
| 760 | if self.config.ln_positions in ["postln"]: |
| 761 | hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)( |
nothing calls this directly
no test coverage detected