14
14
15
15
from typing import Any , Dict , Optional , Type , TypeVar , cast
16
16
17
+ from attr import fields , has
17
18
from omegaconf import OmegaConf
18
19
from omegaconf .errors import ConfigKeyError , ConfigTypeError , MissingMandatoryValue
19
20
@@ -32,15 +33,32 @@ def deserialize_dictionary(
32
33
:return: Object of type `output_type`, holding the data passed in
33
34
`input_dictionary`.
34
35
"""
35
- model_dict_config = OmegaConf .create (input_dictionary )
36
+
37
+ def prune_dict (data : dict , cls : Type [Any ]) -> dict :
38
+ """Recursively prune a dictionary to match the structure of an attr class."""
39
+ pruned_data = {}
40
+ for attribute in fields (cls ):
41
+ key = attribute .name
42
+ if key in data :
43
+ value = data [key ]
44
+ # Check if the field is itself a structured class
45
+ if has (attribute .type ) and isinstance (value , dict ):
46
+ # Recursively prune the nested dictionary
47
+ pruned_data [key ] = prune_dict (value , attribute .type )
48
+ else :
49
+ pruned_data [key ] = value
50
+ return pruned_data
51
+
52
+ filtered_input_dictionary = prune_dict (input_dictionary , output_type )
53
+ model_dict_config = OmegaConf .create (filtered_input_dictionary )
36
54
schema = OmegaConf .structured (output_type )
37
55
schema_error : Optional [DataModelMismatchException ] = None
38
56
try :
39
57
values = OmegaConf .merge (schema , model_dict_config )
40
58
output = cast (output_type , OmegaConf .to_object (values ))
41
59
except (ConfigKeyError , MissingMandatoryValue , ConfigTypeError ) as error :
42
60
schema_error = DataModelMismatchException (
43
- input_dictionary = input_dictionary ,
61
+ input_dictionary = filtered_input_dictionary ,
44
62
output_data_model = output_type ,
45
63
message = error .args [0 ],
46
64
error_type = type (error ),
0 commit comments