MCPcopy
hub / github.com/google-deepmind/penzai / _wrap_selection_errors

Function _wrap_selection_errors

penzai/core/selectors.py:1415–1453  ·  view source on GitHub ↗

Context manager to add informative details to selection errors.

(selection: "Selection")

Source from the content-addressed store, hash-verified

1413
1414@contextlib.contextmanager
1415def _wrap_selection_errors(selection: "Selection"):
1416 """Context manager to add informative details to selection errors."""
1417 try:
1418 yield
1419 except Exception as exc:
1420 new_message = (
1421 "An error occurred while building a Selection. This can happen when"
1422 " PyTree nodes make assumptions about the types of their children,"
1423 " since Selections replace subtrees with sentinel elements. It can also"
1424 " happen if the value being selected isn't a valid JAX PyTree (in which"
1425 " case `jax.tree_utils.tree_flatten` will also fail)."
1426 )
1427
1428 # Check for known failure cases
1429 any_partial = False
1430
1431 def _check(subtree) -> bool:
1432 nonlocal any_partial
1433 if isinstance(subtree, jax.tree_util.Partial):
1434 any_partial = True # pylint: disable=unused-variable
1435 return True
1436 return False
1437
1438 jax.tree_util.tree_map(_check, selection.deselect(), is_leaf=_check)
1439
1440 if any_partial:
1441 new_message += (
1442 "\n\nIn this case, the error may have been caused by a"
1443 " jax.tree_util.Partial instance, which requires its PyTree children"
1444 " to be a list and a dictionary. This means the direct children of a"
1445 " Partial cannot be selected. If you want to avoid this edge case,"
1446 " consider replacing your jax.tree_utils.Partial instances with"
1447 " penzai.experimental.safe_partial.Partial instances, using something"
1448 " like"
1449 " `select(your_tree).at_instances_of(jax.tree_util.Partial)"
1450 ".apply(penzai.experimental.safe_partial.Partial.from_jax)`"
1451 )
1452
1453 raise ValueError(new_message) from exc
1454
1455
1456def _build_selection_from_boundary(tree_with_boundary: Any) -> Selection:

Callers 7

deselectMethod · 0.85
atMethod · 0.85
at_childrenMethod · 0.85
whereMethod · 0.85
at_subtrees_whereMethod · 0.85
pick_nth_selectedMethod · 0.85

Calls 1

deselectMethod · 0.80

Tested by

no test coverage detected