MCPcopy
hub / github.com/jax-ml/jax / split

Function split

jax/_src/random/core.py:319–331  ·  view source on GitHub ↗

Splits a PRNG key into `num` new keys by adding a leading axis. Args: key: a PRNG key (from ``key``, ``split``, ``fold_in``). num: optional, a positive integer (or tuple of integers) indicating the number (or shape) of keys to produce. Defaults to 2. Returns: An array-like ob

(key: ArrayLike, num: int | tuple[int, ...] = 2)

Source from the content-addressed store, hash-verified

317 return prng.random_split(key, shape=shape)
318
319def split(key: ArrayLike, num: int | tuple[int, ...] = 2) -> Array:
320 """Splits a PRNG key into `num` new keys by adding a leading axis.
321
322 Args:
323 key: a PRNG key (from ``key``, ``split``, ``fold_in``).
324 num: optional, a positive integer (or tuple of integers) indicating
325 the number (or shape) of keys to produce. Defaults to 2.
326
327 Returns:
328 An array-like object of `num` new PRNG keys.
329 """
330 typed_key, wrapped = _check_prng_key("split", key)
331 return _return_prng_keys(wrapped, _split(typed_key, num))
332
333
334def _key_impl(keys: Array) -> PRNGImpl:

Callers 5

_generalized_normalFunction · 0.70
_ballFunction · 0.70
body_fnFunction · 0.70
multinomialFunction · 0.70
random_split_impl_baseFunction · 0.70

Calls 3

_check_prng_keyFunction · 0.85
_return_prng_keysFunction · 0.85
_splitFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…