Wrap the params in a `Skip` object.
(tree)
| 444 | |
| 445 | |
| 446 | def _wrap_skip(tree): |
| 447 | """Wrap the params in a `Skip` object.""" |
| 448 | # Currently has no effect but when orbax will support partial restore, |
| 449 | # this will skip the restore of those params. |
| 450 | return jax.tree.map(_Skip, tree) |
| 451 | |
| 452 | |
| 453 | def _unwrap_skip(tree): |
no test coverage detected