DDPM discretization.
(self, x, t)
| 156 | return jax.vmap(logp_fn)(z) |
| 157 | |
| 158 | def discretize(self, x, t): |
| 159 | """DDPM discretization.""" |
| 160 | timestep = (t * (self.N - 1) / self.T).astype(jnp.int32) |
| 161 | beta = self.discrete_betas[timestep] |
| 162 | alpha = self.alphas[timestep] |
| 163 | sqrt_beta = jnp.sqrt(beta) |
| 164 | f = batch_mul(jnp.sqrt(alpha), x) - x |
| 165 | G = sqrt_beta |
| 166 | return f, G |
| 167 | |
| 168 | |
| 169 | class subVPSDE(SDE): |
nothing calls this directly
no test coverage detected