Skip to content

Commit

Permalink
Investigating if we can drop the typeguard dependency.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jan 29, 2025
1 parent 583cd6d commit 4272270
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 19 deletions.
2 changes: 1 addition & 1 deletion diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
try:
with jax.numpy_dtype_promotion("standard"):
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
except Exception as e:
except ValueError as e:
# ValueError may also arise from mismatched tree structures
pretty_term = wl.pformat(terms)
pretty_expected = wl.pformat(term_structure)
Expand Down
43 changes: 25 additions & 18 deletions diffrax/_typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
import sys
import types
from typing import (
Annotated,
Expand All @@ -14,33 +13,41 @@
)
from typing_extensions import TypeAlias

import typeguard


# We don't actually care what people have subscripted with.
# In practice this should be thought of as TypeLike = Union[type, types.UnionType]. Plus
# maybe type(Literal) and so on?
TypeLike: TypeAlias = Any


def better_isinstance(x, annotation) -> bool:
"""As `isinstance`, but supports general type hints."""
_T = TypeVar("_T")

@typeguard.typechecked
def f(y: annotation):
pass

try:
f(x)
except TypeError:
return False
else:
return True
class _Foo(Generic[_T]):
pass


_generic_alias_types = (types.GenericAlias, type(_Foo[int]))
_union_origins = (Union, types.UnionType)
del _Foo, _T

_union_types: list = [Union]
if sys.version_info >= (3, 10):
_union_types.append(types.UnionType)

def better_isinstance(x, annotation) -> bool:
"""As `isinstance`, but supports a few other types that are useful to us."""
origin = get_origin(annotation)
if origin in _union_origins:
return any(better_isinstance(x, arg) for arg in get_args(annotation))
elif isinstance(annotation, _generic_alias_types):
assert origin is not None
return better_isinstance(x, origin)
elif annotation is Any:
return True
elif isinstance(annotation, type):
return isinstance(x, annotation)
else:
raise NotImplementedError(
f"Do not know how to check whether `{x}` is an instance of `{annotation}`."
)


def get_origin_no_specials(x, error_msg: str) -> Optional[type]:
Expand All @@ -59,7 +66,7 @@ def get_origin_no_specials(x, error_msg: str) -> Optional[type]:
As `get_origin`, specifically either `None` or a class.
"""
origin = get_origin(x)
if origin in _union_types:
if origin in _union_origins:
raise NotImplementedError(f"Cannot use unions in `{error_msg}`.")
elif origin is Annotated:
# We do allow Annotated, just because it's easy to handle.
Expand Down

0 comments on commit 4272270

Please sign in to comment.