(self, sample, deterministic: bool = True, return_dict: bool = True)
| 886 | return self.init(rngs, sample)["params"] |
| 887 | |
| 888 | def encode(self, sample, deterministic: bool = True, return_dict: bool = True): |
| 889 | sample = jnp.transpose(sample, (0, 2, 3, 1)) |
| 890 | |
| 891 | hidden_states = self.encoder(sample, deterministic=deterministic) |
| 892 | moments = self.quant_conv(hidden_states) |
| 893 | posterior = FlaxDiagonalGaussianDistribution(moments) |
| 894 | |
| 895 | if not return_dict: |
| 896 | return (posterior,) |
| 897 | |
| 898 | return FlaxAutoencoderKLOutput(latent_dist=posterior) |
| 899 | |
| 900 | def decode(self, latents, deterministic: bool = True, return_dict: bool = True): |
| 901 | if latents.shape[-1] != self.config.latent_channels: |
no test coverage detected