(
cls,
init_base_rng: jax.Array | None,
feature_sizes: list[int],
name: str = "mlp",
activation_fn: Callable[[jax.Array], jax.Array] = jax.nn.relu,
feature_axis: str = "features",
)
| 28 | |
| 29 | @classmethod |
| 30 | def from_config( |
| 31 | cls, |
| 32 | init_base_rng: jax.Array | None, |
| 33 | feature_sizes: list[int], |
| 34 | name: str = "mlp", |
| 35 | activation_fn: Callable[[jax.Array], jax.Array] = jax.nn.relu, |
| 36 | feature_axis: str = "features", |
| 37 | ) -> MLP: |
| 38 | assert len(feature_sizes) >= 2 |
| 39 | children = [] |
| 40 | for i, (feats_in, feats_out) in enumerate( |
| 41 | zip(feature_sizes[:-1], feature_sizes[1:]) |
| 42 | ): |
| 43 | if i: |
| 44 | children.append(pz.nn.Elementwise(activation_fn)) |
| 45 | children.append( |
| 46 | pz.nn.Affine.from_config( |
| 47 | name=f"{name}/Affine_{i}", |
| 48 | init_base_rng=init_base_rng, |
| 49 | input_axes={feature_axis: feats_in}, |
| 50 | output_axes={feature_axis: feats_out}, |
| 51 | ) |
| 52 | ) |
| 53 | return cls(sublayers=children) |
| 54 | |
| 55 | |
| 56 | @pz.pytree_dataclass(has_implicitly_inherited_fields=True) # pytype: disable=wrong-keyword-args # pylint: disable=line-too-long |
no outgoing calls