(self, x, t)
| 185 | return 1 |
| 186 | |
| 187 | def sde(self, x, t): |
| 188 | beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) |
| 189 | drift = -0.5 * batch_mul(beta_t, x) |
| 190 | discount = 1. - jnp.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2) |
| 191 | diffusion = jnp.sqrt(beta_t * discount) |
| 192 | return drift, diffusion |
| 193 | |
| 194 | def marginal_prob(self, x, t): |
| 195 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 |
nothing calls this directly
no test coverage detected