(self, x, t)
| 135 | return 1 |
| 136 | |
| 137 | def sde(self, x, t): |
| 138 | beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) |
| 139 | drift = -0.5 * batch_mul(beta_t, x) |
| 140 | diffusion = jnp.sqrt(beta_t) |
| 141 | return drift, diffusion |
| 142 | |
| 143 | def marginal_prob(self, x, t): |
| 144 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 |
nothing calls this directly
no test coverage detected