| 24 | |
| 25 | @pz.pytree_dataclass |
| 26 | class MyTestLayer(pz.nn.Layer): |
| 27 | params: tuple[pz.Parameter[pz.nx.NamedArray], ...] |
| 28 | expected_param_shapes: tuple[dict[str, int], ...] | None |
| 29 | states: tuple[pz.StateVariable[pz.nx.NamedArray], ...] |
| 30 | stuff: Any |
| 31 | |
| 32 | def __call__(self, arg, /, shared_counter, **side_inputs): |
| 33 | if self.expected_param_shapes is not None: |
| 34 | for param, expected in zip(self.params, self.expected_param_shapes): |
| 35 | assert param.value.named_shape == expected |
| 36 | |
| 37 | for state in self.states: |
| 38 | state.value += 1 |
| 39 | |
| 40 | shared_counter.value += 1 |
| 41 | return arg + 1 |
| 42 | |
| 43 | |
| 44 | class LayerStackTest(absltest.TestCase): |
no outgoing calls