Context manager to add informative details to selection errors.
(selection: "Selection")
| 1413 | |
| 1414 | @contextlib.contextmanager |
| 1415 | def _wrap_selection_errors(selection: "Selection"): |
| 1416 | """Context manager to add informative details to selection errors.""" |
| 1417 | try: |
| 1418 | yield |
| 1419 | except Exception as exc: |
| 1420 | new_message = ( |
| 1421 | "An error occurred while building a Selection. This can happen when" |
| 1422 | " PyTree nodes make assumptions about the types of their children," |
| 1423 | " since Selections replace subtrees with sentinel elements. It can also" |
| 1424 | " happen if the value being selected isn't a valid JAX PyTree (in which" |
| 1425 | " case `jax.tree_utils.tree_flatten` will also fail)." |
| 1426 | ) |
| 1427 | |
| 1428 | # Check for known failure cases |
| 1429 | any_partial = False |
| 1430 | |
| 1431 | def _check(subtree) -> bool: |
| 1432 | nonlocal any_partial |
| 1433 | if isinstance(subtree, jax.tree_util.Partial): |
| 1434 | any_partial = True # pylint: disable=unused-variable |
| 1435 | return True |
| 1436 | return False |
| 1437 | |
| 1438 | jax.tree_util.tree_map(_check, selection.deselect(), is_leaf=_check) |
| 1439 | |
| 1440 | if any_partial: |
| 1441 | new_message += ( |
| 1442 | "\n\nIn this case, the error may have been caused by a" |
| 1443 | " jax.tree_util.Partial instance, which requires its PyTree children" |
| 1444 | " to be a list and a dictionary. This means the direct children of a" |
| 1445 | " Partial cannot be selected. If you want to avoid this edge case," |
| 1446 | " consider replacing your jax.tree_utils.Partial instances with" |
| 1447 | " penzai.experimental.safe_partial.Partial instances, using something" |
| 1448 | " like" |
| 1449 | " `select(your_tree).at_instances_of(jax.tree_util.Partial)" |
| 1450 | ".apply(penzai.experimental.safe_partial.Partial.from_jax)`" |
| 1451 | ) |
| 1452 | |
| 1453 | raise ValueError(new_message) from exc |
| 1454 | |
| 1455 | |
| 1456 | def _build_selection_from_boundary(tree_with_boundary: Any) -> Selection: |
no test coverage detected