Release the memory of the skipped params.
(tree, tree_with_skip)
| 464 | |
| 465 | |
| 466 | def _release_skip(tree, tree_with_skip) -> None: |
| 467 | """Release the memory of the skipped params.""" |
| 468 | jax.tree.map( |
| 469 | lambda x, y: x.delete() if isinstance(y, _Skip) else None, |
| 470 | tree, |
| 471 | tree_with_skip, |
| 472 | ) |
| 473 | |
| 474 | |
| 475 | # ======================== Other utils ======================== |