MCPcopy
hub / github.com/LuChengTHU/dpm-solver / step_fn

Function step_fn

examples/score_sde_jax/losses.py:207–248  ·  view source on GitHub ↗

Running one step of training or evaluation. This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together for faster execution. Args: carry_state: A tuple (JAX random state, `flax.struct.dataclass` containing the training state).

(carry_state, batch)

Source from the content-addressed store, hash-verified

205 raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.")
206
207 def step_fn(carry_state, batch):
208 """Running one step of training or evaluation.
209
210 This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together
211 for faster execution.
212
213 Args:
214 carry_state: A tuple (JAX random state, `flax.struct.dataclass` containing the training state).
215 batch: A mini-batch of training/evaluation data.
216
217 Returns:
218 new_carry_state: The updated tuple of `carry_state`.
219 loss: The average loss value of this state.
220 """
221
222 (rng, state) = carry_state
223 rng, step_rng = jax.random.split(rng)
224 grad_fn = jax.value_and_grad(loss_fn, argnums=1, has_aux=True)
225 if train:
226 params = state.optimizer.target
227 states = state.model_state
228 (loss, new_model_state), grad = grad_fn(step_rng, params, states, batch)
229 grad = jax.lax.pmean(grad, axis_name='batch')
230 new_optimizer = optimize_fn(state, grad)
231 new_params_ema = jax.tree_multimap(
232 lambda p_ema, p: p_ema * state.ema_rate + p * (1. - state.ema_rate),
233 state.params_ema, new_optimizer.target
234 )
235 step = state.step + 1
236 new_state = state.replace(
237 step=step,
238 optimizer=new_optimizer,
239 model_state=new_model_state,
240 params_ema=new_params_ema
241 )
242 else:
243 loss, _ = loss_fn(step_rng, state.params_ema, state.model_state, batch)
244 new_state = state
245
246 loss = jax.lax.pmean(loss, axis_name='batch')
247 new_carry_state = (rng, new_state)
248 return new_carry_state, loss
249
250 return step_fn

Callers

nothing calls this directly

Calls 3

grad_fnFunction · 0.85
optimize_fnFunction · 0.70
loss_fnFunction · 0.70

Tested by

no test coverage detected