Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/handle generics #209

Closed
wants to merge 10 commits into from
2 changes: 1 addition & 1 deletion .github/workflows/code_check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ jobs:
comment-always: false
alert-threshold: '130%'
comment-on-alert: false
fail-on-alert: true
fail-on-alert: ${{ !inputs.publish_performance }}
12 changes: 11 additions & 1 deletion dacite/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from itertools import zip_longest
from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any, Collection, MutableMapping

try:
from typing import get_origin # type: ignore
except ImportError:
from typing_extensions import get_origin # type: ignore

from dacite.cache import cache
from dacite.config import Config
from dacite.data import Data
Expand Down Expand Up @@ -31,6 +36,7 @@
is_init_var,
extract_init_var,
is_subclass,
is_generic_subclass,
)

T = TypeVar("T")
Expand Down Expand Up @@ -59,8 +65,8 @@ def from_dict(data_class: Type[T], data: Data, config: Optional[Config] = None)
for field in data_class_fields:
field_type = data_class_hints[field.name]
if field.name in data:
field_data = data[field.name]
try:
field_data = data[field.name]
value = _build_value(type_=field_type, data=field_data, config=config)
except DaciteFieldError as error:
error.update_path(field.name)
Expand Down Expand Up @@ -97,6 +103,10 @@ def _build_value(type_: Type, data: Any, config: Config) -> Any:
data = _build_value_for_collection(collection=type_, data=data, config=config)
elif cache(is_dataclass)(type_) and isinstance(data, Mapping):
data = from_dict(data_class=type_, data=data, config=config)
elif is_generic_subclass(type_) and cache(is_dataclass)(get_origin(type_)):
origin = get_origin(type_)
assert origin is not None
data = from_dict(data_class=origin, data=data, config=config)
for cast_type in config.cast:
if is_subclass(type_, cast_type):
if is_generic_collection(type_):
Expand Down
92 changes: 77 additions & 15 deletions dacite/types.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from dataclasses import InitVar
from typing import (
Type,
Any,
Optional,
Union,
Collection,
TypeVar,
Mapping,
Tuple,
cast as typing_cast,
)
from typing import Type, Any, Optional, Union, Collection, TypeVar, Mapping, Tuple, get_type_hints, cast as typing_cast

try:
from typing import get_origin, get_args # type: ignore
except ImportError:
from typing_extensions import get_origin, get_args # type: ignore
from inspect import isclass

from dacite.cache import cache

Expand Down Expand Up @@ -43,9 +39,14 @@ def is_generic(type_: Type) -> bool:
return hasattr(type_, "__origin__")


@cache
def is_generic_subclass(type_: Type) -> bool:
return is_generic(type_) and hasattr(type_, "__args__")


@cache
def is_union(type_: Type) -> bool:
if is_generic(type_) and type_.__origin__ == Union:
if is_generic(type_) and get_origin(type_) == Union:
return True

try:
Expand All @@ -66,7 +67,7 @@ def is_literal(type_: Type) -> bool:
try:
from typing import Literal # type: ignore

return is_generic(type_) and type_.__origin__ == Literal
return is_generic(type_) and get_origin(type_) == Literal
except ImportError:
return False

Expand All @@ -86,6 +87,31 @@ def is_init_var(type_: Type) -> bool:
return isinstance(type_, InitVar) or type_ is InitVar


@cache
def is_generic_alias(type_: Type) -> bool:
"""Since `typing._GenericAlias` is not explicitly exported, we instead rely on this check."""
return str(type_) == "<class 'typing._GenericAlias'>"


@cache
def has_generic_alias_in_args(type_: Type) -> bool:
return is_generic_alias(type(get_args(type_)))


def is_valid_generic_class(value: Any, type_: Type) -> bool:
origin = get_origin(type_)
if not (origin and isinstance(value, origin)):
return False
type_hints = cache(get_type_hints)(type(value))
for field_name, field_type in type_hints.items():
if isinstance(field_type, TypeVar):
args = get_args(type_)
return True if not args else any(isinstance(getattr(value, field_name, None), arg) for arg in args)
else:
return isinstance(value, type_)
Comment on lines +105 to +111

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using this branch for a project and found this. This block doesn't make sense and also doesn't work.

get_type_hints gets a mapping from field name to field type for the dataclass value. You're iterating over that, but you only ever check the first item because both branches end with return. I'm pretty sure both of those returns in the loop should only be returning if the value is False.

On the first branch, if field_type is a TypeVar, you check the field value against all the arguments for the subscripted generic type_. Why is get_args in the loop? Why is getattr in the inner loop? Neither one is changing at that point. I believe this would also fail to detect an incorrect type if, say, a field is hinted AnyStr, the value is None, and the dataclass has two type arguments, one str and one None.

I assume the second branch is trying to check if the field value matches the field type? But what you're actually doing is checking if the dataclass matches the subscripted dataclass type. This crashes with the exception TypeError: Subscripted generics cannot be used with class and instance checks for obvious reasons.

For the false negative issue with multiple type arguments, possibly something could be done with __orig_bases__? Not sure.

Here's my suggestion for a fixed version. I haven't written test cases, but I did try it with my fairly complicated use case and it seems to work.

Suggested change
type_hints = cache(get_type_hints)(type(value))
for field_name, field_type in type_hints.items():
if isinstance(field_type, TypeVar):
args = get_args(type_)
return True if not args else any(isinstance(getattr(value, field_name, None), arg) for arg in args)
else:
return isinstance(value, type_)
type_args = get_args(type_)
type_hints = cache(get_type_hints)(type(value))
for field_name, field_type in type_hints.items():
field_value = getattr(value, field_name, None)
if isinstance(field_type, TypeVar):
# TODO: this will fail to detect incorrect type in some cases
# see comments on https://github.com/konradhalas/dacite/pull/209
if not any(is_instance(field_value, arg) for arg in type_args):
return False
elif get_origin(field_type) is not ClassVar:
if not is_instance(field_value, field_type):
return False

