MCPcopy
hub / github.com/google-deepmind/penzai / MyTestLayer

Class MyTestLayer

tests/nn/layer_stack_test.py:26–41  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

24
25@pz.pytree_dataclass
26class MyTestLayer(pz.nn.Layer):
27 params: tuple[pz.Parameter[pz.nx.NamedArray], ...]
28 expected_param_shapes: tuple[dict[str, int], ...] | None
29 states: tuple[pz.StateVariable[pz.nx.NamedArray], ...]
30 stuff: Any
31
32 def __call__(self, arg, /, shared_counter, **side_inputs):
33 if self.expected_param_shapes is not None:
34 for param, expected in zip(self.params, self.expected_param_shapes):
35 assert param.value.named_shape == expected
36
37 for state in self.states:
38 state.value += 1
39
40 shared_counter.value += 1
41 return arg + 1
42
43
44class LayerStackTest(absltest.TestCase):

Callers 2

test_stack_callMethod · 0.85
builderMethod · 0.85

Calls

no outgoing calls

Tested by 2

test_stack_callMethod · 0.68
builderMethod · 0.68