Skip to content

Commit c1aa6da

Browse files
committed
Add basic handling of typing.Generic
1 parent 9bf8fea commit c1aa6da

File tree

4 files changed

+121
-10
lines changed

4 files changed

+121
-10
lines changed

dacite/core.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
2-
from dataclasses import is_dataclass
2+
from dataclasses import is_dataclass, dataclass
33
from itertools import zip_longest
4-
from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any, Collection, MutableMapping
4+
from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any, Collection, MutableMapping, get_origin
55

66
from dacite.cache import cache
77
from dacite.config import Config
@@ -33,6 +33,7 @@
3333
is_init_var,
3434
extract_init_var,
3535
is_subclass,
36+
is_generic_subclass,
3637
)
3738

3839
T = TypeVar("T")
@@ -61,9 +62,10 @@ def from_dict(data_class: Type[T], data: Data, config: Optional[Config] = None)
6162
for field in data_class_fields:
6263
field = copy.copy(field)
6364
field.type = data_class_hints[field.name]
64-
if field.name in data:
65+
66+
if hasattr(data, field.name) or (isinstance(data, Mapping) and field.name in data):
67+
field_data = getattr(data, field.name, None) or data[field.name]
6568
try:
66-
field_data = data[field.name]
6769
value = _build_value(type_=field.type, data=field_data, config=config)
6870
except DaciteFieldError as error:
6971
error.update_path(field.name)
@@ -98,6 +100,8 @@ def _build_value(type_: Type, data: Any, config: Config) -> Any:
98100
data = _build_value_for_collection(collection=type_, data=data, config=config)
99101
elif is_dataclass(type_) and isinstance(data, Mapping):
100102
data = from_dict(data_class=type_, data=data, config=config)
103+
elif is_generic_subclass(type_) and is_dataclass(get_origin(type_)):
104+
data = from_dict(data_class=get_origin(type_), data=data, config=config)
101105
for cast_type in config.cast:
102106
if is_subclass(type_, cast_type):
103107
if is_generic_collection(type_):

dacite/types.py

+39
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88
TypeVar,
99
Mapping,
1010
Tuple,
11+
get_origin,
12+
get_type_hints,
13+
get_args,
1114
cast as typing_cast,
15+
_GenericAlias, # Remove import and check for Generic in a different way
1216
)
17+
from inspect import isclass
1318

1419
from dacite.cache import cache
1520

@@ -43,6 +48,16 @@ def is_generic(type_: Type) -> bool:
4348
return hasattr(type_, "__origin__")
4449

4550

51+
@cache
52+
def is_generic_subclass(type_: Type) -> bool:
53+
return is_generic(type_) and hasattr(type_, "__args__")
54+
55+
56+
@cache
57+
def is_generic_alias(type_: Type) -> bool:
58+
return type(type_.__args__) == _GenericAlias
59+
60+
4661
@cache
4762
def is_union(type_: Type) -> bool:
4863
if is_generic(type_) and type_.__origin__ == Union:
@@ -86,6 +101,22 @@ def is_init_var(type_: Type) -> bool:
86101
return isinstance(type_, InitVar) or type_ is InitVar
87102

88103

104+
def is_valid_generic_class(value: Any, type_: Type) -> bool:
105+
if not isinstance(value, get_origin(type_)):
106+
return False
107+
type_hints = get_type_hints(value)
108+
for field_name, field_type in type_hints.items():
109+
if isinstance(field_type, TypeVar):
110+
return (
111+
any([isinstance(getattr(value, field_name), arg) for arg in get_args(type_)])
112+
if get_args(type_)
113+
else True
114+
)
115+
else:
116+
return is_instance(value, type_)
117+
return True
118+
119+
89120
@cache
90121
def extract_init_var(type_: Type) -> Union[Type, Any]:
91122
try:
@@ -128,6 +159,14 @@ def is_instance(value: Any, type_: Type) -> bool:
128159
return value in extract_generic(type_)
129160
elif is_init_var(type_):
130161
return is_instance(value, extract_init_var(type_))
162+
elif isclass(type(type_)) and type(type_) == _GenericAlias:
163+
return is_valid_generic_class(value, type_)
164+
elif isinstance(type_, TypeVar):
165+
if hasattr(type_, "__constraints__") and type_.__constraints__:
166+
return isinstance(value, type_.__constraints__)
167+
if hasattr(type_, "__bound__") and type_.__bound__:
168+
return isinstance(value, type_.__bound__)
169+
return True
131170
elif is_type_generic(type_):
132171
return is_subclass(value, extract_generic(type_)[0])
133172
else:

