MCPcopy
hub / github.com/google-deepmind/penzai / from_config

Method from_config

penzai/models/simple_mlp.py:30–53  ·  view source on GitHub ↗
(
      cls,
      init_base_rng: jax.Array | None,
      feature_sizes: list[int],
      name: str = "mlp",
      activation_fn: Callable[[jax.Array], jax.Array] = jax.nn.relu,
      feature_axis: str = "features",
  )

Source from the content-addressed store, hash-verified

28
29 @classmethod
30 def from_config(
31 cls,
32 init_base_rng: jax.Array | None,
33 feature_sizes: list[int],
34 name: str = "mlp",
35 activation_fn: Callable[[jax.Array], jax.Array] = jax.nn.relu,
36 feature_axis: str = "features",
37 ) -> MLP:
38 assert len(feature_sizes) >= 2
39 children = []
40 for i, (feats_in, feats_out) in enumerate(
41 zip(feature_sizes[:-1], feature_sizes[1:])
42 ):
43 if i:
44 children.append(pz.nn.Elementwise(activation_fn))
45 children.append(
46 pz.nn.Affine.from_config(
47 name=f"{name}/Affine_{i}",
48 init_base_rng=init_base_rng,
49 input_axes={feature_axis: feats_in},
50 output_axes={feature_axis: feats_out},
51 )
52 )
53 return cls(sublayers=children)
54
55
56@pz.pytree_dataclass(has_implicitly_inherited_fields=True) # pytype: disable=wrong-keyword-args # pylint: disable=line-too-long

Callers 15

from_linearMethod · 0.45
from_configMethod · 0.45
build_llamalike_blockFunction · 0.45
build_gpt_neox_attentionFunction · 0.45
build_gpt_neox_blockFunction · 0.45
test_simple_captureMethod · 0.45

Calls

no outgoing calls

Tested by 15

test_simple_captureMethod · 0.36
test_build_loraMethod · 0.36
test_jit_wrapperMethod · 0.36
test_embeddingsMethod · 0.36
test_linear_in_placeMethod · 0.36
test_add_biasMethod · 0.36
test_affineMethod · 0.36