MCPcopy
hub / github.com/LuChengTHU/dpm-solver / update_fn

Method update_fn

examples/score_sde_jax/sampling.py:205–211  ·  view source on GitHub ↗
(self, rng, x, t)

Source from the content-addressed store, hash-verified

203 super().__init__(sde, score_fn, probability_flow)
204
205 def update_fn(self, rng, x, t):
206 dt = -1. / self.rsde.N
207 z = random.normal(rng, x.shape)
208 drift, diffusion = self.rsde.sde(x, t)
209 x_mean = x + drift * dt
210 x = x_mean + batch_mul(diffusion, jnp.sqrt(-dt) * z)
211 return x, x_mean
212
213
214@register_predictor(name='reverse_diffusion')

Callers

nothing calls this directly

Calls 2

batch_mulFunction · 0.90
sdeMethod · 0.45

Tested by

no test coverage detected