Running one step of training or evaluation. This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together for faster execution. Args: carry_state: A tuple (JAX random state, `flax.struct.dataclass` containing the training state).
(carry_state, batch)
| 205 | raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.") |
| 206 | |
| 207 | def step_fn(carry_state, batch): |
| 208 | """Running one step of training or evaluation. |
| 209 | |
| 210 | This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together |
| 211 | for faster execution. |
| 212 | |
| 213 | Args: |
| 214 | carry_state: A tuple (JAX random state, `flax.struct.dataclass` containing the training state). |
| 215 | batch: A mini-batch of training/evaluation data. |
| 216 | |
| 217 | Returns: |
| 218 | new_carry_state: The updated tuple of `carry_state`. |
| 219 | loss: The average loss value of this state. |
| 220 | """ |
| 221 | |
| 222 | (rng, state) = carry_state |
| 223 | rng, step_rng = jax.random.split(rng) |
| 224 | grad_fn = jax.value_and_grad(loss_fn, argnums=1, has_aux=True) |
| 225 | if train: |
| 226 | params = state.optimizer.target |
| 227 | states = state.model_state |
| 228 | (loss, new_model_state), grad = grad_fn(step_rng, params, states, batch) |
| 229 | grad = jax.lax.pmean(grad, axis_name='batch') |
| 230 | new_optimizer = optimize_fn(state, grad) |
| 231 | new_params_ema = jax.tree_multimap( |
| 232 | lambda p_ema, p: p_ema * state.ema_rate + p * (1. - state.ema_rate), |
| 233 | state.params_ema, new_optimizer.target |
| 234 | ) |
| 235 | step = state.step + 1 |
| 236 | new_state = state.replace( |
| 237 | step=step, |
| 238 | optimizer=new_optimizer, |
| 239 | model_state=new_model_state, |
| 240 | params_ema=new_params_ema |
| 241 | ) |
| 242 | else: |
| 243 | loss, _ = loss_fn(step_rng, state.params_ema, state.model_state, batch) |
| 244 | new_state = state |
| 245 | |
| 246 | loss = jax.lax.pmean(loss, axis_name='batch') |
| 247 | new_carry_state = (rng, new_state) |
| 248 | return new_carry_state, loss |
| 249 | |
| 250 | return step_fn |
nothing calls this directly
no test coverage detected