Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable default ordering in GraphQL queries #5380

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backend/infrahub/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, RelationshipSchema
from infrahub.core.timestamp import Timestamp
from infrahub.exceptions import NodeNotFoundError, ProcessingError, SchemaNotFoundError
from infrahub.graphql.models import OrderModel

if TYPE_CHECKING:
from infrahub.core.branch import Branch
Expand Down Expand Up @@ -141,6 +142,7 @@ async def query(
account=...,
partial_match: bool = ...,
branch_agnostic: bool = ...,
order: OrderModel | None = ...,
) -> list[Any]: ...

@overload
Expand All @@ -161,6 +163,7 @@ async def query(
account=...,
partial_match: bool = ...,
branch_agnostic: bool = ...,
order: OrderModel | None = ...,
) -> list[SchemaProtocol]: ...

@classmethod
Expand All @@ -180,6 +183,7 @@ async def query(
account=None,
partial_match: bool = False,
branch_agnostic: bool = False,
order: OrderModel | None = None,
) -> list[Any]:
"""Query one or multiple nodes of a given type based on filter arguments.
Expand Down Expand Up @@ -227,6 +231,7 @@ async def query(
at=at,
partial_match=partial_match,
branch_agnostic=branch_agnostic,
order=order,
)
await query.execute(db=db)
node_ids = query.get_node_ids()
Expand Down Expand Up @@ -295,6 +300,7 @@ async def count(
at=at,
partial_match=partial_match,
branch_agnostic=branch_agnostic,
order=OrderModel(disable=True),
)
return await query.count(db=db)

Expand Down Expand Up @@ -657,6 +663,7 @@ async def get_one_by_default_filter(
prefetch_relationships=prefetch_relationships,
account=account,
branch_agnostic=branch_agnostic,
order=OrderModel(disable=True),
)

if len(items) > 1:
Expand Down Expand Up @@ -820,6 +827,7 @@ async def get_one_by_hfid(
prefetch_relationships=prefetch_relationships,
account=account,
branch_agnostic=branch_agnostic,
order=OrderModel(disable=True),
)

if len(items) < 1:
Expand Down
8 changes: 7 additions & 1 deletion backend/infrahub/core/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from infrahub.types import ATTRIBUTE_TYPES

from ...graphql.constants import KIND_GRAPHQL_FIELD_NAME
from ...graphql.models import OrderModel
from ..relationship import RelationshipManager
from ..utils import update_relationships_to
from .base import BaseNode, BaseNodeMeta, BaseNodeOptions
Expand Down Expand Up @@ -609,7 +610,12 @@ async def delete(self, db: InfrahubDatabase, at: Optional[Timestamp] = None) ->

# Update the relationship to the branch itself
query = await NodeGetListQuery.init(
db=db, schema=self._schema, filters={"id": self.id}, branch=self._branch, at=delete_at
db=db,
schema=self._schema,
filters={"id": self.id},
branch=self._branch,
at=delete_at,
order=OrderModel(disable=True),
)
await query.execute(db=db)
result = query.get_result()
Expand Down
53 changes: 37 additions & 16 deletions backend/infrahub/core/query/node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import defaultdict
from copy import copy
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from enum import Enum
Expand All @@ -18,6 +19,7 @@
from infrahub.core.schema.attribute_schema import AttributeSchema
from infrahub.core.utils import build_regex_attrs, extract_field_filters
from infrahub.exceptions import QueryError
from infrahub.graphql.models import OrderModel

if TYPE_CHECKING:
from neo4j.graph import Node as Neo4jNode
Expand Down Expand Up @@ -808,14 +810,29 @@ class NodeGetListQuery(Query):
type = QueryType.READ

def __init__(
self, schema: NodeSchema, filters: Optional[dict] = None, partial_match: bool = False, **kwargs: Any
self,
schema: NodeSchema,
filters: Optional[dict] = None,
partial_match: bool = False,
order: OrderModel | None = None,
**kwargs: Any,
) -> None:
self.schema = schema
self.filters = filters
self.partial_match = partial_match
self._variables_to_track = ["n", "rb"]
self._validate_filters()

# Force disabling order when `limit` is 1 as it simplifies the query a lot.
if "limit" in kwargs and kwargs["limit"] == 1:
if order is None:
order = OrderModel(disable=True)
else:
order = copy(order)
order.disable = True

self.order = order

super().__init__(**kwargs)

@property
Expand Down Expand Up @@ -901,8 +918,8 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
self.params["uuid"] = self.filters["id"]
elif not self.has_filters and not self.schema.order_by:
use_simple = True
self.order_by = ["n.uuid"]

if self.order is None or self.order.disable is not True:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic looks surprising to me ...

self.order_by = ["n.uuid"]
if use_simple:
if where_clause_elements:
self.add_to_query(" AND " + " AND ".join(where_clause_elements))
Expand All @@ -917,9 +934,11 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
await self._add_node_filter_attributes(
db=db, field_attribute_requirements=field_attribute_requirements, branch_filter=branch_filter
)
await self._add_node_order_attributes(
db=db, field_attribute_requirements=field_attribute_requirements, branch_filter=branch_filter
)
should_order = self.schema.order_by and (self.order is None or self.order.disable is not True)
if should_order:
await self._add_node_order_attributes(
db=db, field_attribute_requirements=field_attribute_requirements, branch_filter=branch_filter
)

if use_profiles:
await self._add_profiles_per_node_query(db=db, branch_filter=branch_filter)
Expand All @@ -929,15 +948,15 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
await self._add_profile_rollups(field_attribute_requirements=field_attribute_requirements)

self._add_final_filter(field_attribute_requirements=field_attribute_requirements)
self.order_by = []
for far in field_attribute_requirements:
if not far.is_order:
continue
if far.supports_profile:
self.order_by.append(far.final_value_query_variable)
continue
self.order_by.append(far.node_value_query_variable)
self.order_by.append("n.uuid")
if should_order:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even if the order is disabled we should also order by uuid to guarantee the pagination

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I did not have that in mind

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some additional context for the self.order_by.append("n.uuid"): #4704

for far in field_attribute_requirements:
if not far.is_order:
continue
if far.supports_profile:
self.order_by.append(far.final_value_query_variable)
continue
self.order_by.append(far.node_value_query_variable)
self.order_by.append("n.uuid")

async def _add_node_filter_attributes(
self,
Expand Down Expand Up @@ -1206,7 +1225,9 @@ def _get_field_requirements(self) -> list[FieldAttributeRequirement]:
types=[FieldAttributeRequirementType.FILTER],
)
index += 1
if not self.schema.order_by:

disable_order = self.order.disable if self.order is not None else False
if not self.schema.order_by or disable_order:
return list(field_requirements_map.values())

for order_by_path in self.schema.order_by:
Expand Down
6 changes: 5 additions & 1 deletion backend/infrahub/graphql/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class DeleteInput(graphene.InputObjectType):
]


class OrderInput(graphene.InputObjectType):
disable = graphene.Boolean(required=False)


@dataclass
class GraphqlMutations:
create: type[InfrahubMutation]
Expand Down Expand Up @@ -864,7 +868,7 @@ def generate_filters(
dict: A Dictionary containing all the filters with their name as the key and their Type as value
"""

