Selects a specific child of each selected node. ``Selection.at(...)`` allows you to modify a tree with an almost-imperative style while maintaining a functional interface, similar to the ``Array.at[...]`` syntax for ordinary NDArrays. It takes a callable that picks out a subtree of
(
self,
accessor_fn: Callable[[SelectedSubtree], Any | Collection[Any]],
multiple: bool | None = None,
)
| 499 | ).flatten_selected_selections() |
| 500 | |
| 501 | def at( |
| 502 | self, |
| 503 | accessor_fn: Callable[[SelectedSubtree], Any | Collection[Any]], |
| 504 | multiple: bool | None = None, |
| 505 | ) -> "Selection": |
| 506 | """Selects a specific child of each selected node. |
| 507 | |
| 508 | ``Selection.at(...)`` allows you to modify a tree with an almost-imperative |
| 509 | style while maintaining a functional interface, similar to the |
| 510 | ``Array.at[...]`` syntax for ordinary NDArrays. It takes a callable that |
| 511 | picks out a subtree of the tree, and returns a new selection that selects |
| 512 | the part that was picked out. |
| 513 | |
| 514 | For instance, if you have an object |
| 515 | |
| 516 | :: |
| 517 | obj = Foo(bar=[1, 2, {"baz": 5}]) |
| 518 | |
| 519 | you could select the 5 using |
| 520 | |
| 521 | :: |
| 522 | |
| 523 | pz.select(obj).at(lambda x: x.bar[2]["baz"]) |
| 524 | |
| 525 | Args: |
| 526 | accessor_fn: A function which takes each element of the current selection |
| 527 | and returns a single node within that selection (if ``multiple`` is |
| 528 | False) or a collection of nodes (if ``multiple`` is True). This function |
| 529 | must be structural; it must depend only on the PyTree structure of its |
| 530 | input and not on the actual values or Python IDs of the leaves. It will |
| 531 | be called with a copy of the object where every PyTree leaf and every |
| 532 | empty PyTree node (e.g. an empty tuple or the None singleton) are |
| 533 | wrapped with an internal wrapper object. |
| 534 | multiple: Whether `accessor_fn` returns a collection of nodes to select, |
| 535 | rather than a single node. If `None`, first tries to find it as a single |
| 536 | node, and if that fails, tries to find it as a collection of nodes but |
| 537 | emits a warning. |
| 538 | |
| 539 | Returns: |
| 540 | A modified selection that selects the specific child of each node in the |
| 541 | original selection (or the set of nodes if ``multiple`` was True). |
| 542 | """ |
| 543 | |
| 544 | def _is_leaf_or_childless(node): |
| 545 | result = penzai_tree_util.tree_flatten_exactly_one_level(node) |
| 546 | if result is None: |
| 547 | return True |
| 548 | else: |
| 549 | children, _ = result |
| 550 | return not children |
| 551 | |
| 552 | def _unwrap(l: _LeafWrapper): |
| 553 | assert isinstance(l, _LeafWrapper) |
| 554 | return l.wrapped_leaf |
| 555 | |
| 556 | # This logic is based on equinox.tree_at, but kept separate to avoid |
| 557 | # depending on equinox. |
| 558 | def _process_one(node, multiple=multiple): |