MCPcopy
hub / github.com/borisdayma/dalle-mini / __call__

Method __call__

src/dalle_mini/model/modeling.py:704–842  ·  view source on GitHub ↗
(
        self,
        hidden_states: jnp.ndarray,
        attention_mask: jnp.ndarray,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        output_attentions: bool = True,
        deterministic: bool = True,
    )

Source from the content-addressed store, hash-verified

702
703 @nn.compact
704 def __call__(
705 self,
706 hidden_states: jnp.ndarray,
707 attention_mask: jnp.ndarray,
708 encoder_hidden_states: Optional[jnp.ndarray] = None,
709 encoder_attention_mask: Optional[jnp.ndarray] = None,
710 init_cache: bool = False,
711 output_attentions: bool = True,
712 deterministic: bool = True,
713 ) -> Tuple[jnp.ndarray]:
714
715 if self.config.use_scan:
716 hidden_states = hidden_states[0]
717
718 res_gain = (
719 deepnet_gain["decoder"]["alpha"](self.config)
720 if self.config.use_deepnet_scaling
721 else 1
722 )
723
724 embed_dim = self.config.d_model
725 residual = hidden_states
726
727 # Self Attention
728 if self.config.ln_positions in ["normformer", "cogview", "preln"]:
729 hidden_states = norm(
730 self.config.ln_type,
731 dtype=self.dtype,
732 epsilon=1e-05,
733 use_scale=self.config.force_ln_scale,
734 )(hidden_states)
735 hidden_states, attn_weights = FlaxBartAttention(
736 config=self.config,
737 embed_dim=embed_dim,
738 num_heads=self.config.decoder_attention_heads,
739 dropout=self.config.attention_dropout,
740 causal=True,
741 bias=self.config.use_bias,
742 dtype=self.dtype,
743 is_encoder=False,
744 q_length=self.config.image_length,
745 k_length=self.config.image_length,
746 )(
747 hidden_states=hidden_states,
748 attention_mask=attention_mask,
749 init_cache=init_cache,
750 )
751
752 if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
753 hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
754 hidden_states
755 )
756 hidden_states = nn.Dropout(rate=self.config.dropout)(
757 hidden_states, deterministic=deterministic
758 )
759 hidden_states = residual * res_gain + hidden_states
760 if self.config.ln_positions in ["postln"]:
761 hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(

Callers

nothing calls this directly

Calls 4

FlaxBartAttentionClass · 0.85
GLUClass · 0.85
FFNClass · 0.85
normFunction · 0.70

Tested by

no test coverage detected