(self, x, t)
| 192 | return drift, diffusion |
| 193 | |
| 194 | def marginal_prob(self, x, t): |
| 195 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 |
| 196 | mean = batch_mul(jnp.exp(log_mean_coeff), x) |
| 197 | std = 1 - jnp.exp(2. * log_mean_coeff) |
| 198 | return mean, std |
| 199 | |
| 200 | def prior_sampling(self, rng, shape): |
| 201 | return jax.random.normal(rng, shape) |
nothing calls this directly
no test coverage detected