MCPcopy
hub / github.com/google-deepmind/penzai / test_jit_wrapper

Method test_jit_wrapper

tests/toolshed/jit_wrapper_test.py:28–63  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

26class JitWrapperTest(absltest.TestCase):
27
28 def test_jit_wrapper(self):
29 @pz.pytree_dataclass
30 class StateIncrementLayer(pz.nn.Layer):
31 state: pz.StateVariable[int]
32
33 def __call__(self, x, **_unused_side_inputs):
34 self.state.value = self.state.value + 1
35 return x
36
37 mlp = simple_mlp.DropoutMLP.from_config(
38 name="mlp",
39 init_base_rng=jax.random.key(42),
40 feature_sizes=[8, 16, 8],
41 drop_rate=0.2,
42 )
43 state_inc = StateIncrementLayer(pz.StateVariable(0, label="counter"))
44
45 # Non-jitted
46 model = pz.nn.Sequential([mlp, state_inc])
47 rstream = pz.RandomStream.from_base_key(jax.random.key(123))
48 unjit_result = model(
49 pz.nx.arange("features", 8).astype(jnp.float32),
50 random_stream=rstream,
51 )
52 self.assertEqual(rstream.offset.value, 1)
53
54 # Jitted
55 jit_model = jit_wrapper.Jitted(model)
56 rstream = pz.RandomStream.from_base_key(jax.random.key(123))
57 jit_result = jit_model(
58 pz.nx.arange("features", 8).astype(jnp.float32),
59 random_stream=rstream,
60 )
61 chex.assert_trees_all_equal(unjit_result, jit_result)
62 self.assertEqual(rstream.offset.value, 1)
63 self.assertEqual(state_inc.state.value, 2)

Callers

nothing calls this directly

Calls 3

from_base_keyMethod · 0.80
StateIncrementLayerClass · 0.70
from_configMethod · 0.45

Tested by

no test coverage detected