Create discretized iteration rules for the reverse diffusion sampler.
(self, x, t)
| 102 | return drift, diffusion |
| 103 | |
| 104 | def discretize(self, x, t): |
| 105 | """Create discretized iteration rules for the reverse diffusion sampler.""" |
| 106 | f, G = discretize_fn(x, t) |
| 107 | rev_f = f - batch_mul(G ** 2, score_fn(x, t) * (0.5 if self.probability_flow else 1.)) |
| 108 | rev_G = jnp.zeros_like(G) if self.probability_flow else G |
| 109 | return rev_f, rev_G |
| 110 | |
| 111 | return RSDE() |
| 112 |