Skip to content

Commit 71a093d

Browse files
committed
Add basic handling of typing.Generic
1 parent c831d57 commit 71a093d

File tree

4 files changed

+122
-9
lines changed

4 files changed

+122
-9
lines changed

dacite/core.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import is_dataclass
22
from itertools import zip_longest
3-
from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any, Collection, MutableMapping
3+
from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any, Collection, MutableMapping, get_origin
44

55
from dacite.cache import cache
66
from dacite.config import Config
@@ -31,6 +31,7 @@
3131
is_init_var,
3232
extract_init_var,
3333
is_subclass,
34+
is_generic_subclass,
3435
)
3536

3637
T = TypeVar("T")
@@ -58,9 +59,9 @@ def from_dict(data_class: Type[T], data: Data, config: Optional[Config] = None)
5859
raise UnexpectedDataError(keys=extra_fields)
5960
for field in data_class_fields:
6061
field_type = data_class_hints[field.name]
61-
if field.name in data:
62+
if hasattr(data, field.name) or (isinstance(data, Mapping) and field.name in data):
63+
field_data = getattr(data, field.name, None) or data[field.name]
6264
try:
63-
field_data = data[field.name]
6465
value = _build_value(type_=field_type, data=field_data, config=config)
6566
except DaciteFieldError as error:
6667
error.update_path(field.name)
@@ -97,6 +98,8 @@ def _build_value(type_: Type, data: Any, config: Config) -> Any:
9798
data = _build_value_for_collection(collection=type_, data=data, config=config)
9899
elif cache(is_dataclass)(type_) and isinstance(data, Mapping):
99100
data = from_dict(data_class=type_, data=data, config=config)
101+
elif is_generic_subclass(type_) and is_dataclass(get_origin(type_)):
102+
data = from_dict(data_class=get_origin(type_), data=data, config=config)
100103
for cast_type in config.cast:
101104
if is_subclass(type_, cast_type):
102105
if is_generic_collection(type_):

dacite/types.py

+42
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88
TypeVar,
99
Mapping,
1010
Tuple,
11+
get_origin,
12+
get_type_hints,
13+
get_args,
1114
cast as typing_cast,
15+
_GenericAlias, # Remove import and check for Generic in a different way
1216
)
17+
from inspect import isclass
1318

1419
from dacite.cache import cache
1520

@@ -43,6 +48,16 @@ def is_generic(type_: Type) -> bool:
4348
return hasattr(type_, "__origin__")
4449

4550

51+
@cache
52+
def is_generic_subclass(type_: Type) -> bool:
53+
return is_generic(type_) and hasattr(type_, "__args__")
54+
55+
56+
@cache
57+
def is_generic_alias(type_: Type) -> bool:
58+
return type(type_.__args__) == _GenericAlias
59+
60+
4661
@cache
4762
def is_union(type_: Type) -> bool:
4863
if is_generic(type_) and type_.__origin__ == Union:
@@ -86,6 +101,22 @@ def is_init_var(type_: Type) -> bool:
86101
return isinstance(type_, InitVar) or type_ is InitVar
87102

88103

104+
def is_valid_generic_class(value: Any, type_: Type) -> bool:
105+
if not isinstance(value, get_origin(type_)):
106+
return False
107+
type_hints = get_type_hints(value)
108+
for field_name, field_type in type_hints.items():
109+
if isinstance(field_type, TypeVar):
110+
return (
111+
any([isinstance(getattr(value, field_name), arg) for arg in get_args(type_)])
112+
if get_args(type_)
113+
else True
114+
)
115+
else:
116+
return is_instance(value, type_)
117+
return True
118+
119+
89120
@cache
90121
def extract_init_var(type_: Type) -> Union[Type, Any]:
91122
try:
@@ -134,6 +165,17 @@ def is_instance(value: Any, type_: Type) -> bool:
134165
return value in extract_generic(type_)
135166
elif is_init_var(type_):
136167
return is_instance(value, extract_init_var(type_))
168+
elif isclass(type(type_)) and type(type_) == _GenericAlias:
169+
return is_valid_generic_class(value, type_)
170+
elif isinstance(type_, TypeVar):
171+
if hasattr(type_, "__constraints__") and type_.__constraints__:
172+
return any(is_instance(value, t) for t in type_.__constraints__)
173+
if hasattr(type_, "__bound__") and type_.__bound__:
174+
if isinstance(type_.__bound__, tuple):
175+
return any(is_instance(value, t) for t in type_.__bound__)
176+
if type_.__bound__ is not None and is_generic(type_.__bound__):
177+
return isinstance(value, type_.__bound__)
178+
return True
137179
elif is_type_generic(type_):
138180
return is_subclass(value, extract_generic(type_)[0])
139181
else:

