Wrapper around jaxtyping.jaxtyped that uses typeguard iff typeguard < 3.
(fn: _T)
| 24 | |
| 25 | |
| 26 | def jaxtyped(fn: _T) -> _T: |
| 27 | """Wrapper around jaxtyping.jaxtyped that uses typeguard iff typeguard < 3.""" |
| 28 | try: |
| 29 | major, *_ = importlib.metadata.version('typeguard').split('.') |
| 30 | except importlib.metadata.PackageNotFoundError: |
| 31 | major = -1 |
| 32 | |
| 33 | # Only use jaxtyping if typeguard is < 3. See |
| 34 | # https://docs.kidger.site/jaxtyping/api/runtime-type-checking/#runtime-type-checking |
| 35 | # for more details. |
| 36 | if int(major) < 3: |
| 37 | return jaxtyping.jaxtyped(fn, typechecker=typeguard.typechecked) |
| 38 | else: |
| 39 | return fn |