From 427227039807186b6a852f9d37dd5ddc37effe5c Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 28 Jan 2025 20:05:04 +0100 Subject: [PATCH] Investigating if we can drop the typeguard dependency. --- diffrax/_integrate.py | 2 +- diffrax/_typing.py | 43 +++++++++++++++++++++++++------------------ 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 5f6d05d5..88c014aa 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -198,7 +198,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): try: with jax.numpy_dtype_promotion("standard"): jtu.tree_map(_check, term_structure, terms, contr_kwargs, y) - except Exception as e: + except ValueError as e: # ValueError may also arise from mismatched tree structures pretty_term = wl.pformat(terms) pretty_expected = wl.pformat(term_structure) diff --git a/diffrax/_typing.py b/diffrax/_typing.py index e0bfff6c..694357ed 100644 --- a/diffrax/_typing.py +++ b/diffrax/_typing.py @@ -1,5 +1,4 @@ import inspect -import sys import types from typing import ( Annotated, @@ -14,8 +13,6 @@ ) from typing_extensions import TypeAlias -import typeguard - # We don't actually care what people have subscripted with. # In practice this should be thought of as TypeLike = Union[type, types.UnionType]. Plus @@ -23,24 +20,34 @@ TypeLike: TypeAlias = Any -def better_isinstance(x, annotation) -> bool: - """As `isinstance`, but supports general type hints.""" +_T = TypeVar("_T") - @typeguard.typechecked - def f(y: annotation): - pass - try: - f(x) - except TypeError: - return False - else: - return True +class _Foo(Generic[_T]): + pass + +_generic_alias_types = (types.GenericAlias, type(_Foo[int])) +_union_origins = (Union, types.UnionType) +del _Foo, _T -_union_types: list = [Union] -if sys.version_info >= (3, 10): - _union_types.append(types.UnionType) + +def better_isinstance(x, annotation) -> bool: + """As `isinstance`, but supports a few other types that are useful to us.""" + origin = get_origin(annotation) + if origin in _union_origins: + return any(better_isinstance(x, arg) for arg in get_args(annotation)) + elif isinstance(annotation, _generic_alias_types): + assert origin is not None + return better_isinstance(x, origin) + elif annotation is Any: + return True + elif isinstance(annotation, type): + return isinstance(x, annotation) + else: + raise NotImplementedError( + f"Do not know how to check whether `{x}` is an instance of `{annotation}`." + ) def get_origin_no_specials(x, error_msg: str) -> Optional[type]: @@ -59,7 +66,7 @@ def get_origin_no_specials(x, error_msg: str) -> Optional[type]: As `get_origin`, specifically either `None` or a class. """ origin = get_origin(x) - if origin in _union_types: + if origin in _union_origins: raise NotImplementedError(f"Cannot use unions in `{error_msg}`.") elif origin is Annotated: # We do allow Annotated, just because it's easy to handle.