tests/core/test_base.py

+67-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from dataclasses import dataclass, field
2-
from typing import Any, NewType, Optional
1+
from dataclasses import dataclass, field, asdict
2+
from typing import Any, NewType, Optional, TypeVar, Generic, List, Union
33

44
import pytest
55

6-
from dacite import from_dict, MissingValueError, WrongTypeError
6+
from dacite import from_dict, MissingValueError, WrongTypeError, Config
77

88

99
def test_from_dict_with_correct_data():
@@ -191,3 +191,67 @@ class X:
191191
result = from_dict(X, {"s": "test"})
192192

193193
assert result == X(s=MyStr("test"))
194+
195+
196+
def test_from_dict_generic():
197+
T = TypeVar("T", bound=Union[str, int])
198+
199+
@dataclass
200+
class A(Generic[T]):
201+
a: T
202+
203+
@dataclass
204+
class B:
205+
a_str: A[str]
206+
a_int: A[int]
207+
208+
assert from_dict(B, {"a_str": {"a": "test"}, "a_int": {"a": 1}}) == B(a_str=A[str](a="test"), a_int=A[int](a=1))
209+
210+
211+
def test_from_dict_generic_invalid():
212+
T = TypeVar("T")
213+
214+
@dataclass
215+
class A(Generic[T]):
216+
a: T
217+
218+
@dataclass
219+
class B:
220+
a_str: A[str]
221+
a_int: A[int]
222+
223+
with pytest.raises(WrongTypeError):
224+
from_dict(B, {"a_str": {"a": "test"}, "a_int": {"a": "not int"}})
225+
226+
227+
def test_from_dict_generic_common_invalid():
228+
T = TypeVar("T", str, List[str])
229+
230+
@dataclass
231+
class Common(Generic[T]):
232+
foo: T
233+
bar: T
234+
235+
@dataclass
236+
class A:
237+
elements: List[Common[int]]
238+
239+
with pytest.raises(WrongTypeError):
240+
from_dict(A, {"elements": [{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]})
241+
242+
243+
def test_from_dict_generic_common():
244+
T = TypeVar("T", bound=int)
245+
246+
@dataclass
247+
class Common(Generic[T]):
248+
foo: T
249+
bar: T
250+
251+
@dataclass
252+
class A:
253+
elements: List[Common[int]]
254+
255+
result = from_dict(A, {"elements": [{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]})
256+
257+
assert result == A(elements=[Common[int](1, 2), Common[int](3, 4)])

tests/test_types.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import InitVar
1+
from dataclasses import InitVar, dataclass
22
from typing import Optional, Union, List, Any, Dict, NewType, TypeVar, Generic, Collection, Tuple, Type
33
from unittest.mock import patch, Mock
44

@@ -268,13 +268,13 @@ def test_is_instance_with_with_type_and_not_matching_value_type():
268268
assert not is_instance(1, Type[str])
269269

270270

271-
def test_is_instance_with_not_supported_generic_types():
271+
def test_is_instance_with_generic_types():
272272
T = TypeVar("T")
273273

274274
class X(Generic[T]):
275275
pass
276276

277-
assert not is_instance(X[str](), X[str])
277+
assert is_instance(X[str](), X[str])
278278

279279

280280
def test_is_instance_with_generic_mapping_and_matching_value_type():
@@ -364,6 +364,10 @@ def test_is_instance_with_empty_tuple_and_not_matching_type():
364364
assert not is_instance((1, 2), Tuple[()])
365365

366366

367+
def test_is_instance_list_type():
368+
assert is_instance([{}], List)
369+
370+
367371
def test_extract_generic():
368372
assert extract_generic(List[int]) == (int,)
369373

0 commit comments

Comments
 (0)