| 898 | return FlaxAutoencoderKLOutput(latent_dist=posterior) |
| 899 | |
| 900 | def decode(self, latents, deterministic: bool = True, return_dict: bool = True): |
| 901 | if latents.shape[-1] != self.config.latent_channels: |
| 902 | latents = jnp.transpose(latents, (0, 2, 3, 1)) |
| 903 | |
| 904 | hidden_states = self.post_quant_conv(latents) |
| 905 | hidden_states = self.decoder(hidden_states, deterministic=deterministic) |
| 906 | |
| 907 | hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) |
| 908 | |
| 909 | if not return_dict: |
| 910 | return (hidden_states,) |
| 911 | |
| 912 | return FlaxDecoderOutput(sample=hidden_states) |
| 913 | |
| 914 | def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True): |
| 915 | posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict) |