From 772f7472c2e923fd5d00545801cc834cf3c3f97b Mon Sep 17 00:00:00 2001 From: Lucas Guillermou Date: Fri, 3 Jan 2025 16:55:48 +0100 Subject: [PATCH] Add node_get_list benchmark --- backend/infrahub/core/manager.py | 4 + backend/infrahub/core/query/node.py | 8 +- backend/infrahub/graphql/manager.py | 6 +- .../infrahub/graphql/resolvers/resolver.py | 2 + backend/tests/helpers/test_app.py | 2 +- backend/tests/query_benchmark/conftest.py | 3 + .../query_benchmark/test_node_get_list.py | 81 +++++++++++++++++++ .../unit/core/test_node_get_list_query.py | 2 +- .../tests/unit/graphql/queries/test_order.py | 50 ++++++++++++ backend/tests/unit/graphql/test_manager.py | 1 + 10 files changed, 153 insertions(+), 6 deletions(-) create mode 100644 backend/tests/query_benchmark/test_node_get_list.py create mode 100644 backend/tests/unit/graphql/queries/test_order.py diff --git a/backend/infrahub/core/manager.py b/backend/infrahub/core/manager.py index 4758a59ba3..d4dc4c875c 100644 --- a/backend/infrahub/core/manager.py +++ b/backend/infrahub/core/manager.py @@ -141,6 +141,7 @@ async def query( account=..., partial_match: bool = ..., branch_agnostic: bool = ..., + order: dict[str, Any] | None = ..., ) -> list[Any]: ... @overload @@ -161,6 +162,7 @@ async def query( account=..., partial_match: bool = ..., branch_agnostic: bool = ..., + order: dict[str, Any] | None = ..., ) -> list[SchemaProtocol]: ... @classmethod @@ -180,6 +182,7 @@ async def query( account=None, partial_match: bool = False, branch_agnostic: bool = False, + order: dict[str, Any] | None = None, ) -> list[Any]: """Query one or multiple nodes of a given type based on filter arguments. @@ -227,6 +230,7 @@ async def query( at=at, partial_match=partial_match, branch_agnostic=branch_agnostic, + order=order, ) await query.execute(db=db) node_ids = query.get_node_ids() diff --git a/backend/infrahub/core/query/node.py b/backend/infrahub/core/query/node.py index e16e64193b..cf49316386 100644 --- a/backend/infrahub/core/query/node.py +++ b/backend/infrahub/core/query/node.py @@ -808,14 +808,14 @@ def __init__( schema: NodeSchema, filters: Optional[dict] = None, partial_match: bool = False, - ordering: bool = True, + order: dict[str, Any] | None = None, **kwargs: Any, ) -> None: self.schema = schema self.filters = filters self.partial_match = partial_match self._variables_to_track = ["n", "rb"] - self.ordering = ordering + self.order = order self._validate_filters() super().__init__(**kwargs) @@ -1182,7 +1182,9 @@ def _get_field_requirements(self) -> list[FieldAttributeRequirement]: types=[FieldAttributeRequirementType.FILTER], ) index += 1 - if not self.schema.order_by or not self.ordering: + + disable_order = self.order["disable"] if self.order is not None and "disable" in self.order else False + if not self.schema.order_by or disable_order: return list(field_requirements_map.values()) for order_by_path in self.schema.order_by: diff --git a/backend/infrahub/graphql/manager.py b/backend/infrahub/graphql/manager.py index 4a1b0a92b3..60c5103e5c 100644 --- a/backend/infrahub/graphql/manager.py +++ b/backend/infrahub/graphql/manager.py @@ -79,6 +79,10 @@ class DeleteInput(graphene.InputObjectType): ] +class OrderInput(graphene.InputObjectType): + disable = graphene.Boolean(required=False) + + @dataclass class GraphqlMutations: create: type[InfrahubMutation] @@ -864,7 +868,7 @@ def generate_filters( dict: A Dictionary containing all the filters with their name as the key and their Type as value """ - filters: dict[str, Any] = {"offset": graphene.Int(), "limit": graphene.Int()} + filters: dict[str, Any] = {"offset": graphene.Int(), "limit": graphene.Int(), "order": OrderInput()} default_filters: list[str] = list(filters.keys()) filters["ids"] = graphene.List(graphene.ID) diff --git a/backend/infrahub/graphql/resolvers/resolver.py b/backend/infrahub/graphql/resolvers/resolver.py index 86e4b01671..f91de8cc01 100644 --- a/backend/infrahub/graphql/resolvers/resolver.py +++ b/backend/infrahub/graphql/resolvers/resolver.py @@ -132,6 +132,7 @@ async def default_paginated_list_resolver( info: GraphQLResolveInfo, offset: int | None = None, limit: int | None = None, + order: dict[str, Any] | None = None, partial_match: bool = False, **kwargs: dict[str, Any], ) -> dict[str, Any]: @@ -173,6 +174,7 @@ async def default_paginated_list_resolver( include_source=True, include_owner=True, partial_match=partial_match, + order=order, ) if "count" in fields: diff --git a/backend/tests/helpers/test_app.py b/backend/tests/helpers/test_app.py index 975e7c8053..c30aa5b8c1 100644 --- a/backend/tests/helpers/test_app.py +++ b/backend/tests/helpers/test_app.py @@ -67,7 +67,7 @@ def bus_simulator(self, db: InfrahubDatabase) -> Generator[BusSimulator, None, N config.OVERRIDE.message_bus = original @pytest.fixture(scope="class", autouse=True) - async def workflow_local(self, prefect: Generator[str, None, None]) -> AsyncGenerator[WorkflowLocalExecution, None]: + async def workflow_local(self) -> AsyncGenerator[WorkflowLocalExecution, None]: original = config.OVERRIDE.workflow workflow = WorkflowLocalExecution() await setup_task_manager() diff --git a/backend/tests/query_benchmark/conftest.py b/backend/tests/query_benchmark/conftest.py index c2b01a439a..2146b5bd17 100644 --- a/backend/tests/query_benchmark/conftest.py +++ b/backend/tests/query_benchmark/conftest.py @@ -21,6 +21,7 @@ async def car_person_schema_root() -> SchemaRoot: "default_filter": "name__value", "display_labels": ["name__value", "color__value"], "uniqueness_constraints": [["name__value"]], + "order_by": ["name__value"], "branch": BranchSupportType.AWARE.value, "attributes": [ {"name": "name", "kind": "Text", "unique": True}, @@ -58,6 +59,7 @@ async def car_person_schema_root() -> SchemaRoot: "display_labels": ["name__value"], "branch": BranchSupportType.AWARE.value, "uniqueness_constraints": [["name__value"]], + "order_by": ["name__value"], "attributes": [ {"name": "name", "kind": "Text", "unique": True}, {"name": "height", "kind": "Number", "optional": True}, @@ -82,6 +84,7 @@ async def car_person_schema_root() -> SchemaRoot: "namespace": "Test", "default_filter": "name__value", "display_labels": ["name__value"], + "order_by": ["name__value"], "branch": BranchSupportType.AWARE.value, "uniqueness_constraints": [["name__value"]], "attributes": [ diff --git a/backend/tests/query_benchmark/test_node_get_list.py b/backend/tests/query_benchmark/test_node_get_list.py new file mode 100644 index 0000000000..fa762e31c9 --- /dev/null +++ b/backend/tests/query_benchmark/test_node_get_list.py @@ -0,0 +1,81 @@ +import inspect +from pathlib import Path + +import pytest + +from infrahub.core import registry +from infrahub.core.query.node import NodeGetListQuery +from infrahub.database.constants import Neo4jRuntime +from infrahub.log import get_logger +from tests.helpers.constants import NEO4J_ENTERPRISE_IMAGE +from tests.helpers.query_benchmark.benchmark_config import BenchmarkConfig +from tests.helpers.query_benchmark.car_person_generators import ( + CarGenerator, +) +from tests.helpers.query_benchmark.data_generator import load_data_and_profile +from tests.query_benchmark.conftest import RESULTS_FOLDER +from tests.query_benchmark.utils import start_db_and_create_default_branch + +log = get_logger() + +# pytestmark = pytest.mark.skip("Not relevant to test this currently.") + + +@pytest.mark.timeout(36000) # 10 hours +@pytest.mark.parametrize( + "benchmark_config, ordering", + [ + ( + BenchmarkConfig( + neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE, load_db_indexes=False + ), + False, + ), + ( + BenchmarkConfig( + neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE, load_db_indexes=False + ), + True, + ), + ], +) +async def test_node_get_list_ordering( + benchmark_config, car_person_schema_root, graph_generator, increase_query_size_limit, ordering +): + # Initialization + db_profiling_queries, default_branch = await start_db_and_create_default_branch( + neo4j_image=benchmark_config.neo4j_image, + load_indexes=benchmark_config.load_db_indexes, + ) + registry.schema.register_schema(schema=car_person_schema_root, branch=default_branch.name) + + # Build function to profile + async def init_and_execute(): + car_node_schema = registry.get_node_schema(name="TestCar", branch=default_branch.name) + query = await NodeGetListQuery.init( + db=db_profiling_queries, + schema=car_node_schema, + branch=default_branch, + ordering=ordering, + ) + res = await query.execute(db=db_profiling_queries) + print(f"{len(res.get_node_ids())=}") + return res + + nb_cars = 10_000 + cars_generator = CarGenerator(db=db_profiling_queries) + test_name = inspect.currentframe().f_code.co_name + module_name = Path(__file__).stem + graph_output_location = RESULTS_FOLDER / module_name / test_name + + test_label = str(benchmark_config) + "_ordering_" + str(ordering) + + await load_data_and_profile( + data_generator=cars_generator, + func_call=init_and_execute, + profile_frequency=1_000, + nb_elements=nb_cars, + graphs_output_location=graph_output_location, + test_label=test_label, + graph_generator=graph_generator, + ) diff --git a/backend/tests/unit/core/test_node_get_list_query.py b/backend/tests/unit/core/test_node_get_list_query.py index bdb40197e7..27e4993d30 100644 --- a/backend/tests/unit/core/test_node_get_list_query.py +++ b/backend/tests/unit/core/test_node_get_list_query.py @@ -384,7 +384,7 @@ async def test_query_NodeGetListQuery_order_by_disabled( db=db, branch=branch, schema=schema, - ordering=False, + order={"disable": True}, ) await query.execute(db=db) assert query.get_node_ids() == sorted([car_camry_main.id, car_yaris_main.id, car_accord_main.id, car_volt_main.id]) diff --git a/backend/tests/unit/graphql/queries/test_order.py b/backend/tests/unit/graphql/queries/test_order.py new file mode 100644 index 0000000000..10fa644b50 --- /dev/null +++ b/backend/tests/unit/graphql/queries/test_order.py @@ -0,0 +1,50 @@ +from infrahub.core.branch import Branch +from infrahub.core.node import Node +from infrahub.database import InfrahubDatabase +from tests.helpers.graphql import graphql_query +from tests.helpers.test_app import TestInfrahubApp + + +class TestQueryOrder(TestInfrahubApp): + async def test_query_default_order( + self, db: InfrahubDatabase, default_branch: Branch, register_core_models_schema, session_admin, client + ): + for i in range(5, 0, -1): + node = await Node.init(db=db, schema="BuiltinTag") + await node.new(db=db, name=f"tag-{i}") + await node.save(db=db) + + for disable_order in [True, False, None]: + if disable_order is True: + order_param_query = "( order: { disable: true } )" + elif disable_order is False: + order_param_query = "( order: { disable: false } )" + else: + order_param_query = "" + + query = ( + """ + query { + BuiltinTag""" + + order_param_query + + """ { + edges { + node { + name { value } + } + } + } + } + """ + ) + + res = await graphql_query( + query=query, + db=db, + branch=default_branch, + ) + + node_names = [edge["node"]["name"]["value"] for edge in res.data["BuiltinTag"]["edges"]] + if disable_order is True: + node_names = sorted(node_names) + assert node_names == [f"tag-{i}" for i in range(1, 6)] diff --git a/backend/tests/unit/graphql/test_manager.py b/backend/tests/unit/graphql/test_manager.py index 6928414cad..1adc045336 100644 --- a/backend/tests/unit/graphql/test_manager.py +++ b/backend/tests/unit/graphql/test_manager.py @@ -181,6 +181,7 @@ async def test_generate_filters(db: InfrahubDatabase, default_branch: Branch, da expected_filters = [ "offset", "limit", + "order", "partial_match", "ids", "any__is_protected",