Random numbers should be handled correctly.
(self)
| 202 | _slice_layers_params(layer_stack_grad)) |
| 203 | |
| 204 | def test_random(self): |
| 205 | """Random numbers should be handled correctly.""" |
| 206 | n = 100 |
| 207 | |
| 208 | @hk.transform |
| 209 | @layer_stack.layer_stack(n) |
| 210 | def add_random(x): |
| 211 | x = x + jax.random.normal(hk.next_rng_key()) |
| 212 | return x |
| 213 | |
| 214 | # Evaluate a bunch of times |
| 215 | key, *keys = jax.random.split(jax.random.PRNGKey(7), 1024 + 1) |
| 216 | params = add_random.init(key, 0.) |
| 217 | apply_fn = jax.jit(add_random.apply) |
| 218 | values = [apply_fn(params, key, 0.) for key in keys] |
| 219 | |
| 220 | # Should be roughly N(0, sqrt(n)) |
| 221 | cdf = scipy.stats.norm(scale=np.sqrt(n)).cdf |
| 222 | _, p = scipy.stats.kstest(values, cdf) |
| 223 | self.assertLess(0.3, p) |
| 224 | |
| 225 | def test_threading(self): |
| 226 | """Test @layer_stack when the function gets per-layer state.""" |