MCPcopy
hub / github.com/google-deepmind/penzai / test_layer_stack_build

Method test_layer_stack_build

tests/nn/layer_stack_test.py:111–179  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

109 )
110
111 def test_layer_stack_build(self):
112 def builder(init_base_rng, some_value):
113 return MyTestLayer(
114 params=(
115 pz.nn.make_parameter(
116 "foo",
117 init_base_rng,
118 lambda k: pz.nx.wrap(
119 jax.random.uniform(k, shape=(4,)), "vals"
120 ),
121 ),
122 ),
123 expected_param_shapes=None,
124 states=(pz.StateVariable(value=pz.nx.zeros({}), label="varstate"),),
125 stuff={"value": some_value, "named_value": pz.nx.arange("bar", 4)},
126 )
127
128 layer = pz.nn.LayerStack.from_sublayer_builder(
129 builder=builder,
130 stack_axis="stack",
131 stack_axis_size=3,
132 init_base_rng=jax.random.key(1),
133 builder_kwargs={"some_value": 100.0},
134 )
135 self.assertEqual(
136 layer.stacked_sublayers.params[0].value.named_shape,
137 {"stack": 3, "vals": 4},
138 )
139 self.assertEqual(
140 layer.stacked_sublayers.states[0].value.named_shape, {"stack": 3}
141 )
142 self.assertEqual(layer.stacked_sublayers.stuff["value"], 100.0)
143 chex.assert_trees_all_equal(
144 layer.stacked_sublayers.stuff["named_value"].canonicalize(),
145 (pz.nx.arange("bar", 4) * pz.nx.ones({"stack": 3})).canonicalize(),
146 )
147
148 slot_layer = pz.nn.LayerStack.from_sublayer_builder(
149 builder=builder,
150 stack_axis="stack",
151 stack_axis_size=3,
152 init_base_rng=None,
153 builder_kwargs={"some_value": 100.0},
154 )
155
156 unbound_layer, layer_vars = pz.unbind_variables(layer)
157 unbound_slot_layer, slot_layer_vars = pz.unbind_variables(slot_layer)
158
159 # Check as dictionaries to avoid limitations of chex:
160 unbound_layer_leaves, unbound_layer_treedef = (
161 jax.tree_util.tree_flatten_with_path(unbound_layer)
162 )
163 unbound_slot_layer_leaves, unbound_slot_layer_treedef = (
164 jax.tree_util.tree_flatten_with_path(unbound_slot_layer)
165 )
166 self.assertEqual(unbound_layer_treedef, unbound_slot_layer_treedef)
167 chex.assert_trees_all_equal(
168 collections.OrderedDict(unbound_layer_leaves),

Callers

nothing calls this directly

Calls 2

from_sublayer_builderMethod · 0.80
canonicalizeMethod · 0.80

Tested by

no test coverage detected