(self, num_keys=2)
| 42 | return self._key |
| 43 | |
| 44 | def split(self, num_keys=2): |
| 45 | self._assert_not_used() |
| 46 | self._used = True |
| 47 | new_keys = jax.random.split(self._key, num_keys) |
| 48 | return jax.tree_map(SafeKey, tuple(new_keys)) |
| 49 | |
| 50 | def duplicate(self, num_keys=2): |
| 51 | self._assert_not_used() |