MCPcopy
hub / github.com/patrick-kidger/jaxtyping

github.com/patrick-kidger/jaxtyping @v0.3.11 sqlite

repository ↗ · DeepWiki ↗ · release v0.3.11 ↗
453 symbols 1,396 edges 31 files 31 documented · 7%
README

jaxtyping

A library providing type annotations and runtime type-checking for the shape and dtype of JAX/PyTorch/NumPy/MLX/TensorFlow arrays and tensors.

The name 'jax'typing is now historical, we support all of the above and have no JAX dependency!

from jaxtyping import Float
from torch import Tensor

# Accepts floating-point 2D arrays with matching axes
def matrix_multiply(x: Float[Tensor, "dim1 dim2"],
                    y: Float[Tensor, "dim2 dim3"]
                  ) -> Float[Tensor, "dim1 dim3"]:
    ...

Installation

pip install jaxtyping

Requires Python 3.10+.

The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are typeguard (which exhaustively checks every argument) and beartype (which checks random pieces of arguments).

Documentation

Available at https://docs.kidger.site/jaxtyping.

See also: other libraries in the JAX ecosystem

Always useful
Equinox: neural networks and everything not already in core JAX!

Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Levanter: scalable+reliable training of foundation models (e.g. LLMs).
paramax: parameterizations and constraints for PyTrees.

Scientific computing
Diffrax: numerical differential equation solvers.
Optimistix: root finding, minimisation, fixed points, and least squares.
Lineax: linear solvers.
BlackJAX: probabilistic+Bayesian sampling.
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
PySR: symbolic regression. (Non-JAX honourable mention!)

Awesome JAX
Awesome JAX: a longer list of other JAX projects.

Core symbols most depended-on inside this repo

_make_dtype
called by 36
jaxtyping/_array_types.py
qualified_name
called by 23
jaxtyping/_typeguard/__init__.py
check_type
called by 19
jaxtyping/_typeguard/__init__.py
get_args
called by 13
jaxtyping/_typeguard/__init__.py
make_mlp
called by 11
test/helpers.py
function_name
called by 7
jaxtyping/_typeguard/__init__.py
fn
called by 6
jaxtyping/_decorator.py
_check_scalar
called by 6
jaxtyping/_array_types.py

Shape

Function 255
Method 115
Class 83

Languages

Python100%

Modules by API surface

test/test_array.py74 symbols
jaxtyping/_typeguard/__init__.py66 symbols
test/test_decorator.py51 symbols
jaxtyping/_array_types.py41 symbols
test/import_hook_tester.py40 symbols
test/test_pytree.py27 symbols
jaxtyping/_import_hook.py26 symbols
jaxtyping/_decorator.py21 symbols
jaxtyping/_pytree_type.py14 symbols
jaxtyping/_storage.py13 symbols
test/test_generators.py10 symbols
test/test_ipython_extension.py9 symbols

Dependencies from manifests, versioned

wadler_lindig0.1.3 · 1×

For agents

$ claude mcp add jaxtyping \
  -- python -m otcore.mcp_server <graph>

⬇ download graph artifact