From f841454e1c2c25ee711e5c9a17518e296850d644 Mon Sep 17 00:00:00 2001 From: M Aswin Kishore <60577077+mak626@users.noreply.github.com> Date: Fri, 21 Feb 2025 00:53:43 +0530 Subject: [PATCH] fix[union-convertor]: get_queried_union_types can now handle union fragments --- graphene_mongo/converter.py | 24 ++++++++++++++---------- graphene_mongo/utils.py | 11 +++++++++-- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/graphene_mongo/converter.py b/graphene_mongo/converter.py index b07bc71..b2a2fb0 100644 --- a/graphene_mongo/converter.py +++ b/graphene_mongo/converter.py @@ -1,23 +1,23 @@ import asyncio import sys +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import singledispatch import graphene import mongoengine - from graphene.types.json import JSONString -from graphene.utils.str_converters import to_snake_case, to_camel_case -from mongoengine.base import get_document, LazyReference +from graphene.utils.str_converters import to_camel_case, to_snake_case +from mongoengine.base import LazyReference, get_document + from . import advanced_types from .utils import ( + ExecutorEnum, get_field_description, - get_query_fields, - get_queried_union_types, get_field_is_required, - ExecutorEnum, + get_queried_union_types, + get_query_fields, sync_to_async, ) -from concurrent.futures import ThreadPoolExecutor, as_completed -from functools import singledispatch class MongoEngineConversionError(Exception): @@ -186,7 +186,9 @@ def reference_resolver(root, *args, **kwargs): return None choice_to_resolve = dict() - querying_union_types = get_queried_union_types(args[0]) + querying_union_types = get_queried_union_types( + args[0], registry._registry_string_map.keys() + ) to_resolve_models = dict() for each, queried_fields in querying_union_types.items(): to_resolve_models[registry._registry_string_map[each]] = queried_fields @@ -263,7 +265,9 @@ async def reference_resolver_async(root, *args, **kwargs): return None choice_to_resolve = dict() - querying_union_types = get_queried_union_types(args[0]) + querying_union_types = get_queried_union_types( + args[0], registry._registry_async_string_map.keys() + ) to_resolve_models = dict() for each, queried_fields in querying_union_types.items(): to_resolve_models[registry._registry_async_string_map[each]] = queried_fields diff --git a/graphene_mongo/utils.py b/graphene_mongo/utils.py index e5fd457..c0910ea 100644 --- a/graphene_mongo/utils.py +++ b/graphene_mongo/utils.py @@ -203,11 +203,12 @@ def get_query_fields(info): return query -def get_queried_union_types(info): +def get_queried_union_types(info, valid_types): """A convenience function to get queried union types with its fields Args: info (ResolveInfo) + valid_types (dict_keys) Returns: dict[union_type_name, queried_fields(dict)] @@ -227,9 +228,15 @@ def get_queried_union_types(info): for leaf in selection_set.selections: if leaf.kind == "fragment_spread": fragment_name = fragments[leaf.name.value].type_condition.name.value - fragments_queries[fragment_name] = collect_query_fields( + sub_query_fields = collect_query_fields( fragments[leaf.name.value], fragments, variables ) + if fragment_name not in valid_types: + # This is done to avoid UnionFragments coming in fragments_queries as + # we actually need its children types and not the UnionFragments itself + fragments_queries.update(sub_query_fields) + else: + fragments_queries[fragment_name] = sub_query_fields elif leaf.kind == "inline_fragment": fragment_name = leaf.type_condition.name.value fragments_queries[fragment_name] = collect_query_fields(leaf, fragments, variables)