Skip to content

Commit

Permalink
Add node_get_list benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasG0 committed Jan 6, 2025
1 parent 97454d3 commit 772f747
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 6 deletions.
4 changes: 4 additions & 0 deletions backend/infrahub/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ async def query(
account=...,
partial_match: bool = ...,
branch_agnostic: bool = ...,
order: dict[str, Any] | None = ...,
) -> list[Any]: ...

@overload
Expand All @@ -161,6 +162,7 @@ async def query(
account=...,
partial_match: bool = ...,
branch_agnostic: bool = ...,
order: dict[str, Any] | None = ...,
) -> list[SchemaProtocol]: ...

@classmethod
Expand All @@ -180,6 +182,7 @@ async def query(
account=None,
partial_match: bool = False,
branch_agnostic: bool = False,
order: dict[str, Any] | None = None,
) -> list[Any]:
"""Query one or multiple nodes of a given type based on filter arguments.
Expand Down Expand Up @@ -227,6 +230,7 @@ async def query(
at=at,
partial_match=partial_match,
branch_agnostic=branch_agnostic,
order=order,
)
await query.execute(db=db)
node_ids = query.get_node_ids()
Expand Down
8 changes: 5 additions & 3 deletions backend/infrahub/core/query/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,14 +808,14 @@ def __init__(
schema: NodeSchema,
filters: Optional[dict] = None,
partial_match: bool = False,
ordering: bool = True,
order: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.schema = schema
self.filters = filters
self.partial_match = partial_match
self._variables_to_track = ["n", "rb"]
self.ordering = ordering
self.order = order
self._validate_filters()

super().__init__(**kwargs)
Expand Down Expand Up @@ -1182,7 +1182,9 @@ def _get_field_requirements(self) -> list[FieldAttributeRequirement]:
types=[FieldAttributeRequirementType.FILTER],
)
index += 1
if not self.schema.order_by or not self.ordering:

disable_order = self.order["disable"] if self.order is not None and "disable" in self.order else False
if not self.schema.order_by or disable_order:
return list(field_requirements_map.values())

for order_by_path in self.schema.order_by:
Expand Down
6 changes: 5 additions & 1 deletion backend/infrahub/graphql/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class DeleteInput(graphene.InputObjectType):
]


class OrderInput(graphene.InputObjectType):
disable = graphene.Boolean(required=False)


@dataclass
class GraphqlMutations:
create: type[InfrahubMutation]
Expand Down Expand Up @@ -864,7 +868,7 @@ def generate_filters(
dict: A Dictionary containing all the filters with their name as the key and their Type as value
"""

filters: dict[str, Any] = {"offset": graphene.Int(), "limit": graphene.Int()}
filters: dict[str, Any] = {"offset": graphene.Int(), "limit": graphene.Int(), "order": OrderInput()}
default_filters: list[str] = list(filters.keys())

