(init_base_rng, some_value)
| 110 | |
| 111 | def test_layer_stack_build(self): |
| 112 | def builder(init_base_rng, some_value): |
| 113 | return MyTestLayer( |
| 114 | params=( |
| 115 | pz.nn.make_parameter( |
| 116 | "foo", |
| 117 | init_base_rng, |
| 118 | lambda k: pz.nx.wrap( |
| 119 | jax.random.uniform(k, shape=(4,)), "vals" |
| 120 | ), |
| 121 | ), |
| 122 | ), |
| 123 | expected_param_shapes=None, |
| 124 | states=(pz.StateVariable(value=pz.nx.zeros({}), label="varstate"),), |
| 125 | stuff={"value": some_value, "named_value": pz.nx.arange("bar", 4)}, |
| 126 | ) |
| 127 | |
| 128 | layer = pz.nn.LayerStack.from_sublayer_builder( |
| 129 | builder=builder, |
nothing calls this directly
no test coverage detected