Skip to content

Commit

Permalink
fix[union-convertor]: get_queried_union_types can now handle union fr…
Browse files Browse the repository at this point in the history
…agments
  • Loading branch information
mak626 committed Feb 20, 2025
1 parent b12870f commit f841454
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
24 changes: 14 additions & 10 deletions graphene_mongo/converter.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import asyncio
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import singledispatch

import graphene
import mongoengine

from graphene.types.json import JSONString
from graphene.utils.str_converters import to_snake_case, to_camel_case
from mongoengine.base import get_document, LazyReference
from graphene.utils.str_converters import to_camel_case, to_snake_case
from mongoengine.base import LazyReference, get_document

from . import advanced_types
from .utils import (
ExecutorEnum,
get_field_description,
get_query_fields,
get_queried_union_types,
get_field_is_required,
ExecutorEnum,
get_queried_union_types,
get_query_fields,
sync_to_async,
)
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import singledispatch


class MongoEngineConversionError(Exception):
Expand Down Expand Up @@ -186,7 +186,9 @@ def reference_resolver(root, *args, **kwargs):
return None

choice_to_resolve = dict()
querying_union_types = get_queried_union_types(args[0])
querying_union_types = get_queried_union_types(
args[0], registry._registry_string_map.keys()
)
to_resolve_models = dict()
for each, queried_fields in querying_union_types.items():
to_resolve_models[registry._registry_string_map[each]] = queried_fields
Expand Down Expand Up @@ -263,7 +265,9 @@ async def reference_resolver_async(root, *args, **kwargs):
return None

choice_to_resolve = dict()
querying_union_types = get_queried_union_types(args[0])
querying_union_types = get_queried_union_types(
args[0], registry._registry_async_string_map.keys()
)
to_resolve_models = dict()
for each, queried_fields in querying_union_types.items():
to_resolve_models[registry._registry_async_string_map[each]] = queried_fields
Expand Down
11 changes: 9 additions & 2 deletions graphene_mongo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,12 @@ def get_query_fields(info):
return query


def get_queried_union_types(info):
def get_queried_union_types(info, valid_types):
"""A convenience function to get queried union types with its fields
Args:
info (ResolveInfo)
valid_types (dict_keys)
Returns:
dict[union_type_name, queried_fields(dict)]
Expand All @@ -227,9 +228,15 @@ def get_queried_union_types(info):
for leaf in selection_set.selections:
if leaf.kind == "fragment_spread":
fragment_name = fragments[leaf.name.value].type_condition.name.value
fragments_queries[fragment_name] = collect_query_fields(
sub_query_fields = collect_query_fields(
fragments[leaf.name.value], fragments, variables
)
if fragment_name not in valid_types:
# This is done to avoid UnionFragments coming in fragments_queries as
# we actually need its children types and not the UnionFragments itself
fragments_queries.update(sub_query_fields)
else:
fragments_queries[fragment_name] = sub_query_fields
elif leaf.kind == "inline_fragment":
fragment_name = leaf.type_condition.name.value
fragments_queries[fragment_name] = collect_query_fields(leaf, fragments, variables)
Expand Down

0 comments on commit f841454

Please sign in to comment.