MCPcopy
hub / github.com/LuChengTHU/dpm-solver / loss_fn

Function loss_fn

examples/score_sde_jax/losses.py:84–119  ·  view source on GitHub ↗

Compute the loss function. Args: rng: A JAX random state. params: A dictionary that contains trainable parameters of the score-based model. states: A dictionary that contains mutable states of the score-based model. batch: A mini-batch of training data. Returns:

(rng, params, states, batch)

Source from the content-addressed store, hash-verified

82 reduce_op = jnp.mean if reduce_mean else lambda *args, **kwargs: 0.5 * jnp.sum(*args, **kwargs)
83
84 def loss_fn(rng, params, states, batch):
85 """Compute the loss function.
86
87 Args:
88 rng: A JAX random state.
89 params: A dictionary that contains trainable parameters of the score-based model.
90 states: A dictionary that contains mutable states of the score-based model.
91 batch: A mini-batch of training data.
92
93 Returns:
94 loss: A scalar that represents the average loss value across the mini-batch.
95 new_model_state: A dictionary that contains the mutated states of the score-based model.
96 """
97
98 score_fn = mutils.get_score_fn(sde, model, params, states, train=train, continuous=continuous, return_state=True)
99 data = batch['image']
100
101 rng, step_rng = random.split(rng)
102 t = random.uniform(step_rng, (data.shape[0],), minval=eps, maxval=sde.T)
103 rng, step_rng = random.split(rng)
104 z = random.normal(step_rng, data.shape)
105 mean, std = sde.marginal_prob(data, t)
106 perturbed_data = mean + batch_mul(std, z)
107 rng, step_rng = random.split(rng)
108 score, new_model_state = score_fn(perturbed_data, t, rng=step_rng)
109
110 if not likelihood_weighting:
111 losses = jnp.square(batch_mul(score, std) + z)
112 losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1)
113 else:
114 g2 = sde.sde(jnp.zeros_like(data), t)[1] ** 2
115 losses = jnp.square(score + batch_mul(z, 1. / std))
116 losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1) * g2
117
118 loss = jnp.mean(losses)
119 return loss, new_model_state
120
121 return loss_fn
122

Callers 1

step_fnFunction · 0.70

Calls 5

batch_mulFunction · 0.90
model_fnFunction · 0.70
score_fnFunction · 0.50
marginal_probMethod · 0.45
sdeMethod · 0.45

Tested by

no test coverage detected