Builds a layer stack of layers with non-shared parameters. This function assumes that all variables returned by this builder are defined inside the builder. Returning variables that were already defined outside the builder is not supported. Args: builder: A function that buil
(
cls,
builder: Callable[..., layer_base.Layer],
stack_axis: named_axes.AxisName,
stack_axis_size: int,
init_base_rng: jax.Array | None,
builder_kwargs: dict[str, Any],
)
| 265 | |
| 266 | @classmethod |
| 267 | def from_sublayer_builder( |
| 268 | cls, |
| 269 | builder: Callable[..., layer_base.Layer], |
| 270 | stack_axis: named_axes.AxisName, |
| 271 | stack_axis_size: int, |
| 272 | init_base_rng: jax.Array | None, |
| 273 | builder_kwargs: dict[str, Any], |
| 274 | ) -> LayerStack: |
| 275 | """Builds a layer stack of layers with non-shared parameters. |
| 276 | |
| 277 | This function assumes that all variables returned by this builder are |
| 278 | defined inside the builder. Returning variables that were already defined |
| 279 | outside the builder is not supported. |
| 280 | |
| 281 | Args: |
| 282 | builder: A function that builds a single layer, which must take a keyword |
| 283 | argument ``init_base_rng``. All variables, as well as all other leaf |
| 284 | values that depend on this RNG, must be NamedArrays. |
| 285 | stack_axis: The axis name that layer data is stacked along. |
| 286 | stack_axis_size: The size of the stack axis. |
| 287 | init_base_rng: The base RNG for initializing the parameters. |
| 288 | builder_kwargs: Keyword arguments to pass to the builder. |
| 289 | |
| 290 | Returns: |
| 291 | A new layer stack. All arrays and variables will be split across the |
| 292 | stack axis. |
| 293 | """ |
| 294 | |
| 295 | def go(rng): |
| 296 | sublayer = builder( |
| 297 | init_base_rng=rng, |
| 298 | **builder_kwargs, |
| 299 | ) |
| 300 | unbound_sublayer, var_values = variables.unbind_variables( |
| 301 | sublayer, freeze=True |
| 302 | ) |
| 303 | if any( |
| 304 | not isinstance(leaf, named_axes.NamedArrayBase) |
| 305 | for leaf in jax.tree_util.tree_leaves( |
| 306 | var_values, is_leaf=named_axes.is_namedarray |
| 307 | ) |
| 308 | ): |
| 309 | raise ValueError( |
| 310 | "Variables returned by the LayerStack builder must only contain" |
| 311 | " NamedArrays, not ordinary array data." |
| 312 | ) |
| 313 | namedarray_selection = selectors.select( |
| 314 | (unbound_sublayer, var_values) |
| 315 | ).at_instances_of(named_axes.NamedArrayBase) |
| 316 | adjusted_namedarrays = collections.OrderedDict({ |
| 317 | k: v.with_positional_prefix() |
| 318 | for k, v in namedarray_selection.selected_by_path.items() |
| 319 | }) |
| 320 | return namedarray_selection.remainder, adjusted_namedarrays |
| 321 | |
| 322 | remainder, namedarrays = jax.vmap( |
| 323 | go, out_axes=(None, 0), axis_size=stack_axis_size |
| 324 | )( |