return True


@cache
def extract_init_var(type_: Type) -> Union[Type, Any]:
try:
Expand All @@ -94,6 +120,31 @@ def extract_init_var(type_: Type) -> Union[Type, Any]:
return Any


@cache
def get_constraints(type_: TypeVar) -> Optional[Any]:
return type_.__constraints__


@cache
def is_constrained(type_: TypeVar) -> bool:
return hasattr(type_, "__constraints__") and get_constraints(type_)


@cache
def get_bound(type_: TypeVar) -> Optional[Any]:
return type_.__bound__


@cache
def is_bound(type_: TypeVar) -> bool:
return hasattr(type_, "__bound__") and get_bound(type_)


@cache
def is_generic_bound(type_: TypeVar) -> bool:
return is_bound(type_) and get_bound(type_) is not None and is_generic(get_bound(type_))


def is_instance(value: Any, type_: Type) -> bool:
try:
# As described in PEP 484 - section: "The numeric tower"
Expand Down Expand Up @@ -134,6 +185,17 @@ def is_instance(value: Any, type_: Type) -> bool:
return value in extract_generic(type_)
elif is_init_var(type_):
return is_instance(value, extract_init_var(type_))
elif isclass(type(type_)) and is_generic_alias(type(type_)):
return is_valid_generic_class(value, type_)
elif isinstance(type_, TypeVar):
if is_constrained(type_):
return any(is_instance(value, t) for t in type_.__constraints__)
if is_bound(type_):
if isinstance(get_bound(type_), tuple):
return any(isinstance(value, t) for t in get_bound(type_))
if is_generic_bound(type_):
return isinstance(value, extract_generic(get_bound(type_)))
return True
elif is_type_generic(type_):
return is_subclass(value, extract_generic(type_)[0])
else:
Expand Down Expand Up @@ -168,14 +230,14 @@ def is_subclass(sub_type: Type, base_type: Type) -> bool:
if is_generic_collection(sub_type):
sub_type = extract_origin_collection(sub_type)
try:
return issubclass(sub_type, base_type)
return cache(issubclass)(sub_type, base_type)
except TypeError:
return False


@cache
def is_type_generic(type_: Type) -> bool:
try:
return type_.__origin__ in (type, Type)
return get_origin(type_) in (type, Type)
except AttributeError:
return False
70 changes: 68 additions & 2 deletions tests/core/test_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass, field
from typing import Any, NewType, Optional, List
from typing import Any, NewType, Optional, List, TypeVar, Generic, Union

import pytest

from dacite import from_dict, MissingValueError, WrongTypeError
from dacite import from_dict, MissingValueError, WrongTypeError, Config


def test_from_dict_with_correct_data():
Expand Down Expand Up @@ -193,6 +193,70 @@ class X:
assert result == X(s=MyStr("test"))


def test_from_dict_generic_valid():
T = TypeVar("T", bound=Union[str, int])

@dataclass
class A(Generic[T]):
a: T

@dataclass
class B:
a_str: A[str]
a_int: A[int]

assert from_dict(B, {"a_str": {"a": "test"}, "a_int": {"a": 1}}) == B(a_str=A[str](a="test"), a_int=A[int](a=1))


def test_from_dict_generic_invalid():
T = TypeVar("T")

@dataclass
class A(Generic[T]):
a: T

@dataclass
class B:
a_str: A[str]
a_int: A[int]

with pytest.raises(WrongTypeError):
from_dict(B, {"a_str": {"a": "test"}, "a_int": {"a": "not int"}})


def test_from_dict_generic_common_invalid():
T = TypeVar("T", str, List[str])

@dataclass
class Common(Generic[T]):
foo: T
bar: T

@dataclass
class A:
elements: List[Common[int]]

with pytest.raises(WrongTypeError):
from_dict(A, {"elements": [{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]})


def test_from_dict_generic_common():
T = TypeVar("T", bound=int)

@dataclass
class Common(Generic[T]):
foo: T
bar: T

@dataclass
class A:
elements: List[Common[int]]

result = from_dict(A, {"elements": [{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]})

assert result == A(elements=[Common[int](1, 2), Common[int](3, 4)])


def test_dataclass_default_factory_identity():
# https://github.com/konradhalas/dacite/issues/215
@dataclass
Expand All @@ -203,4 +267,6 @@ class A:
a1 = from_dict(A, {"name": "a1"})
a2 = from_dict(A, {"name": "a2"})

assert a1 is not a2
assert a1.name is not a2.name
assert a1.items is not a2.items
10 changes: 7 additions & 3 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import InitVar
from dataclasses import InitVar, dataclass
from typing import Optional, Union, List, Any, Dict, NewType, TypeVar, Generic, Collection, Tuple, Type
from unittest.mock import patch, Mock

Expand Down Expand Up @@ -268,13 +268,13 @@ def test_is_instance_with_with_type_and_not_matching_value_type():
assert not is_instance(1, Type[str])


def test_is_instance_with_not_supported_generic_types():
def test_is_instance_with_generic_types():
T = TypeVar("T")

class X(Generic[T]):
pass

assert not is_instance(X[str](), X[str])
assert is_instance(X[str](), X[str])


def test_is_instance_with_generic_mapping_and_matching_value_type():
Expand Down Expand Up @@ -364,6 +364,10 @@ def test_is_instance_with_empty_tuple_and_not_matching_type():
assert not is_instance((1, 2), Tuple[()])


def test_is_instance_list_type():
assert is_instance([{}], List)


def test_extract_generic():
assert extract_generic(List[int]) == (int,)

Expand Down