filters: dict[str, Any] = {"offset": graphene.Int(), "limit": graphene.Int()}
filters: dict[str, Any] = {"offset": graphene.Int(), "limit": graphene.Int(), "order": OrderInput()}
default_filters: list[str] = list(filters.keys())

filters["ids"] = graphene.List(graphene.ID)
Expand Down
6 changes: 6 additions & 0 deletions backend/infrahub/graphql/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel


# Corresponds to infrahub.graphql.manager.OrderInput
class OrderModel(BaseModel):
disable: bool | None = None
6 changes: 5 additions & 1 deletion backend/infrahub/graphql/mutations/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from infrahub.database import InfrahubDatabase, retry_db_transaction
from infrahub.exceptions import NodeNotFoundError, PermissionDeniedError

from ..models import OrderModel
from ..types import InfrahubObjectType

if TYPE_CHECKING:
Expand Down Expand Up @@ -112,7 +113,10 @@ async def delete_token(
token_id = str(data.get("id"))

results = await NodeManager.query(
schema=InternalAccountToken, filters={"account_ids": [account.id], "ids": [token_id]}, db=db
schema=InternalAccountToken,
filters={"account_ids": [account.id], "ids": [token_id]},
db=db,
order=OrderModel(disable=True),
)

if not results:
Expand Down
4 changes: 4 additions & 0 deletions backend/infrahub/graphql/resolvers/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from infrahub.core.query.node import NodeGetHierarchyQuery
from infrahub.exceptions import NodeNotFoundError

from ..models import OrderModel
from ..parser import extract_selection
from ..permissions import get_permissions
from ..types import RELATIONS_PROPERTY_MAP, RELATIONS_PROPERTY_MAP_REVERSED
Expand All @@ -33,6 +34,7 @@ async def account_resolver(
filters={"ids": [context.account_session.account_id]},
fields=fields,
db=db,
order=OrderModel(disable=True),
)
if results:
account_profile = await results[0].to_graphql(db=db, fields=fields)
Expand Down Expand Up @@ -132,6 +134,7 @@ async def default_paginated_list_resolver(
info: GraphQLResolveInfo,
offset: int | None = None,
limit: int | None = None,
order: OrderModel | None = None,
partial_match: bool = False,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
Expand Down Expand Up @@ -173,6 +176,7 @@ async def default_paginated_list_resolver(
include_source=True,
include_owner=True,
partial_match=partial_match,
order=order,
)

if "count" in fields:
Expand Down
3 changes: 3 additions & 0 deletions backend/tests/query_benchmark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ async def car_person_schema_root() -> SchemaRoot:
"default_filter": "name__value",
"display_labels": ["name__value", "color__value"],
"uniqueness_constraints": [["name__value"]],
"order_by": ["name__value"],
"branch": BranchSupportType.AWARE.value,
"attributes": [
{"name": "name", "kind": "Text", "unique": True},
Expand Down Expand Up @@ -58,6 +59,7 @@ async def car_person_schema_root() -> SchemaRoot:
"display_labels": ["name__value"],
"branch": BranchSupportType.AWARE.value,
"uniqueness_constraints": [["name__value"]],
"order_by": ["name__value"],
"attributes": [
{"name": "name", "kind": "Text", "unique": True},
{"name": "height", "kind": "Number", "optional": True},
Expand All @@ -82,6 +84,7 @@ async def car_person_schema_root() -> SchemaRoot:
"namespace": "Test",
"default_filter": "name__value",
"display_labels": ["name__value"],
"order_by": ["name__value"],
"branch": BranchSupportType.AWARE.value,
"uniqueness_constraints": [["name__value"]],
"attributes": [
Expand Down
2 changes: 0 additions & 2 deletions backend/tests/query_benchmark/test_diff_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

log = get_logger()

# pytestmark = pytest.mark.skip("Not relevant to test this currently.")


@pytest.mark.timeout(36000) # 10 hours
@pytest.mark.parametrize(
Expand Down
79 changes: 79 additions & 0 deletions backend/tests/query_benchmark/test_node_get_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import inspect
from pathlib import Path

import pytest

from infrahub.core import registry
from infrahub.core.query.node import NodeGetListQuery
from infrahub.database.constants import Neo4jRuntime
from infrahub.log import get_logger
from tests.helpers.constants import NEO4J_ENTERPRISE_IMAGE
from tests.helpers.query_benchmark.benchmark_config import BenchmarkConfig
from tests.helpers.query_benchmark.car_person_generators import (
CarGenerator,
)
from tests.helpers.query_benchmark.data_generator import load_data_and_profile
from tests.query_benchmark.conftest import RESULTS_FOLDER
from tests.query_benchmark.utils import start_db_and_create_default_branch

log = get_logger()


@pytest.mark.timeout(36000) # 10 hours
@pytest.mark.parametrize(
"benchmark_config, ordering",
[
(
BenchmarkConfig(
neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE, load_db_indexes=False
),
False,
),
(
BenchmarkConfig(
neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE, load_db_indexes=False
),
True,
),
],
)
async def test_node_get_list_ordering(
benchmark_config, car_person_schema_root, graph_generator, increase_query_size_limit, ordering
):
# Initialization
db_profiling_queries, default_branch = await start_db_and_create_default_branch(
neo4j_image=benchmark_config.neo4j_image,
load_indexes=benchmark_config.load_db_indexes,
)
registry.schema.register_schema(schema=car_person_schema_root, branch=default_branch.name)

# Build function to profile
async def init_and_execute():
car_node_schema = registry.get_node_schema(name="TestCar", branch=default_branch.name)
query = await NodeGetListQuery.init(
db=db_profiling_queries,
schema=car_node_schema,
branch=default_branch,
ordering=ordering,
)
res = await query.execute(db=db_profiling_queries)
print(f"{len(res.get_node_ids())=}")
return res

nb_cars = 10_000
cars_generator = CarGenerator(db=db_profiling_queries)
test_name = inspect.currentframe().f_code.co_name
module_name = Path(__file__).stem
graph_output_location = RESULTS_FOLDER / module_name / test_name

test_label = str(benchmark_config) + "_ordering_" + str(ordering)

await load_data_and_profile(
data_generator=cars_generator,
func_call=init_and_execute,
profile_frequency=1_000,
nb_elements=nb_cars,
graphs_output_location=graph_output_location,
test_label=test_label,
graph_generator=graph_generator,
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@

log = get_logger()

# pytestmark = pytest.mark.skip("Not relevant to test this currently.")


async def benchmark_uniqueness_query(
query_request,
Expand Down
Loading
Loading