(x, t, rng=None)
| 212 | |
| 213 | if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): |
| 214 | def score_fn(x, t, rng=None): |
| 215 | # Scale neural network output by standard deviation and flip sign |
| 216 | if continuous or isinstance(sde, sde_lib.subVPSDE): |
| 217 | # For VP-trained models, t=0 corresponds to the lowest noise level |
| 218 | # The maximum value of time embedding is assumed to 999 for |
| 219 | # continuously-trained models. |
| 220 | labels = t * 999 |
| 221 | model, state = model_fn(x, labels, rng) |
| 222 | std = sde.marginal_prob(jnp.zeros_like(x), t)[1] |
| 223 | else: |
| 224 | # For VP-trained models, t=0 corresponds to the lowest noise level |
| 225 | labels = t * (sde.N - 1) |
| 226 | model, state = model_fn(x, labels, rng) |
| 227 | std = sde.sqrt_1m_alphas_cumprod[labels.astype(jnp.int32)] |
| 228 | |
| 229 | score = batch_mul(-model, 1. / std) |
| 230 | if return_state: |
| 231 | return score, state |
| 232 | else: |
| 233 | return score |
| 234 | |
| 235 | elif isinstance(sde, sde_lib.VESDE): |
| 236 | def score_fn(x, t, rng=None): |
no test coverage detected