filters["ids"] = graphene.List(graphene.ID)
Expand Down
2 changes: 2 additions & 0 deletions backend/infrahub/graphql/resolvers/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ async def default_paginated_list_resolver(
info: GraphQLResolveInfo,
offset: int | None = None,
limit: int | None = None,
order: dict[str, Any] | None = None,
partial_match: bool = False,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
Expand Down Expand Up @@ -173,6 +174,7 @@ async def default_paginated_list_resolver(
include_source=True,
include_owner=True,
partial_match=partial_match,
order=order,
)

if "count" in fields:
Expand Down
2 changes: 1 addition & 1 deletion backend/tests/helpers/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def bus_simulator(self, db: InfrahubDatabase) -> Generator[BusSimulator, None, N
config.OVERRIDE.message_bus = original

@pytest.fixture(scope="class", autouse=True)
async def workflow_local(self, prefect: Generator[str, None, None]) -> AsyncGenerator[WorkflowLocalExecution, None]:
async def workflow_local(self) -> AsyncGenerator[WorkflowLocalExecution, None]:
original = config.OVERRIDE.workflow
workflow = WorkflowLocalExecution()
await setup_task_manager()
Expand Down
3 changes: 3 additions & 0 deletions backend/tests/query_benchmark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ async def car_person_schema_root() -> SchemaRoot:
"default_filter": "name__value",
"display_labels": ["name__value", "color__value"],
"uniqueness_constraints": [["name__value"]],
"order_by": ["name__value"],
"branch": BranchSupportType.AWARE.value,
"attributes": [
{"name": "name", "kind": "Text", "unique": True},
Expand Down Expand Up @@ -58,6 +59,7 @@ async def car_person_schema_root() -> SchemaRoot:
"display_labels": ["name__value"],
"branch": BranchSupportType.AWARE.value,
"uniqueness_constraints": [["name__value"]],
"order_by": ["name__value"],
"attributes": [
{"name": "name", "kind": "Text", "unique": True},
{"name": "height", "kind": "Number", "optional": True},
Expand All @@ -82,6 +84,7 @@ async def car_person_schema_root() -> SchemaRoot:
"namespace": "Test",
"default_filter": "name__value",
"display_labels": ["name__value"],
"order_by": ["name__value"],
"branch": BranchSupportType.AWARE.value,
"uniqueness_constraints": [["name__value"]],
"attributes": [
Expand Down
81 changes: 81 additions & 0 deletions backend/tests/query_benchmark/test_node_get_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import inspect
from pathlib import Path

import pytest

from infrahub.core import registry
from infrahub.core.query.node import NodeGetListQuery
from infrahub.database.constants import Neo4jRuntime
from infrahub.log import get_logger
from tests.helpers.constants import NEO4J_ENTERPRISE_IMAGE
from tests.helpers.query_benchmark.benchmark_config import BenchmarkConfig
from tests.helpers.query_benchmark.car_person_generators import (
CarGenerator,
)
from tests.helpers.query_benchmark.data_generator import load_data_and_profile
from tests.query_benchmark.conftest import RESULTS_FOLDER
from tests.query_benchmark.utils import start_db_and_create_default_branch

log = get_logger()

# pytestmark = pytest.mark.skip("Not relevant to test this currently.")


@pytest.mark.timeout(36000) # 10 hours
@pytest.mark.parametrize(
"benchmark_config, ordering",
[
(
BenchmarkConfig(
neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE, load_db_indexes=False
),
False,
),
(
BenchmarkConfig(
neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE, load_db_indexes=False
),
True,
),
],
)
async def test_node_get_list_ordering(
benchmark_config, car_person_schema_root, graph_generator, increase_query_size_limit, ordering
):
# Initialization
db_profiling_queries, default_branch = await start_db_and_create_default_branch(
neo4j_image=benchmark_config.neo4j_image,
load_indexes=benchmark_config.load_db_indexes,
)
registry.schema.register_schema(schema=car_person_schema_root, branch=default_branch.name)

# Build function to profile
async def init_and_execute():
car_node_schema = registry.get_node_schema(name="TestCar", branch=default_branch.name)
query = await NodeGetListQuery.init(
db=db_profiling_queries,
schema=car_node_schema,
branch=default_branch,
ordering=ordering,
)
res = await query.execute(db=db_profiling_queries)
print(f"{len(res.get_node_ids())=}")
return res

nb_cars = 10_000
cars_generator = CarGenerator(db=db_profiling_queries)
test_name = inspect.currentframe().f_code.co_name
module_name = Path(__file__).stem
graph_output_location = RESULTS_FOLDER / module_name / test_name

test_label = str(benchmark_config) + "_ordering_" + str(ordering)

await load_data_and_profile(
data_generator=cars_generator,
func_call=init_and_execute,
profile_frequency=1_000,
nb_elements=nb_cars,
graphs_output_location=graph_output_location,
test_label=test_label,
graph_generator=graph_generator,
)
2 changes: 1 addition & 1 deletion backend/tests/unit/core/test_node_get_list_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ async def test_query_NodeGetListQuery_order_by_disabled(
db=db,
branch=branch,
schema=schema,
ordering=False,
order={"disable": True},
)
await query.execute(db=db)
assert query.get_node_ids() == sorted([car_camry_main.id, car_yaris_main.id, car_accord_main.id, car_volt_main.id])
Expand Down
50 changes: 50 additions & 0 deletions backend/tests/unit/graphql/queries/test_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from infrahub.core.branch import Branch
from infrahub.core.node import Node
from infrahub.database import InfrahubDatabase
from tests.helpers.graphql import graphql_query
from tests.helpers.test_app import TestInfrahubApp


class TestQueryOrder(TestInfrahubApp):
async def test_query_default_order(
self, db: InfrahubDatabase, default_branch: Branch, register_core_models_schema, session_admin, client
):
for i in range(5, 0, -1):
node = await Node.init(db=db, schema="BuiltinTag")
await node.new(db=db, name=f"tag-{i}")
await node.save(db=db)

for disable_order in [True, False, None]:
if disable_order is True:
order_param_query = "( order: { disable: true } )"
elif disable_order is False:
order_param_query = "( order: { disable: false } )"
else:
order_param_query = ""

query = (
"""
query {
BuiltinTag"""
+ order_param_query
+ """ {
edges {
node {
name { value }
}
}
}
}
"""
)

res = await graphql_query(
query=query,
db=db,
branch=default_branch,
)

node_names = [edge["node"]["name"]["value"] for edge in res.data["BuiltinTag"]["edges"]]
if disable_order is True:
node_names = sorted(node_names)
assert node_names == [f"tag-{i}" for i in range(1, 6)]
1 change: 1 addition & 0 deletions backend/tests/unit/graphql/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ async def test_generate_filters(db: InfrahubDatabase, default_branch: Branch, da
expected_filters = [
"offset",
"limit",
"order",
"partial_match",
"ids",
"any__is_protected",
Expand Down

0 comments on commit 772f747

Please sign in to comment.