diff --git a/backend/infrahub/core/manager.py b/backend/infrahub/core/manager.py index 4758a59ba3..b3e59af753 100644 --- a/backend/infrahub/core/manager.py +++ b/backend/infrahub/core/manager.py @@ -26,6 +26,7 @@ from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, RelationshipSchema from infrahub.core.timestamp import Timestamp from infrahub.exceptions import NodeNotFoundError, ProcessingError, SchemaNotFoundError +from infrahub.graphql.models import OrderModel if TYPE_CHECKING: from infrahub.core.branch import Branch @@ -141,6 +142,7 @@ async def query( account=..., partial_match: bool = ..., branch_agnostic: bool = ..., + order: OrderModel | None = ..., ) -> list[Any]: ... @overload @@ -161,6 +163,7 @@ async def query( account=..., partial_match: bool = ..., branch_agnostic: bool = ..., + order: OrderModel | None = ..., ) -> list[SchemaProtocol]: ... @classmethod @@ -180,6 +183,7 @@ async def query( account=None, partial_match: bool = False, branch_agnostic: bool = False, + order: OrderModel | None = None, ) -> list[Any]: """Query one or multiple nodes of a given type based on filter arguments. @@ -227,6 +231,7 @@ async def query( at=at, partial_match=partial_match, branch_agnostic=branch_agnostic, + order=order, ) await query.execute(db=db) node_ids = query.get_node_ids() @@ -295,6 +300,7 @@ async def count( at=at, partial_match=partial_match, branch_agnostic=branch_agnostic, + order=OrderModel(disable=True), ) return await query.count(db=db) @@ -657,6 +663,7 @@ async def get_one_by_default_filter( prefetch_relationships=prefetch_relationships, account=account, branch_agnostic=branch_agnostic, + order=OrderModel(disable=True), ) if len(items) > 1: @@ -820,6 +827,7 @@ async def get_one_by_hfid( prefetch_relationships=prefetch_relationships, account=account, branch_agnostic=branch_agnostic, + order=OrderModel(disable=True), ) if len(items) < 1: diff --git a/backend/infrahub/core/node/__init__.py b/backend/infrahub/core/node/__init__.py index 36d9b0c8bf..d1a8189c3e 100644 --- a/backend/infrahub/core/node/__init__.py +++ b/backend/infrahub/core/node/__init__.py @@ -18,6 +18,7 @@ from infrahub.types import ATTRIBUTE_TYPES from ...graphql.constants import KIND_GRAPHQL_FIELD_NAME +from ...graphql.models import OrderModel from ..relationship import RelationshipManager from ..utils import update_relationships_to from .base import BaseNode, BaseNodeMeta, BaseNodeOptions @@ -609,7 +610,12 @@ async def delete(self, db: InfrahubDatabase, at: Optional[Timestamp] = None) -> # Update the relationship to the branch itself query = await NodeGetListQuery.init( - db=db, schema=self._schema, filters={"id": self.id}, branch=self._branch, at=delete_at + db=db, + schema=self._schema, + filters={"id": self.id}, + branch=self._branch, + at=delete_at, + order=OrderModel(disable=True), ) await query.execute(db=db) result = query.get_result() diff --git a/backend/infrahub/core/query/node.py b/backend/infrahub/core/query/node.py index e16e64193b..84925da6bf 100644 --- a/backend/infrahub/core/query/node.py +++ b/backend/infrahub/core/query/node.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections import defaultdict +from copy import copy from dataclasses import dataclass from dataclasses import field as dataclass_field from enum import Enum @@ -14,6 +15,7 @@ from infrahub.core.schema.attribute_schema import AttributeSchema from infrahub.core.utils import build_regex_attrs, extract_field_filters from infrahub.exceptions import QueryError +from infrahub.graphql.models import OrderModel if TYPE_CHECKING: from neo4j.graph import Node as Neo4jNode @@ -808,16 +810,25 @@ def __init__( schema: NodeSchema, filters: Optional[dict] = None, partial_match: bool = False, - ordering: bool = True, + order: OrderModel | None = None, **kwargs: Any, ) -> None: self.schema = schema self.filters = filters self.partial_match = partial_match self._variables_to_track = ["n", "rb"] - self.ordering = ordering self._validate_filters() + # Force disabling order when `limit` is 1 as it simplifies the query a lot. + if "limit" in kwargs and kwargs["limit"] == 1: + if order is None: + order = OrderModel(disable=True) + else: + order = copy(order) + order.disable = True + + self.order = order + super().__init__(**kwargs) def _validate_filters(self) -> None: @@ -878,7 +889,8 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: self.params["uuid"] = self.filters["id"] if not self.filters and not self.schema.order_by: use_simple = True - self.order_by = ["n.uuid"] + if self.order is None or self.order.disable is not True: + self.order_by = ["n.uuid"] if use_simple: if where_clause_elements: self.add_to_query(" AND " + " AND ".join(where_clause_elements)) @@ -893,9 +905,11 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: await self._add_node_filter_attributes( db=db, field_attribute_requirements=field_attribute_requirements, branch_filter=branch_filter ) - await self._add_node_order_attributes( - db=db, field_attribute_requirements=field_attribute_requirements, branch_filter=branch_filter - ) + should_order = self.schema.order_by and (self.order is None or self.order.disable is not True) + if should_order: + await self._add_node_order_attributes( + db=db, field_attribute_requirements=field_attribute_requirements, branch_filter=branch_filter + ) if use_profiles: await self._add_profiles_per_node_query(db=db, branch_filter=branch_filter) @@ -905,15 +919,15 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: await self._add_profile_rollups(field_attribute_requirements=field_attribute_requirements) self._add_final_filter(field_attribute_requirements=field_attribute_requirements) - self.order_by = [] - for far in field_attribute_requirements: - if not far.is_order: - continue - if far.supports_profile: - self.order_by.append(far.final_value_query_variable) - continue - self.order_by.append(far.node_value_query_variable) - self.order_by.append("n.uuid") + if should_order: + for far in field_attribute_requirements: + if not far.is_order: + continue + if far.supports_profile: + self.order_by.append(far.final_value_query_variable) + continue + self.order_by.append(far.node_value_query_variable) + self.order_by.append("n.uuid") async def _add_node_filter_attributes( self, @@ -1182,7 +1196,9 @@ def _get_field_requirements(self) -> list[FieldAttributeRequirement]: types=[FieldAttributeRequirementType.FILTER], ) index += 1 - if not self.schema.order_by or not self.ordering: + + disable_order = self.order.disable if self.order is not None else False + if not self.schema.order_by or disable_order: return list(field_requirements_map.values()) for order_by_path in self.schema.order_by: diff --git a/backend/infrahub/graphql/manager.py b/backend/infrahub/graphql/manager.py index 4a1b0a92b3..60c5103e5c 100644 --- a/backend/infrahub/graphql/manager.py +++ b/backend/infrahub/graphql/manager.py @@ -79,6 +79,10 @@ class DeleteInput(graphene.InputObjectType): ] +class OrderInput(graphene.InputObjectType): + disable = graphene.Boolean(required=False) + + @dataclass class GraphqlMutations: create: type[InfrahubMutation] @@ -864,7 +868,7 @@ def generate_filters( dict: A Dictionary containing all the filters with their name as the key and their Type as value """ - filters: dict[str, Any] = {"offset": graphene.Int(), "limit": graphene.Int()} + filters: dict[str, Any] = {"offset": graphene.Int(), "limit": graphene.Int(), "order": OrderInput()} default_filters: list[str] = list(filters.keys()) filters["ids"] = graphene.List(graphene.ID) diff --git a/backend/infrahub/graphql/models.py b/backend/infrahub/graphql/models.py new file mode 100644 index 0000000000..8346b4054e --- /dev/null +++ b/backend/infrahub/graphql/models.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +# Corresponds to infrahub.graphql.manager.OrderInput +class OrderModel(BaseModel): + disable: bool | None = None diff --git a/backend/infrahub/graphql/mutations/account.py b/backend/infrahub/graphql/mutations/account.py index ff97853eb1..e116fcb68f 100644 --- a/backend/infrahub/graphql/mutations/account.py +++ b/backend/infrahub/graphql/mutations/account.py @@ -15,6 +15,7 @@ from infrahub.database import InfrahubDatabase, retry_db_transaction from infrahub.exceptions import NodeNotFoundError, PermissionDeniedError +from ..models import OrderModel from ..types import InfrahubObjectType if TYPE_CHECKING: @@ -112,7 +113,10 @@ async def delete_token( token_id = str(data.get("id")) results = await NodeManager.query( - schema=InternalAccountToken, filters={"account_ids": [account.id], "ids": [token_id]}, db=db + schema=InternalAccountToken, + filters={"account_ids": [account.id], "ids": [token_id]}, + db=db, + order=OrderModel(disable=True), ) if not results: diff --git a/backend/infrahub/graphql/resolvers/resolver.py b/backend/infrahub/graphql/resolvers/resolver.py index 86e4b01671..895dd8a79e 100644 --- a/backend/infrahub/graphql/resolvers/resolver.py +++ b/backend/infrahub/graphql/resolvers/resolver.py @@ -9,6 +9,7 @@ from infrahub.core.query.node import NodeGetHierarchyQuery from infrahub.exceptions import NodeNotFoundError +from ..models import OrderModel from ..parser import extract_selection from ..permissions import get_permissions from ..types import RELATIONS_PROPERTY_MAP, RELATIONS_PROPERTY_MAP_REVERSED @@ -33,6 +34,7 @@ async def account_resolver( filters={"ids": [context.account_session.account_id]}, fields=fields, db=db, + order=OrderModel(disable=True), ) if results: account_profile = await results[0].to_graphql(db=db, fields=fields) @@ -132,6 +134,7 @@ async def default_paginated_list_resolver( info: GraphQLResolveInfo, offset: int | None = None, limit: int | None = None, + order: OrderModel | None = None, partial_match: bool = False, **kwargs: dict[str, Any], ) -> dict[str, Any]: @@ -173,6 +176,7 @@ async def default_paginated_list_resolver( include_source=True, include_owner=True, partial_match=partial_match, + order=order, ) if "count" in fields: diff --git a/backend/tests/helpers/test_app.py b/backend/tests/helpers/test_app.py index 975e7c8053..c30aa5b8c1 100644 --- a/backend/tests/helpers/test_app.py +++ b/backend/tests/helpers/test_app.py @@ -67,7 +67,7 @@ def bus_simulator(self, db: InfrahubDatabase) -> Generator[BusSimulator, None, N config.OVERRIDE.message_bus = original @pytest.fixture(scope="class", autouse=True) - async def workflow_local(self, prefect: Generator[str, None, None]) -> AsyncGenerator[WorkflowLocalExecution, None]: + async def workflow_local(self) -> AsyncGenerator[WorkflowLocalExecution, None]: original = config.OVERRIDE.workflow workflow = WorkflowLocalExecution() await setup_task_manager() diff --git a/backend/tests/query_benchmark/conftest.py b/backend/tests/query_benchmark/conftest.py index c2b01a439a..2146b5bd17 100644 --- a/backend/tests/query_benchmark/conftest.py +++ b/backend/tests/query_benchmark/conftest.py @@ -21,6 +21,7 @@ async def car_person_schema_root() -> SchemaRoot: "default_filter": "name__value", "display_labels": ["name__value", "color__value"], "uniqueness_constraints": [["name__value"]], + "order_by": ["name__value"], "branch": BranchSupportType.AWARE.value, "attributes": [ {"name": "name", "kind": "Text", "unique": True}, @@ -58,6 +59,7 @@ async def car_person_schema_root() -> SchemaRoot: "display_labels": ["name__value"], "branch": BranchSupportType.AWARE.value, "uniqueness_constraints": [["name__value"]], + "order_by": ["name__value"], "attributes": [ {"name": "name", "kind": "Text", "unique": True}, {"name": "height", "kind": "Number", "optional": True}, @@ -82,6 +84,7 @@ async def car_person_schema_root() -> SchemaRoot: "namespace": "Test", "default_filter": "name__value", "display_labels": ["name__value"], + "order_by": ["name__value"], "branch": BranchSupportType.AWARE.value, "uniqueness_constraints": [["name__value"]], "attributes": [ diff --git a/backend/tests/query_benchmark/test_diff_query.py b/backend/tests/query_benchmark/test_diff_query.py index a80d35ee1e..cec289709b 100644 --- a/backend/tests/query_benchmark/test_diff_query.py +++ b/backend/tests/query_benchmark/test_diff_query.py @@ -20,8 +20,6 @@ log = get_logger() -# pytestmark = pytest.mark.skip("Not relevant to test this currently.") - @pytest.mark.timeout(36000) # 10 hours @pytest.mark.parametrize( diff --git a/backend/tests/query_benchmark/test_node_get_list.py b/backend/tests/query_benchmark/test_node_get_list.py new file mode 100644 index 0000000000..1d9a5e15c1 --- /dev/null +++ b/backend/tests/query_benchmark/test_node_get_list.py @@ -0,0 +1,79 @@ +import inspect +from pathlib import Path + +import pytest + +from infrahub.core import registry +from infrahub.core.query.node import NodeGetListQuery +from infrahub.database.constants import Neo4jRuntime +from infrahub.log import get_logger +from tests.helpers.constants import NEO4J_ENTERPRISE_IMAGE +from tests.helpers.query_benchmark.benchmark_config import BenchmarkConfig +from tests.helpers.query_benchmark.car_person_generators import ( + CarGenerator, +) +from tests.helpers.query_benchmark.data_generator import load_data_and_profile +from tests.query_benchmark.conftest import RESULTS_FOLDER +from tests.query_benchmark.utils import start_db_and_create_default_branch + +log = get_logger() + + +@pytest.mark.timeout(36000) # 10 hours +@pytest.mark.parametrize( + "benchmark_config, ordering", + [ + ( + BenchmarkConfig( + neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE, load_db_indexes=False + ), + False, + ), + ( + BenchmarkConfig( + neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE, load_db_indexes=False + ), + True, + ), + ], +) +async def test_node_get_list_ordering( + benchmark_config, car_person_schema_root, graph_generator, increase_query_size_limit, ordering +): + # Initialization + db_profiling_queries, default_branch = await start_db_and_create_default_branch( + neo4j_image=benchmark_config.neo4j_image, + load_indexes=benchmark_config.load_db_indexes, + ) + registry.schema.register_schema(schema=car_person_schema_root, branch=default_branch.name) + + # Build function to profile + async def init_and_execute(): + car_node_schema = registry.get_node_schema(name="TestCar", branch=default_branch.name) + query = await NodeGetListQuery.init( + db=db_profiling_queries, + schema=car_node_schema, + branch=default_branch, + ordering=ordering, + ) + res = await query.execute(db=db_profiling_queries) + print(f"{len(res.get_node_ids())=}") + return res + + nb_cars = 10_000 + cars_generator = CarGenerator(db=db_profiling_queries) + test_name = inspect.currentframe().f_code.co_name + module_name = Path(__file__).stem + graph_output_location = RESULTS_FOLDER / module_name / test_name + + test_label = str(benchmark_config) + "_ordering_" + str(ordering) + + await load_data_and_profile( + data_generator=cars_generator, + func_call=init_and_execute, + profile_frequency=1_000, + nb_elements=nb_cars, + graphs_output_location=graph_output_location, + test_label=test_label, + graph_generator=graph_generator, + ) diff --git a/backend/tests/query_benchmark/test_node_unique_attribute_constraint.py b/backend/tests/query_benchmark/test_node_unique_attribute_constraint.py index c1c17f4607..d70e9c536e 100644 --- a/backend/tests/query_benchmark/test_node_unique_attribute_constraint.py +++ b/backend/tests/query_benchmark/test_node_unique_attribute_constraint.py @@ -25,8 +25,6 @@ log = get_logger() -# pytestmark = pytest.mark.skip("Not relevant to test this currently.") - async def benchmark_uniqueness_query( query_request, diff --git a/backend/tests/unit/core/test_node_get_list_query.py b/backend/tests/unit/core/test_node_get_list_query.py index bdb40197e7..ddc552d5dc 100644 --- a/backend/tests/unit/core/test_node_get_list_query.py +++ b/backend/tests/unit/core/test_node_get_list_query.py @@ -16,6 +16,7 @@ from infrahub.core.schema import SchemaRoot from infrahub.core.schema.relationship_schema import RelationshipSchema from infrahub.database import InfrahubDatabase +from infrahub.graphql.models import OrderModel from tests.helpers.schema import WIDGET @@ -384,10 +385,10 @@ async def test_query_NodeGetListQuery_order_by_disabled( db=db, branch=branch, schema=schema, - ordering=False, + order=OrderModel(disable=True), ) await query.execute(db=db) - assert query.get_node_ids() == sorted([car_camry_main.id, car_yaris_main.id, car_accord_main.id, car_volt_main.id]) + assert set(query.get_node_ids()) == {car_camry_main.id, car_yaris_main.id, car_accord_main.id, car_volt_main.id} async def test_query_NodeGetListQuery_order_by_optional_relationship_nulls( diff --git a/backend/tests/unit/graphql/queries/test_order.py b/backend/tests/unit/graphql/queries/test_order.py new file mode 100644 index 0000000000..703ae1934f --- /dev/null +++ b/backend/tests/unit/graphql/queries/test_order.py @@ -0,0 +1,37 @@ +from infrahub.core.branch import Branch +from infrahub.core.node import Node +from infrahub.database import InfrahubDatabase +from tests.helpers.graphql import graphql_query +from tests.helpers.test_app import TestInfrahubApp + + +class TestQueryOrder(TestInfrahubApp): + async def test_query_default_order( + self, db: InfrahubDatabase, default_branch: Branch, register_core_models_schema, session_admin, client + ): + for i in range(5, 0, -1): + node = await Node.init(db=db, schema="BuiltinTag") + await node.new(db=db, name=f"tag-{i}") + await node.save(db=db) + + for disable_order in [True, False, None]: + variables = {"order": {"disable": disable_order}} if disable_order is not None else {"order": None} + + query = """ + query($order: OrderInput) { + BuiltinTag(order: $order) { + edges { + node { + name { value } + } + } + } + } + """ + + res = await graphql_query(query=query, db=db, branch=default_branch, variables=variables) + + node_names = [edge["node"]["name"]["value"] for edge in res.data["BuiltinTag"]["edges"]] + if disable_order is True: + node_names = sorted(node_names) + assert node_names == [f"tag-{i}" for i in range(1, 6)] diff --git a/backend/tests/unit/graphql/test_manager.py b/backend/tests/unit/graphql/test_manager.py index 6928414cad..1adc045336 100644 --- a/backend/tests/unit/graphql/test_manager.py +++ b/backend/tests/unit/graphql/test_manager.py @@ -181,6 +181,7 @@ async def test_generate_filters(db: InfrahubDatabase, default_branch: Branch, da expected_filters = [ "offset", "limit", + "order", "partial_match", "ids", "any__is_protected", diff --git a/changelog/add.5376.md b/changelog/add.5376.md new file mode 100644 index 0000000000..8be80da7c7 --- /dev/null +++ b/changelog/add.5376.md @@ -0,0 +1 @@ +Add `order` parameter to GraphQL queries so nodes ordering can be disabled to improve performances