Partitions the tree into ``(selected_tree, remainder_tree)`` parts. This function can be used to separate out the selected components of a tree into their own separate tree, so that JAX functions and other JAX libraries can process them like ordinary PyTrees. It splits its input into tw
(self, at_leaves: bool = False)
| 962 | return self.at_subtrees_where(_check_equal) |
| 963 | |
| 964 | def partition(self, at_leaves: bool = False) -> tuple[Any, Any]: |
| 965 | """Partitions the tree into ``(selected_tree, remainder_tree)`` parts. |
| 966 | |
| 967 | This function can be used to separate out the selected components of a tree |
| 968 | into their own separate tree, so that JAX functions and other JAX libraries |
| 969 | can process them like ordinary PyTrees. It splits its input into two |
| 970 | disjoint trees ``(selected_tree, remainder_tree)``, where ``selected_tree`` |
| 971 | only contains the leaves that were selected, and `remainder_tree` only |
| 972 | contains the remainder. The parts that were removed are identified using a |
| 973 | sentinel `pz.NotInThisPartition` object, which has no PyTree children. |
| 974 | |
| 975 | The main use case for ``partition`` is to identify subsets of models that |
| 976 | should be treated in different ways by JAX API functions. For instance, if |
| 977 | you want to take a gradient with respect to a specific subset of parameters, |
| 978 | you can select those parameters, call ``partition`` to separate them from |
| 979 | the rest, then call `jax.grad` and use `argnums` to identify the partition |
| 980 | of interest. Similarly, if you want to donate only a subset of the state to |
| 981 | `jax.jit`, you can partition it and then use JAX's ``donate_argnums`` |
| 982 | argument to `jax.jit` to identify the parts you want to donate. Inside the |
| 983 | function, you can then use `pz.combine` to rebuild the original tree. |
| 984 | |
| 985 | It is possible to repeatedly call ``partition`` to split a tree into |
| 986 | more than two parts. In particular, you can select the ``remainder_tree``, |
| 987 | target some additional nodes, and call ``.partition()`` again, repeating |
| 988 | this process as needed. All of the partitioned trees can then be re-combined |
| 989 | using a single call to `pz.combine`. |
| 990 | |
| 991 | Note that `NotInThisPartition` is a PyTree node with no children, which |
| 992 | means that partitioned trees are safe to pass through JAX transformations, |
| 993 | and the set of leaves in the two partitioned trees together are the same as |
| 994 | the set of leaves in the original selected tree. |
| 995 | |
| 996 | This function is inspired by Equinox's `equinox.partition`, but is designed |
| 997 | to work with Penzai's selector system. Unlike `equinox.partition`, missing |
| 998 | nodes are identified with the `pz.NotInThisPartition` sentinel, and can |
| 999 | replace arbitrary PyTree subtrees instead of just leaves. (Partitioning is |
| 1000 | also somewhat less important in Penzai than in Equinox because all PyTree |
| 1001 | leaves are arraylike by convention; partitioning is only necessary when |
| 1002 | different parts of the tree need special treatment e.g. for ``argnums`` or |
| 1003 | ``donate_argnums`` parameters.) |
| 1004 | |
| 1005 | Args: |
| 1006 | at_leaves: Whether to do the partitioning at the leaf level, so that the |
| 1007 | returned trees have exactly the same structure. (Note that `pz.combine` |
| 1008 | is OK with entire subtrees missing, so this is not necessary, but can |
| 1009 | make the partitions easier to manipulate manually if desired.) If False, |
| 1010 | the entire selected subtrees will be replaced by `NotInThisPartition` in |
| 1011 | the remainder tree. |
| 1012 | |
| 1013 | Returns: |
| 1014 | A tuple ``(selected_tree, remainder_tree)``, where both trees have the |
| 1015 | same structure (if ``at_leaves=True``) or the same prefix (if |
| 1016 | ``at_leaves=False``) except that `NotInThisPartition` is used to replace |
| 1017 | parts that are in the other partition. |
| 1018 | """ |
| 1019 | if at_leaves: |
| 1020 | selected_tree = ( |
| 1021 | self.invert() |