Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[converter]: convert_field_to_list resolver error #20

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 95 additions & 103 deletions graphene_mongo/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .utils import (
get_field_description,
get_query_fields,
get_queried_union_types,
get_field_is_required,
ExecutorEnum,
sync_to_async,
Expand Down Expand Up @@ -154,7 +155,7 @@ def convert_field_to_list(field, registry=None, executor: ExecutorEnum = Executo
if isinstance(field.field, mongoengine.GenericReferenceField):

def get_reference_objects(*args, **kwargs):
document = get_document(args[0][0])
document = get_document(args[0])
document_field = mongoengine.ReferenceField(document)
document_field = convert_mongoengine_field(document_field, registry)
document_field_type = document_field.get_type().type
Expand All @@ -164,75 +165,70 @@ def get_reference_objects(*args, **kwargs):
for key, values in document_field_type._meta.filter_fields.items():
for each in values:
filter_args.append(key + "__" + each)
for each in get_query_fields(args[0][3][0])[document_field_type._meta.name].keys():
for each in args[4]:
item = to_snake_case(each)
if item in document._fields_ordered + tuple(filter_args):
queried_fields.append(item)
return (
document.objects()
.no_dereference()
.only(*set(list(document_field_type._meta.required_fields) + queried_fields))
.filter(pk__in=args[0][1])
.filter(pk__in=args[1])
)

def get_non_querying_object(*args, **kwargs):
model = get_document(args[0][0])
return [model(pk=each) for each in args[0][1]]
model = get_document(args[0])
return [model(pk=each) for each in args[1]]

def reference_resolver(root, *args, **kwargs):
to_resolve = getattr(root, field.name or field.db_name)
if to_resolve:
choice_to_resolve = dict()
querying_union_types = list(get_query_fields(args[0]).keys())
if "__typename" in querying_union_types:
querying_union_types.remove("__typename")
to_resolve_models = list()
for each in querying_union_types:
if executor == ExecutorEnum.SYNC:
to_resolve_models.append(registry._registry_string_map[each])
else:
to_resolve_models.append(registry._registry_async_string_map[each])
to_resolve_object_ids = list()
for each in to_resolve:
if isinstance(each, LazyReference):
to_resolve_object_ids.append(each.pk)
model = each.document_type._class_name
if model not in choice_to_resolve:
choice_to_resolve[model] = list()
choice_to_resolve[model].append(each.pk)
else:
to_resolve_object_ids.append(each["_ref"].id)
if each["_cls"] not in choice_to_resolve:
choice_to_resolve[each["_cls"]] = list()
choice_to_resolve[each["_cls"]].append(each["_ref"].id)
pool = ThreadPoolExecutor(5)
futures = list()
for model, object_id_list in choice_to_resolve.items():
if model in to_resolve_models:
futures.append(
pool.submit(
get_reference_objects,
(model, object_id_list, registry, args),
)
if not to_resolve:
return None

choice_to_resolve = dict()
querying_union_types = get_queried_union_types(args[0])
to_resolve_models = dict()
for each, queried_fields in querying_union_types.items():
to_resolve_models[registry._registry_string_map[each]] = queried_fields
to_resolve_object_ids = list()
for each in to_resolve:
if isinstance(each, LazyReference):
to_resolve_object_ids.append(each.pk)
model = each.document_type._class_name
if model not in choice_to_resolve:
choice_to_resolve[model] = list()
choice_to_resolve[model].append(each.pk)
else:
to_resolve_object_ids.append(each["_ref"].id)
if each["_cls"] not in choice_to_resolve:
choice_to_resolve[each["_cls"]] = list()
choice_to_resolve[each["_cls"]].append(each["_ref"].id)
pool = ThreadPoolExecutor(5)
futures = list()
for model, object_id_list in choice_to_resolve.items():
if model in to_resolve_models:
queried_fields = to_resolve_models[model]
futures.append(
pool.submit(
get_reference_objects,
*(model, object_id_list, registry, args, queried_fields),
)
else:
futures.append(
pool.submit(
get_non_querying_object,
(model, object_id_list, registry, args),
)
)
else:
futures.append(
pool.submit(
get_non_querying_object,
*(model, object_id_list, registry, args),
)
result = list()
for x in as_completed(futures):
result += x.result()
result_object_ids = list()
for each in result:
result_object_ids.append(each.id)
ordered_result = list()
for each in to_resolve_object_ids:
ordered_result.append(result[result_object_ids.index(each)])
return ordered_result
return None
)
result = list()
for x in as_completed(futures):
result += x.result()
result_object_ids = [each.id for each in result]
ordered_result = [
result[result_object_ids.index(each)] for each in to_resolve_object_ids
]
return ordered_result

async def get_reference_objects_async(*args, **kwargs):
document = get_document(args[0])
Expand All @@ -247,7 +243,7 @@ async def get_reference_objects_async(*args, **kwargs):
for key, values in document_field_type._meta.filter_fields.items():
for each in values:
filter_args.append(key + "__" + each)
for each in get_query_fields(args[3][0])[document_field_type._meta.name].keys():
for each in args[4]:
item = to_snake_case(each)
if item in document._fields_ordered + tuple(filter_args):
queried_fields.append(item)
Expand All @@ -259,57 +255,53 @@ async def get_reference_objects_async(*args, **kwargs):
)

async def get_non_querying_object_async(*args, **kwargs):
model = get_document(args[0])
return [model(pk=each) for each in args[1]]
return get_non_querying_object(*args, **kwargs)

async def reference_resolver_async(root, *args, **kwargs):
to_resolve = getattr(root, field.name or field.db_name)
if to_resolve:
choice_to_resolve = dict()
querying_union_types = list(get_query_fields(args[0]).keys())
if "__typename" in querying_union_types:
querying_union_types.remove("__typename")
to_resolve_models = list()
for each in querying_union_types:
if executor == ExecutorEnum.SYNC:
to_resolve_models.append(registry._registry_string_map[each])
else:
to_resolve_models.append(registry._registry_async_string_map[each])
to_resolve_object_ids = list()
for each in to_resolve:
if isinstance(each, LazyReference):
to_resolve_object_ids.append(each.pk)
model = each.document_type._class_name
if model not in choice_to_resolve:
choice_to_resolve[model] = list()
choice_to_resolve[model].append(each.pk)
else:
to_resolve_object_ids.append(each["_ref"].id)
if each["_cls"] not in choice_to_resolve:
choice_to_resolve[each["_cls"]] = list()
choice_to_resolve[each["_cls"]].append(each["_ref"].id)
loop = asyncio.get_event_loop()
tasks = []
for model, object_id_list in choice_to_resolve.items():
if model in to_resolve_models:
task = loop.create_task(
get_reference_objects_async(model, object_id_list, registry, args)
)
else:
task = loop.create_task(
get_non_querying_object_async(model, object_id_list, registry, args)
if not to_resolve:
return None

choice_to_resolve = dict()
querying_union_types = get_queried_union_types(args[0])
to_resolve_models = dict()
for each, queried_fields in querying_union_types.items():
to_resolve_models[registry._registry_async_string_map[each]] = queried_fields
to_resolve_object_ids = list()
for each in to_resolve:
if isinstance(each, LazyReference):
to_resolve_object_ids.append(each.pk)
model = each.document_type._class_name
if model not in choice_to_resolve:
choice_to_resolve[model] = list()
choice_to_resolve[model].append(each.pk)
else:
to_resolve_object_ids.append(each["_ref"].id)
if each["_cls"] not in choice_to_resolve:
choice_to_resolve[each["_cls"]] = list()
choice_to_resolve[each["_cls"]].append(each["_ref"].id)
loop = asyncio.get_event_loop()
tasks = []
for model, object_id_list in choice_to_resolve.items():
if model in to_resolve_models:
queried_fields = to_resolve_models[model]
task = loop.create_task(
get_reference_objects_async(
model, object_id_list, registry, args, queried_fields
)
tasks.append(task)
result = await asyncio.gather(*tasks)
result_object = {}
for items in result:
for item in items:
result_object[item.id] = item
ordered_result = list()
for each in to_resolve_object_ids:
ordered_result.append(result_object[each])
return ordered_result
return None
)
else:
task = loop.create_task(
get_non_querying_object_async(model, object_id_list, registry, args)
)
tasks.append(task)
result = await asyncio.gather(*tasks)
result_object = {}
for items in result:
for item in items:
result_object[item.id] = item
ordered_result = [result_object[each] for each in to_resolve_object_ids]
return ordered_result

return graphene.List(
base_type._type,
Expand Down
34 changes: 34 additions & 0 deletions graphene_mongo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,40 @@ def get_query_fields(info):
return query


def get_queried_union_types(info):
"""A convenience function to get queried union types with its fields

Args:
info (ResolveInfo)

Returns:
dict[union_type_name, queried_fields(dict)]
"""

fragments = {}
node = ast_to_dict(info.field_nodes[0])
variables = info.variable_values

for name, value in info.fragments.items():
fragments[name] = ast_to_dict(value)

fragments_queries: dict[str, dict] = {}

selection_set = node.get("selection_set") if isinstance(node, dict) else node.selection_set
if selection_set:
for leaf in selection_set.selections:
if leaf.kind == "fragment_spread":
fragment_name = fragments[leaf.name.value].type_condition.name.value
fragments_queries[fragment_name] = collect_query_fields(
fragments[leaf.name.value], fragments, variables
)
elif leaf.kind == "inline_fragment":
fragment_name = leaf.type_condition.name.value
fragments_queries[fragment_name] = collect_query_fields(leaf, fragments, variables)

return fragments_queries


def has_page_info(info):
"""A convenience function to call collect_query_fields with info
for retrieving if page_info details are required
Expand Down
Loading