(self)
| 26 | class JitWrapperTest(absltest.TestCase): |
| 27 | |
| 28 | def test_jit_wrapper(self): |
| 29 | @pz.pytree_dataclass |
| 30 | class StateIncrementLayer(pz.nn.Layer): |
| 31 | state: pz.StateVariable[int] |
| 32 | |
| 33 | def __call__(self, x, **_unused_side_inputs): |
| 34 | self.state.value = self.state.value + 1 |
| 35 | return x |
| 36 | |
| 37 | mlp = simple_mlp.DropoutMLP.from_config( |
| 38 | name="mlp", |
| 39 | init_base_rng=jax.random.key(42), |
| 40 | feature_sizes=[8, 16, 8], |
| 41 | drop_rate=0.2, |
| 42 | ) |
| 43 | state_inc = StateIncrementLayer(pz.StateVariable(0, label="counter")) |
| 44 | |
| 45 | # Non-jitted |
| 46 | model = pz.nn.Sequential([mlp, state_inc]) |
| 47 | rstream = pz.RandomStream.from_base_key(jax.random.key(123)) |
| 48 | unjit_result = model( |
| 49 | pz.nx.arange("features", 8).astype(jnp.float32), |
| 50 | random_stream=rstream, |
| 51 | ) |
| 52 | self.assertEqual(rstream.offset.value, 1) |
| 53 | |
| 54 | # Jitted |
| 55 | jit_model = jit_wrapper.Jitted(model) |
| 56 | rstream = pz.RandomStream.from_base_key(jax.random.key(123)) |
| 57 | jit_result = jit_model( |
| 58 | pz.nx.arange("features", 8).astype(jnp.float32), |
| 59 | random_stream=rstream, |
| 60 | ) |
| 61 | chex.assert_trees_all_equal(unjit_result, jit_result) |
| 62 | self.assertEqual(rstream.offset.value, 1) |
| 63 | self.assertEqual(state_inc.state.value, 2) |
nothing calls this directly
no test coverage detected