Selects all PyTree leaves of each selected subtree. This selects all of the leaves of the PyTree according to `jax.tree_util`, giving the most-specific selection expressible with a `Selection` object. (Note that, if any objects in the tree are not registered as JAX PyTree nodes, the
(self)
| 629 | return _build_selection_from_boundary(with_boundary) |
| 630 | |
| 631 | def at_pytree_leaves(self) -> "Selection": |
| 632 | """Selects all PyTree leaves of each selected subtree. |
| 633 | |
| 634 | This selects all of the leaves of the PyTree according to `jax.tree_util`, |
| 635 | giving the most-specific selection expressible with a `Selection` object. |
| 636 | (Note that, if any objects in the tree are not registered as JAX PyTree |
| 637 | nodes, they will be selected in their entirety even if they contain children |
| 638 | when printed out by treescope.) |
| 639 | |
| 640 | Returns: |
| 641 | A new selection that selects every leaf of each selected subtree. |
| 642 | """ |
| 643 | add_boundary = functools.partial( |
| 644 | jax.tree_util.tree_map, _InProgressSelectionBoundary |
| 645 | ) |
| 646 | return _build_selection_from_boundary(self.apply(add_boundary)) |
| 647 | |
| 648 | def at_children(self) -> "Selection": |
| 649 | """Selects all direct children of each selected subtree. |