MCPcopy
hub / github.com/google-deepmind/penzai / partition

Method partition

penzai/core/selectors.py:964–1032  ·  view source on GitHub ↗

Partitions the tree into ``(selected_tree, remainder_tree)`` parts. This function can be used to separate out the selected components of a tree into their own separate tree, so that JAX functions and other JAX libraries can process them like ordinary PyTrees. It splits its input into tw

(self, at_leaves: bool = False)

Source from the content-addressed store, hash-verified

962 return self.at_subtrees_where(_check_equal)
963
964 def partition(self, at_leaves: bool = False) -> tuple[Any, Any]:
965 """Partitions the tree into ``(selected_tree, remainder_tree)`` parts.
966
967 This function can be used to separate out the selected components of a tree
968 into their own separate tree, so that JAX functions and other JAX libraries
969 can process them like ordinary PyTrees. It splits its input into two
970 disjoint trees ``(selected_tree, remainder_tree)``, where ``selected_tree``
971 only contains the leaves that were selected, and `remainder_tree` only
972 contains the remainder. The parts that were removed are identified using a
973 sentinel `pz.NotInThisPartition` object, which has no PyTree children.
974
975 The main use case for ``partition`` is to identify subsets of models that
976 should be treated in different ways by JAX API functions. For instance, if
977 you want to take a gradient with respect to a specific subset of parameters,
978 you can select those parameters, call ``partition`` to separate them from
979 the rest, then call `jax.grad` and use `argnums` to identify the partition
980 of interest. Similarly, if you want to donate only a subset of the state to
981 `jax.jit`, you can partition it and then use JAX's ``donate_argnums``
982 argument to `jax.jit` to identify the parts you want to donate. Inside the
983 function, you can then use `pz.combine` to rebuild the original tree.
984
985 It is possible to repeatedly call ``partition`` to split a tree into
986 more than two parts. In particular, you can select the ``remainder_tree``,
987 target some additional nodes, and call ``.partition()`` again, repeating
988 this process as needed. All of the partitioned trees can then be re-combined
989 using a single call to `pz.combine`.
990
991 Note that `NotInThisPartition` is a PyTree node with no children, which
992 means that partitioned trees are safe to pass through JAX transformations,
993 and the set of leaves in the two partitioned trees together are the same as
994 the set of leaves in the original selected tree.
995
996 This function is inspired by Equinox's `equinox.partition`, but is designed
997 to work with Penzai's selector system. Unlike `equinox.partition`, missing
998 nodes are identified with the `pz.NotInThisPartition` sentinel, and can
999 replace arbitrary PyTree subtrees instead of just leaves. (Partitioning is
1000 also somewhat less important in Penzai than in Equinox because all PyTree
1001 leaves are arraylike by convention; partitioning is only necessary when
1002 different parts of the tree need special treatment e.g. for ``argnums`` or
1003 ``donate_argnums`` parameters.)
1004
1005 Args:
1006 at_leaves: Whether to do the partitioning at the leaf level, so that the
1007 returned trees have exactly the same structure. (Note that `pz.combine`
1008 is OK with entire subtrees missing, so this is not necessary, but can
1009 make the partitions easier to manipulate manually if desired.) If False,
1010 the entire selected subtrees will be replaced by `NotInThisPartition` in
1011 the remainder tree.
1012
1013 Returns:
1014 A tuple ``(selected_tree, remainder_tree)``, where both trees have the
1015 same structure (if ``at_leaves=True``) or the same prefix (if
1016 ``at_leaves=False``) except that `NotInThisPartition` is used to replace
1017 parts that are in the other partition.
1018 """
1019 if at_leaves:
1020 selected_tree = (
1021 self.invert()

Callers 1

Calls 3

setMethod · 0.95
at_pytree_leavesMethod · 0.95
invertMethod · 0.95

Tested by 1