Create the drift and diffusion functions for the reverse SDE/ODE.
(self, x, t)
| 93 | return T |
| 94 | |
| 95 | def sde(self, x, t): |
| 96 | """Create the drift and diffusion functions for the reverse SDE/ODE.""" |
| 97 | drift, diffusion = sde_fn(x, t) |
| 98 | score = score_fn(x, t) |
| 99 | drift = drift - batch_mul(diffusion ** 2, score * (0.5 if self.probability_flow else 1.)) |
| 100 | # Set the diffusion function to zero for ODEs. |
| 101 | diffusion = jnp.zeros_like(diffusion) if self.probability_flow else diffusion |
| 102 | return drift, diffusion |
| 103 | |
| 104 | def discretize(self, x, t): |
| 105 | """Create discretized iteration rules for the reverse diffusion sampler.""" |