(self, rng, x, t)
| 203 | super().__init__(sde, score_fn, probability_flow) |
| 204 | |
| 205 | def update_fn(self, rng, x, t): |
| 206 | dt = -1. / self.rsde.N |
| 207 | z = random.normal(rng, x.shape) |
| 208 | drift, diffusion = self.rsde.sde(x, t) |
| 209 | x_mean = x + drift * dt |
| 210 | x = x_mean + batch_mul(diffusion, jnp.sqrt(-dt) * z) |
| 211 | return x, x_mean |
| 212 | |
| 213 | |
| 214 | @register_predictor(name='reverse_diffusion') |