MCPcopy
hub / github.com/LuChengTHU/dpm-solver / grad_fn

Function grad_fn

examples/score_sde_jax/models/utils.py:321–327  ·  view source on GitHub ↗
(data, ve_noise_scale, labels)

Source from the content-addressed store, hash-verified

319 """Create the gradient function for the classifier in use of class-conditional sampling. """
320
321 def grad_fn(data, ve_noise_scale, labels):
322 def prob_fn(data):
323 logits = logit_fn(data, ve_noise_scale)
324 prob = jax.nn.log_softmax(logits, axis=-1)[jnp.arange(labels.shape[0]), labels].sum()
325 return prob
326
327 return jax.grad(prob_fn)(data)
328
329 return grad_fn

Callers 1

step_fnFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected