(self, x, t)
| 141 | return drift, diffusion |
| 142 | |
| 143 | def marginal_prob(self, x, t): |
| 144 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 |
| 145 | mean = batch_mul(jnp.exp(log_mean_coeff), x) |
| 146 | std = jnp.sqrt(1 - jnp.exp(2. * log_mean_coeff)) |
| 147 | return mean, std |
| 148 | |
| 149 | def prior_sampling(self, rng, shape): |
| 150 | return jax.random.normal(rng, shape) |
nothing calls this directly
no test coverage detected