MCPcopy
hub / github.com/LuChengTHU/dpm-solver / optimize_fn

Function optimize_fn

examples/score_sde_jax/losses.py:43–60  ·  view source on GitHub ↗

Optimizes with warmup and gradient clipping (disabled if negative).

(state,
                  grad,
                  warmup=config.optim.warmup,
                  grad_clip=config.optim.grad_clip)

Source from the content-addressed store, hash-verified

41 """Returns an optimize_fn based on `config`."""
42
43 def optimize_fn(state,
44 grad,
45 warmup=config.optim.warmup,
46 grad_clip=config.optim.grad_clip):
47 """Optimizes with warmup and gradient clipping (disabled if negative)."""
48 lr = state.lr
49 if warmup > 0:
50 lr = lr * jnp.minimum(state.step / warmup, 1.0)
51 if grad_clip >= 0:
52 # Compute global gradient norm
53 grad_norm = jnp.sqrt(
54 sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(grad)]))
55 # Clip gradient
56 clipped_grad = jax.tree_map(
57 lambda x: x * grad_clip / jnp.maximum(grad_norm, grad_clip), grad)
58 else: # disabling gradient clipping if grad_clip < 0
59 clipped_grad = grad
60 return state.optimizer.apply_gradient(clipped_grad, learning_rate=lr)
61
62 return optimize_fn
63

Callers 1

step_fnFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected