MCPcopy
hub / github.com/LuChengTHU/dpm-solver / marginal_prob

Method marginal_prob

examples/score_sde_jax/sde_lib.py:143–147  ·  view source on GitHub ↗
(self, x, t)

Source from the content-addressed store, hash-verified

141 return drift, diffusion
142
143 def marginal_prob(self, x, t):
144 log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
145 mean = batch_mul(jnp.exp(log_mean_coeff), x)
146 std = jnp.sqrt(1 - jnp.exp(2. * log_mean_coeff))
147 return mean, std
148
149 def prior_sampling(self, rng, shape):
150 return jax.random.normal(rng, shape)

Callers

nothing calls this directly

Calls 1

batch_mulFunction · 0.90

Tested by

no test coverage detected