(cls, *, apply_fn, params, tx, **kwargs)
| 617 | |
| 618 | @classmethod |
| 619 | def create(cls, *, apply_fn, params, tx, **kwargs): |
| 620 | opt_state = {} |
| 621 | for k, p in split_params(params).items(): |
| 622 | init_fn = tx[k].init |
| 623 | if "scanned" in k: |
| 624 | init_fn = jax.vmap(init_fn) |
| 625 | opt_state[k] = init_fn(p) |
| 626 | return cls( |
| 627 | step=0, |
| 628 | apply_fn=apply_fn, |
| 629 | params=params, |
| 630 | tx=tx, |
| 631 | opt_state=freeze(opt_state), |
| 632 | **kwargs, |
| 633 | ) |
| 634 | |
| 635 | |
| 636 | def main(): |
no test coverage detected