Optimizes with warmup and gradient clipping (disabled if negative).
(state,
grad,
warmup=config.optim.warmup,
grad_clip=config.optim.grad_clip)
| 41 | """Returns an optimize_fn based on `config`.""" |
| 42 | |
| 43 | def optimize_fn(state, |
| 44 | grad, |
| 45 | warmup=config.optim.warmup, |
| 46 | grad_clip=config.optim.grad_clip): |
| 47 | """Optimizes with warmup and gradient clipping (disabled if negative).""" |
| 48 | lr = state.lr |
| 49 | if warmup > 0: |
| 50 | lr = lr * jnp.minimum(state.step / warmup, 1.0) |
| 51 | if grad_clip >= 0: |
| 52 | # Compute global gradient norm |
| 53 | grad_norm = jnp.sqrt( |
| 54 | sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(grad)])) |
| 55 | # Clip gradient |
| 56 | clipped_grad = jax.tree_map( |
| 57 | lambda x: x * grad_clip / jnp.maximum(grad_norm, grad_clip), grad) |
| 58 | else: # disabling gradient clipping if grad_clip < 0 |
| 59 | clipped_grad = grad |
| 60 | return state.optimizer.apply_gradient(clipped_grad, learning_rate=lr) |
| 61 | |
| 62 | return optimize_fn |
| 63 |