Skip to content

Commit cbd4345

Browse files
authored
Improve HTTP response parsing to tolerate extra keys (#526)
the extra keys often originate from non-breaking API changes and mismatching SDK versions
1 parent 3bc15af commit cbd4345

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

geti_sdk/utils/serialization_helpers.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from typing import Any, Dict, Optional, Type, TypeVar, cast
1616

17+
from attr import fields, has
1718
from omegaconf import OmegaConf
1819
from omegaconf.errors import ConfigKeyError, ConfigTypeError, MissingMandatoryValue
1920

@@ -32,15 +33,32 @@ def deserialize_dictionary(
3233
:return: Object of type `output_type`, holding the data passed in
3334
`input_dictionary`.
3435
"""
35-
model_dict_config = OmegaConf.create(input_dictionary)
36+
37+
def prune_dict(data: dict, cls: Type[Any]) -> dict:
38+
"""Recursively prune a dictionary to match the structure of an attr class."""
39+
pruned_data = {}
40+
for attribute in fields(cls):
41+
key = attribute.name
42+
if key in data:
43+
value = data[key]
44+
# Check if the field is itself a structured class
45+
if has(attribute.type) and isinstance(value, dict):
46+
# Recursively prune the nested dictionary
47+
pruned_data[key] = prune_dict(value, attribute.type)
48+
else:
49+
pruned_data[key] = value
50+
return pruned_data
51+
52+
filtered_input_dictionary = prune_dict(input_dictionary, output_type)
53+
model_dict_config = OmegaConf.create(filtered_input_dictionary)
3654
schema = OmegaConf.structured(output_type)
3755
schema_error: Optional[DataModelMismatchException] = None
3856
try:
3957
values = OmegaConf.merge(schema, model_dict_config)
4058
output = cast(output_type, OmegaConf.to_object(values))
4159
except (ConfigKeyError, MissingMandatoryValue, ConfigTypeError) as error:
4260
schema_error = DataModelMismatchException(
43-
input_dictionary=input_dictionary,
61+
input_dictionary=filtered_input_dictionary,
4462
output_data_model=output_type,
4563
message=error.args[0],
4664
error_type=type(error),

tests/pre-merge/unit/utils/test_utils_unit.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@ def test_deserialize_dictionary(self, fxt_project_dictionary: dict):
3030
"""
3131
Verifies that deserializing a dictionary to a python object works.
3232
33-
Also tests that a DataModelMismatchException is raised in case:
34-
1. the input dictionary contains an invalid key
35-
2. the input dictionary misses a required key
33+
The test checks that a DataModelMismatchException is raised in case of a missing key.
34+
It also verifies that the presence of additional keys in the input dictionary is not a problem.
3635
"""
3736

3837
# Arrange
@@ -41,6 +40,11 @@ def test_deserialize_dictionary(self, fxt_project_dictionary: dict):
4140
dictionary_with_extra_key = copy.deepcopy(fxt_project_dictionary)
4241
dictionary_with_extra_key.update({"invalid_key": "invalidness"})
4342

43+
dictionary_with_nested_extra_key = copy.deepcopy(fxt_project_dictionary)
44+
dictionary_with_nested_extra_key["pipeline"].update(
45+
{"invalid_key": "invalidness"}
46+
)
47+
4448
dictionary_with_missing_key = copy.deepcopy(fxt_project_dictionary)
4549
dictionary_with_missing_key.pop("pipeline")
4650

@@ -54,15 +58,19 @@ def test_deserialize_dictionary(self, fxt_project_dictionary: dict):
5458
assert project.get_trainable_tasks()[0].type == TaskType.DETECTION
5559

5660
# Act and assert
57-
with pytest.raises(DataModelMismatchException):
58-
deserialize_dictionary(
59-
input_dictionary=dictionary_with_extra_key, output_type=object_type
60-
)
61+
deserialize_dictionary(
62+
input_dictionary=dictionary_with_extra_key, output_type=object_type
63+
)
64+
65+
# Act and assert
66+
deserialize_dictionary(
67+
input_dictionary=dictionary_with_nested_extra_key, output_type=object_type
68+
)
6169

6270
# Act and assert
6371
with pytest.raises(DataModelMismatchException):
6472
deserialize_dictionary(
65-
input_dictionary=dictionary_with_extra_key, output_type=object_type
73+
input_dictionary=dictionary_with_missing_key, output_type=object_type
6674
)
6775

6876
def test_generate_segmentation_labels(self):

0 commit comments

Comments
 (0)