tests/core/test_base.py

+67-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from dataclasses import dataclass, field
2-
from typing import Any, NewType, Optional
1+
from dataclasses import dataclass, field, asdict
2+
from typing import Any, NewType, Optional, TypeVar, Generic, List, Union
33

44
import pytest
55

6-
from dacite import from_dict, MissingValueError, WrongTypeError
6+
from dacite import from_dict, MissingValueError, WrongTypeError, Config
77

88

99
def test_from_dict_with_correct_data():
@@ -191,3 +191,67 @@ class X:
191191
result = from_dict(X, {"s": "test"})
192192

193193
assert result == X(s=MyStr("test"))
194+
195+
196+
def test_from_dict_generic():
197+
T = TypeVar("T", bound=Union[str, int])
198+
199+
@dataclass
200+
class A(Generic[T]):
201+
a: T
202+
203+
@dataclass
204+
class B:
205+
a_str: A[str]
206+
a_int: A[int]
207+
208+
assert from_dict(B, {"a_str": {"a": "test"}, "a_int": {"a": 1}}) == B(a_str=A[str](a="test"), a_int=A[int](a=1))
209+
210+
211+
def test_from_dict_generic_invalid():
212+
T = TypeVar("T")
213+
214+
@dataclass
215+
class A(Generic[T]):
216+
a: T
217+
218+
@dataclass
219+
class B:
220+
a_str: A[str]
221+
a_int: A[int]
222+
223+
with pytest.raises(WrongTypeError):
224+
from_dict(B, {"a_str": {"a": "test"}, "a_int": {"a": "not int"}})
225+
226+
227+
def test_from_dict_generic_common_invalid():
228+
T = TypeVar("T", bound=str)
229+
230+
@dataclass
231+
class Common(Generic[T]):
232+
foo: T
233+
bar: T
234+
235+
@dataclass
236+
class A:
237+
elements: List[Common[int]]
238+
239+
with pytest.raises(WrongTypeError):
240+
from_dict(A, {"elements": [{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]})
241+
242+
243+
def test_from_dict_generic_common():
244+
T = TypeVar("T", bound=int)
245+
246+
@dataclass
247+
class Common(Generic[T]):
248+
foo: T
249+
bar: T
250+
251+
@dataclass
252+
class A:
253+
elements: List[Common[int]]
254+
255+
result = from_dict(A, {"elements": [{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]})
256+
257+
assert result == A(elements=[Common[int](1, 2), Common[int](3, 4)])

tests/test_types.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import InitVar
1+
from dataclasses import InitVar, dataclass
22
from typing import Optional, Union, List, Any, Dict, NewType, TypeVar, Generic, Collection, Tuple, Type
33
from unittest.mock import patch, Mock
44

@@ -268,13 +268,13 @@ def test_is_instance_with_with_type_and_not_matching_value_type():
268268
assert not is_instance(1, Type[str])
269269

270270

271-
def test_is_instance_with_not_supported_generic_types():
271+
def test_is_instance_with_generic_types():
272272
T = TypeVar("T")
273273

274274
class X(Generic[T]):
275275
pass
276276

277-
assert not is_instance(X[str](), X[str])
277+
assert is_instance(X[str](), X[str])
278278

279279

280280
def test_is_instance_with_generic_mapping_and_matching_value_type():
@@ -364,6 +364,10 @@ def test_is_instance_with_empty_tuple_and_not_matching_type():
364364
assert not is_instance((1, 2), Tuple[()])
365365

366366

367+
def test_is_instance_list_type():
368+
assert is_instance([{}], List)
369+
370+
367371
def test_extract_generic():
368372
assert extract_generic(List[int]) == (int,)
369373

0 commit comments

Comments
 (0)