Compute the loss function. Args: rng: A JAX random state. params: A dictionary that contains trainable parameters of the score-based model. states: A dictionary that contains mutable states of the score-based model. batch: A mini-batch of training data. Returns:
(rng, params, states, batch)
| 82 | reduce_op = jnp.mean if reduce_mean else lambda *args, **kwargs: 0.5 * jnp.sum(*args, **kwargs) |
| 83 | |
| 84 | def loss_fn(rng, params, states, batch): |
| 85 | """Compute the loss function. |
| 86 | |
| 87 | Args: |
| 88 | rng: A JAX random state. |
| 89 | params: A dictionary that contains trainable parameters of the score-based model. |
| 90 | states: A dictionary that contains mutable states of the score-based model. |
| 91 | batch: A mini-batch of training data. |
| 92 | |
| 93 | Returns: |
| 94 | loss: A scalar that represents the average loss value across the mini-batch. |
| 95 | new_model_state: A dictionary that contains the mutated states of the score-based model. |
| 96 | """ |
| 97 | |
| 98 | score_fn = mutils.get_score_fn(sde, model, params, states, train=train, continuous=continuous, return_state=True) |
| 99 | data = batch['image'] |
| 100 | |
| 101 | rng, step_rng = random.split(rng) |
| 102 | t = random.uniform(step_rng, (data.shape[0],), minval=eps, maxval=sde.T) |
| 103 | rng, step_rng = random.split(rng) |
| 104 | z = random.normal(step_rng, data.shape) |
| 105 | mean, std = sde.marginal_prob(data, t) |
| 106 | perturbed_data = mean + batch_mul(std, z) |
| 107 | rng, step_rng = random.split(rng) |
| 108 | score, new_model_state = score_fn(perturbed_data, t, rng=step_rng) |
| 109 | |
| 110 | if not likelihood_weighting: |
| 111 | losses = jnp.square(batch_mul(score, std) + z) |
| 112 | losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1) |
| 113 | else: |
| 114 | g2 = sde.sde(jnp.zeros_like(data), t)[1] ** 2 |
| 115 | losses = jnp.square(score + batch_mul(z, 1. / std)) |
| 116 | losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1) * g2 |
| 117 | |
| 118 | loss = jnp.mean(losses) |
| 119 | return loss, new_model_state |
| 120 | |
| 121 | return loss_fn |
| 122 |