MCPcopy
hub / github.com/google-deepmind/penzai / freeze_variables

Function freeze_variables

penzai/core/variables.py:344–371  ·  view source on GitHub ↗

Replaces each variable in a pytree with a frozen copy. The resulting tree will contain frozen variable instances instead of mutable variable instances. Frozen variables are themselves pytree nodes, so the resulting tree will be safe to pass through JAX transformations if all variables are f

(
    tree: Any,
    predicate: Callable[[AbstractVariable], bool] | None = None,
)

Source from the content-addressed store, hash-verified

342
343
344def freeze_variables(
345 tree: Any,
346 predicate: Callable[[AbstractVariable], bool] | None = None,
347) -> Any:
348 """Replaces each variable in a pytree with a frozen copy.
349
350 The resulting tree will contain frozen variable instances instead of mutable
351 variable instances. Frozen variables are themselves pytree nodes, so the
352 resulting tree will be safe to pass through JAX transformations if all
353 variables are frozen.
354
355 Args:
356 tree: A tree containing variables.
357 predicate: A function that returns True for variables that should be frozen.
358 If None, all variables will be frozen.
359
360 Returns:
361 A copy of `tree` but with all variables (or those selected by `predicate`)
362 replaced by equivalent frozen instances.
363 """
364 if predicate is None:
365 predicate = lambda _: True
366 return (
367 selectors.select(tree)
368 .at_instances_of(AbstractVariable)
369 .where(predicate)
370 .apply(lambda var: var.freeze())
371 )
372
373
374def variable_jit(fun, *, donate_variables: bool = False, **jit_kwargs):

Callers 2

freeze_paramsFunction · 0.85
freeze_state_varsFunction · 0.85

Calls 5

whereMethod · 0.80
at_instances_ofMethod · 0.80
selectMethod · 0.80
applyMethod · 0.45
freezeMethod · 0.45

Tested by

no test coverage detected