MCPcopy
hub / github.com/LuChengTHU/dpm-solver / score_fn

Function score_fn

examples/score_sde_jax/models/utils.py:214–233  ·  view source on GitHub ↗
(x, t, rng=None)

Source from the content-addressed store, hash-verified

212
213 if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
214 def score_fn(x, t, rng=None):
215 # Scale neural network output by standard deviation and flip sign
216 if continuous or isinstance(sde, sde_lib.subVPSDE):
217 # For VP-trained models, t=0 corresponds to the lowest noise level
218 # The maximum value of time embedding is assumed to 999 for
219 # continuously-trained models.
220 labels = t * 999
221 model, state = model_fn(x, labels, rng)
222 std = sde.marginal_prob(jnp.zeros_like(x), t)[1]
223 else:
224 # For VP-trained models, t=0 corresponds to the lowest noise level
225 labels = t * (sde.N - 1)
226 model, state = model_fn(x, labels, rng)
227 std = sde.sqrt_1m_alphas_cumprod[labels.astype(jnp.int32)]
228
229 score = batch_mul(-model, 1. / std)
230 if return_state:
231 return score, state
232 else:
233 return score
234
235 elif isinstance(sde, sde_lib.VESDE):
236 def score_fn(x, t, rng=None):

Callers 6

sdeMethod · 0.50
discretizeMethod · 0.50
total_grad_fnFunction · 0.50
loss_fnFunction · 0.50
loop_bodyMethod · 0.50
loop_bodyMethod · 0.50

Calls 3

batch_mulFunction · 0.90
model_fnFunction · 0.70
marginal_probMethod · 0.45

Tested by

no test coverage detected