MCPcopy
hub / github.com/google-deepmind/penzai / test_stack_call

Method test_stack_call

tests/nn/layer_stack_test.py:46–109  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

44class LayerStackTest(absltest.TestCase):
45
46 def test_stack_call(self):
47 the_layer = pz.nn.LayerStack(
48 stacked_sublayers=MyTestLayer(
49 params=(
50 pz.Parameter(
51 value=pz.nx.zeros({"stack": 3, "foo": 4}), label="param1"
52 ),
53 pz.Parameter(value=pz.nx.zeros({"bar": 5}), label="param2"),
54 ),
55 expected_param_shapes=({"foo": 4}, {"bar": 5}),
56 states=(
57 pz.StateVariable(
58 value=0,
59 label="state1",
60 metadata={
61 "layerstack_axes": {
62 "stack": pz.nn.LayerStackVarBehavior.SHARED
63 }
64 },
65 ),
66 pz.StateVariable(
67 value=pz.nx.zeros({"foo": 3}),
68 label="state2",
69 metadata={
70 "layerstack_axes": {
71 "stack": pz.nn.LayerStackVarBehavior.SHARED
72 }
73 },
74 ),
75 pz.StateVariable(
76 value=pz.nx.arange("stack", 3),
77 label="state3",
78 metadata={
79 "layerstack_axes": {
80 "stack": pz.nn.LayerStackVarBehavior.PER_LAYER
81 }
82 },
83 ),
84 ),
85 stuff={
86 "value": 100.0,
87 "named": pz.nx.arange("bar", 3),
88 "named_stacked": pz.nx.arange("stack", 3),
89 },
90 ),
91 stack_axis="stack",
92 stack_axis_size=3,
93 )
94 counter = pz.StateVariable(0)
95 result = the_layer(10, shared_counter=counter)
96
97 self.assertEqual(result, 13.0)
98 self.assertEqual(counter.value, 3)
99 chex.assert_trees_all_equal(
100 the_layer.stacked_sublayers.states[0].value, 3.0
101 )
102 chex.assert_trees_all_equal(
103 the_layer.stacked_sublayers.states[1].value.canonicalize(),

Callers

nothing calls this directly

Calls 2

MyTestLayerClass · 0.85
canonicalizeMethod · 0.80

Tested by

no test coverage detected