MCPcopy
hub / github.com/LuChengTHU/dpm-solver / marginal_prob

Method marginal_prob

examples/score_sde_jax/sde_lib.py:194–198  ·  view source on GitHub ↗
(self, x, t)

Source from the content-addressed store, hash-verified

192 return drift, diffusion
193
194 def marginal_prob(self, x, t):
195 log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
196 mean = batch_mul(jnp.exp(log_mean_coeff), x)
197 std = 1 - jnp.exp(2. * log_mean_coeff)
198 return mean, std
199
200 def prior_sampling(self, rng, shape):
201 return jax.random.normal(rng, shape)

Callers

nothing calls this directly

Calls 1

batch_mulFunction · 0.90

Tested by

no test coverage detected