MCPcopy
hub / github.com/google-deepmind/penzai / from_sublayer_builder

Method from_sublayer_builder

penzai/nn/layer_stack.py:267–371  ·  view source on GitHub ↗

Builds a layer stack of layers with non-shared parameters. This function assumes that all variables returned by this builder are defined inside the builder. Returning variables that were already defined outside the builder is not supported. Args: builder: A function that buil

(
      cls,
      builder: Callable[..., layer_base.Layer],
      stack_axis: named_axes.AxisName,
      stack_axis_size: int,
      init_base_rng: jax.Array | None,
      builder_kwargs: dict[str, Any],
  )

Source from the content-addressed store, hash-verified

265
266 @classmethod
267 def from_sublayer_builder(
268 cls,
269 builder: Callable[..., layer_base.Layer],
270 stack_axis: named_axes.AxisName,
271 stack_axis_size: int,
272 init_base_rng: jax.Array | None,
273 builder_kwargs: dict[str, Any],
274 ) -> LayerStack:
275 """Builds a layer stack of layers with non-shared parameters.
276
277 This function assumes that all variables returned by this builder are
278 defined inside the builder. Returning variables that were already defined
279 outside the builder is not supported.
280
281 Args:
282 builder: A function that builds a single layer, which must take a keyword
283 argument ``init_base_rng``. All variables, as well as all other leaf
284 values that depend on this RNG, must be NamedArrays.
285 stack_axis: The axis name that layer data is stacked along.
286 stack_axis_size: The size of the stack axis.
287 init_base_rng: The base RNG for initializing the parameters.
288 builder_kwargs: Keyword arguments to pass to the builder.
289
290 Returns:
291 A new layer stack. All arrays and variables will be split across the
292 stack axis.
293 """
294
295 def go(rng):
296 sublayer = builder(
297 init_base_rng=rng,
298 **builder_kwargs,
299 )
300 unbound_sublayer, var_values = variables.unbind_variables(
301 sublayer, freeze=True
302 )
303 if any(
304 not isinstance(leaf, named_axes.NamedArrayBase)
305 for leaf in jax.tree_util.tree_leaves(
306 var_values, is_leaf=named_axes.is_namedarray
307 )
308 ):
309 raise ValueError(
310 "Variables returned by the LayerStack builder must only contain"
311 " NamedArrays, not ordinary array data."
312 )
313 namedarray_selection = selectors.select(
314 (unbound_sublayer, var_values)
315 ).at_instances_of(named_axes.NamedArrayBase)
316 adjusted_namedarrays = collections.OrderedDict({
317 k: v.with_positional_prefix()
318 for k, v in namedarray_selection.selected_by_path.items()
319 })
320 return namedarray_selection.remainder, adjusted_namedarrays
321
322 remainder, namedarrays = jax.vmap(
323 go, out_axes=(None, 0), axis_size=stack_axis_size
324 )(

Callers 3

Calls 5

LayerStackClass · 0.85
tag_prefixMethod · 0.80
deselectMethod · 0.80
bind_variablesMethod · 0.80
unfreeze_as_copyMethod · 0.45

Tested by 1