MCPcopy
hub / github.com/LuChengTHU/dpm-solver / discretize

Method discretize

examples/score_sde_jax/sde_lib.py:158–166  ·  view source on GitHub ↗

DDPM discretization.

(self, x, t)

Source from the content-addressed store, hash-verified

156 return jax.vmap(logp_fn)(z)
157
158 def discretize(self, x, t):
159 """DDPM discretization."""
160 timestep = (t * (self.N - 1) / self.T).astype(jnp.int32)
161 beta = self.discrete_betas[timestep]
162 alpha = self.alphas[timestep]
163 sqrt_beta = jnp.sqrt(beta)
164 f = batch_mul(jnp.sqrt(alpha), x) - x
165 G = sqrt_beta
166 return f, G
167
168
169class subVPSDE(SDE):

Callers

nothing calls this directly

Calls 1

batch_mulFunction · 0.90

Tested by

no test coverage detected