(self)
| 109 | ) |
| 110 | |
| 111 | def test_layer_stack_build(self): |
| 112 | def builder(init_base_rng, some_value): |
| 113 | return MyTestLayer( |
| 114 | params=( |
| 115 | pz.nn.make_parameter( |
| 116 | "foo", |
| 117 | init_base_rng, |
| 118 | lambda k: pz.nx.wrap( |
| 119 | jax.random.uniform(k, shape=(4,)), "vals" |
| 120 | ), |
| 121 | ), |
| 122 | ), |
| 123 | expected_param_shapes=None, |
| 124 | states=(pz.StateVariable(value=pz.nx.zeros({}), label="varstate"),), |
| 125 | stuff={"value": some_value, "named_value": pz.nx.arange("bar", 4)}, |
| 126 | ) |
| 127 | |
| 128 | layer = pz.nn.LayerStack.from_sublayer_builder( |
| 129 | builder=builder, |
| 130 | stack_axis="stack", |
| 131 | stack_axis_size=3, |
| 132 | init_base_rng=jax.random.key(1), |
| 133 | builder_kwargs={"some_value": 100.0}, |
| 134 | ) |
| 135 | self.assertEqual( |
| 136 | layer.stacked_sublayers.params[0].value.named_shape, |
| 137 | {"stack": 3, "vals": 4}, |
| 138 | ) |
| 139 | self.assertEqual( |
| 140 | layer.stacked_sublayers.states[0].value.named_shape, {"stack": 3} |
| 141 | ) |
| 142 | self.assertEqual(layer.stacked_sublayers.stuff["value"], 100.0) |
| 143 | chex.assert_trees_all_equal( |
| 144 | layer.stacked_sublayers.stuff["named_value"].canonicalize(), |
| 145 | (pz.nx.arange("bar", 4) * pz.nx.ones({"stack": 3})).canonicalize(), |
| 146 | ) |
| 147 | |
| 148 | slot_layer = pz.nn.LayerStack.from_sublayer_builder( |
| 149 | builder=builder, |
| 150 | stack_axis="stack", |
| 151 | stack_axis_size=3, |
| 152 | init_base_rng=None, |
| 153 | builder_kwargs={"some_value": 100.0}, |
| 154 | ) |
| 155 | |
| 156 | unbound_layer, layer_vars = pz.unbind_variables(layer) |
| 157 | unbound_slot_layer, slot_layer_vars = pz.unbind_variables(slot_layer) |
| 158 | |
| 159 | # Check as dictionaries to avoid limitations of chex: |
| 160 | unbound_layer_leaves, unbound_layer_treedef = ( |
| 161 | jax.tree_util.tree_flatten_with_path(unbound_layer) |
| 162 | ) |
| 163 | unbound_slot_layer_leaves, unbound_slot_layer_treedef = ( |
| 164 | jax.tree_util.tree_flatten_with_path(unbound_slot_layer) |
| 165 | ) |
| 166 | self.assertEqual(unbound_layer_treedef, unbound_slot_layer_treedef) |
| 167 | chex.assert_trees_all_equal( |
| 168 | collections.OrderedDict(unbound_layer_leaves), |
nothing calls this directly
no test coverage detected