Function
grad_fn
(data, ve_noise_scale, labels)
Source from the content-addressed store, hash-verified
| 319 | """Create the gradient function for the classifier in use of class-conditional sampling. """ |
| 320 | |
| 321 | def grad_fn(data, ve_noise_scale, labels): |
| 322 | def prob_fn(data): |
| 323 | logits = logit_fn(data, ve_noise_scale) |
| 324 | prob = jax.nn.log_softmax(logits, axis=-1)[jnp.arange(labels.shape[0]), labels].sum() |
| 325 | return prob |
| 326 | |
| 327 | return jax.grad(prob_fn)(data) |
| 328 | |
| 329 | return grad_fn |
Tested by
no test coverage detected