Skip to content

Commit

Permalink
Add node_get_list benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasG0 committed Jan 6, 2025
1 parent 97454d3 commit 5eb2d19
Show file tree
Hide file tree
Showing 16 changed files with 192 additions and 26 deletions.
8 changes: 8 additions & 0 deletions backend/infrahub/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, RelationshipSchema
from infrahub.core.timestamp import Timestamp
from infrahub.exceptions import NodeNotFoundError, ProcessingError, SchemaNotFoundError
from infrahub.graphql.models import OrderModel

if TYPE_CHECKING:
from infrahub.core.branch import Branch
Expand Down Expand Up @@ -141,6 +142,7 @@ async def query(
account=...,
partial_match: bool = ...,
branch_agnostic: bool = ...,
order: OrderModel | None = ...,
) -> list[Any]: ...

@overload
Expand All @@ -161,6 +163,7 @@ async def query(
account=...,
partial_match: bool = ...,
branch_agnostic: bool = ...,
order: OrderModel | None = ...,
) -> list[SchemaProtocol]: ...

@classmethod
Expand All @@ -180,6 +183,7 @@ async def query(
account=None,
partial_match: bool = False,
branch_agnostic: bool = False,
order: OrderModel | None = None,
) -> list[Any]:
"""Query one or multiple nodes of a given type based on filter arguments.
Expand Down Expand Up @@ -227,6 +231,7 @@ async def query(
at=at,
partial_match=partial_match,
branch_agnostic=branch_agnostic,
order=order,
)
await query.execute(db=db)
node_ids = query.get_node_ids()
Expand Down Expand Up @@ -295,6 +300,7 @@ async def count(
at=at,
partial_match=partial_match,
branch_agnostic=branch_agnostic,
order=OrderModel(disable=True),
)
return await query.count(db=db)

Expand Down Expand Up @@ -657,6 +663,7 @@ async def get_one_by_default_filter(
prefetch_relationships=prefetch_relationships,
account=account,
branch_agnostic=branch_agnostic,
order=OrderModel(disable=True),
)

if len(items) > 1:
Expand Down Expand Up @@ -820,6 +827,7 @@ async def get_one_by_hfid(
prefetch_relationships=prefetch_relationships,
account=account,
branch_agnostic=branch_agnostic,
order=OrderModel(disable=True),
)

if len(items) < 1:
Expand Down
8 changes: 7 additions & 1 deletion backend/infrahub/core/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from infrahub.types import ATTRIBUTE_TYPES

from ...graphql.constants import KIND_GRAPHQL_FIELD_NAME
from ...graphql.models import OrderModel
from ..relationship import RelationshipManager
from ..utils import update_relationships_to
from .base import BaseNode, BaseNodeMeta, BaseNodeOptions
Expand Down Expand Up @@ -609,7 +610,12 @@ async def delete(self, db: InfrahubDatabase, at: Optional[Timestamp] = None) ->

# Update the relationship to the branch itself
query = await NodeGetListQuery.init(
db=db, schema=self._schema, filters={"id": self.id}, branch=self._branch, at=delete_at
db=db,
schema=self._schema,
filters={"id": self.id},
branch=self._branch,
at=delete_at,
order=OrderModel(disable=True),
)
await query.execute(db=db)
result = query.get_result()
Expand Down
48 changes: 32 additions & 16 deletions backend/infrahub/core/query/node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import defaultdict
from copy import copy
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from enum import Enum
Expand All @@ -14,6 +15,7 @@
from infrahub.core.schema.attribute_schema import AttributeSchema
from infrahub.core.utils import build_regex_attrs, extract_field_filters
from infrahub.exceptions import QueryError
from infrahub.graphql.models import OrderModel

if TYPE_CHECKING:
from neo4j.graph import Node as Neo4jNode
Expand Down Expand Up @@ -808,16 +810,25 @@ def __init__(
schema: NodeSchema,
filters: Optional[dict] = None,
partial_match: bool = False,
ordering: bool = True,
order: OrderModel | None = None,
**kwargs: Any,
) -> None:
self.schema = schema
self.filters = filters
self.partial_match = partial_match
self._variables_to_track = ["n", "rb"]
self.ordering = ordering
self._validate_filters()

# Force disabling order when `limit` is 1 as it simplifies the query a lot.
if "limit" in kwargs and kwargs["limit"] == 1:
if order is None:
order = OrderModel(disable=True)
else:
order = copy(order)
order.disable = True

self.order = order

super().__init__(**kwargs)

def _validate_filters(self) -> None:
Expand Down Expand Up @@ -878,7 +889,8 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
self.params["uuid"] = self.filters["id"]
if not self.filters and not self.schema.order_by:
use_simple = True
self.order_by = ["n.uuid"]
if self.order is None or self.order.disable is not True:
self.order_by = ["n.uuid"]
if use_simple:
if where_clause_elements:
self.add_to_query(" AND " + " AND ".join(where_clause_elements))
Expand All @@ -893,9 +905,11 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
await self._add_node_filter_attributes(
db=db, field_attribute_requirements=field_attribute_requirements, branch_filter=branch_filter
)
await self._add_node_order_attributes(
db=db, field_attribute_requirements=field_attribute_requirements, branch_filter=branch_filter
)
should_order = self.schema.order_by and (self.order is None or self.order.disable is not True)
if should_order:
await self._add_node_order_attributes(
db=db, field_attribute_requirements=field_attribute_requirements, branch_filter=branch_filter
)

if use_profiles:
await self._add_profiles_per_node_query(db=db, branch_filter=branch_filter)
Expand All @@ -905,15 +919,15 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
await self._add_profile_rollups(field_attribute_requirements=field_attribute_requirements)

self._add_final_filter(field_attribute_requirements=field_attribute_requirements)
self.order_by = []
for far in field_attribute_requirements:
if not far.is_order:
continue
if far.supports_profile:
self.order_by.append(far.final_value_query_variable)
continue
self.order_by.append(far.node_value_query_variable)
self.order_by.append("n.uuid")
if should_order:
for far in field_attribute_requirements:
if not far.is_order:
continue
if far.supports_profile:
self.order_by.append(far.final_value_query_variable)
continue
self.order_by.append(far.node_value_query_variable)
self.order_by.append("n.uuid")

async def _add_node_filter_attributes(
self,
Expand Down Expand Up @@ -1182,7 +1196,9 @@ def _get_field_requirements(self) -> list[FieldAttributeRequirement]:
types=[FieldAttributeRequirementType.FILTER],
)
index += 1
if not self.schema.order_by or not self.ordering:

disable_order = self.order.disable if self.order is not None else False
if not self.schema.order_by or disable_order:
return list(field_requirements_map.values())

for order_by_path in self.schema.order_by:
Expand Down
6 changes: 5 additions & 1 deletion backend/infrahub/graphql/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class DeleteInput(graphene.InputObjectType):
]


class OrderInput(graphene.InputObjectType):
disable = graphene.Boolean(required=False)


@dataclass
class GraphqlMutations:
create: type[InfrahubMutation]
Expand Down Expand Up @@ -864,7 +868,7 @@ def generate_filters(
dict: A Dictionary containing all the filters with their name as the key and their Type as value
"""

filters: dict[str, Any] = {"offset": graphene.Int(), "limit": graphene.Int()}
filters: dict[str, Any] = {"offset": graphene.Int(), "limit": graphene.Int(), "order": OrderInput()}
default_filters: list[str] = list(filters.keys())

filters["ids"] = graphene.List(graphene.ID)
Expand Down
6 changes: 6 additions & 0 deletions backend/infrahub/graphql/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel


# Corresponds to infrahub.graphql.manager.OrderInput
class OrderModel(BaseModel):
disable: bool | None = None
6 changes: 5 additions & 1 deletion backend/infrahub/graphql/mutations/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from infrahub.database import InfrahubDatabase, retry_db_transaction
from infrahub.exceptions import NodeNotFoundError, PermissionDeniedError

from ..models import OrderModel
from ..types import InfrahubObjectType

if TYPE_CHECKING:
Expand Down Expand Up @@ -112,7 +113,10 @@ async def delete_token(
token_id = str(data.get("id"))

results = await NodeManager.query(
schema=InternalAccountToken, filters={"account_ids": [account.id], "ids": [token_id]}, db=db
schema=InternalAccountToken,
filters={"account_ids": [account.id], "ids": [token_id]},
db=db,
order=OrderModel(disable=True),
)

if not results:
Expand Down
4 changes: 4 additions & 0 deletions backend/infrahub/graphql/resolvers/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from infrahub.core.query.node import NodeGetHierarchyQuery
from infrahub.exceptions import NodeNotFoundError

from ..models import OrderModel
from ..parser import extract_selection
from ..permissions import get_permissions
from ..types import RELATIONS_PROPERTY_MAP, RELATIONS_PROPERTY_MAP_REVERSED
Expand All @@ -33,6 +34,7 @@ async def account_resolver(
filters={"ids": [context.account_session.account_id]},
fields=fields,
db=db,
order=OrderModel(disable=True),
)
if results:
account_profile = await results[0].to_graphql(db=db, fields=fields)
Expand Down Expand Up @@ -132,6 +134,7 @@ async def default_paginated_list_resolver(
info: GraphQLResolveInfo,
offset: int | None = None,
limit: int | None = None,
order: OrderModel | None = None,
partial_match: bool = False,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
Expand Down Expand Up @@ -173,6 +176,7 @@ async def default_paginated_list_resolver(
include_source=True,
include_owner=True,
partial_match=partial_match,
order=order,
)

if "count" in fields:
Expand Down
2 changes: 1 addition & 1 deletion backend/tests/helpers/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def bus_simulator(self, db: InfrahubDatabase) -> Generator[BusSimulator, None, N
config.OVERRIDE.message_bus = original

@pytest.fixture(scope="class", autouse=True)
async def workflow_local(self, prefect: Generator[str, None, None]) -> AsyncGenerator[WorkflowLocalExecution, None]:
async def workflow_local(self) -> AsyncGenerator[WorkflowLocalExecution, None]:
original = config.OVERRIDE.workflow
workflow = WorkflowLocalExecution()
await setup_task_manager()
Expand Down
3 changes: 3 additions & 0 deletions backend/tests/query_benchmark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ async def car_person_schema_root() -> SchemaRoot:
"default_filter": "name__value",
"display_labels": ["name__value", "color__value"],
"uniqueness_constraints": [["name__value"]],
"order_by": ["name__value"],
"branch": BranchSupportType.AWARE.value,
"attributes": [
{"name": "name", "kind": "Text", "unique": True},
Expand Down Expand Up @@ -58,6 +59,7 @@ async def car_person_schema_root() -> SchemaRoot:
"display_labels": ["name__value"],
"branch": BranchSupportType.AWARE.value,
"uniqueness_constraints": [["name__value"]],
"order_by": ["name__value"],
"attributes": [
{"name": "name", "kind": "Text", "unique": True},
{"name": "height", "kind": "Number", "optional": True},
Expand All @@ -82,6 +84,7 @@ async def car_person_schema_root() -> SchemaRoot:
"namespace": "Test",
"default_filter": "name__value",
"display_labels": ["name__value"],
"order_by": ["name__value"],
"branch": BranchSupportType.AWARE.value,
"uniqueness_constraints": [["name__value"]],
"attributes": [
Expand Down
2 changes: 0 additions & 2 deletions backend/tests/query_benchmark/test_diff_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

log = get_logger()

# pytestmark = pytest.mark.skip("Not relevant to test this currently.")


@pytest.mark.timeout(36000) # 10 hours
@pytest.mark.parametrize(
Expand Down
79 changes: 79 additions & 0 deletions backend/tests/query_benchmark/test_node_get_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import inspect
from pathlib import Path

import pytest

from infrahub.core import registry
from infrahub.core.query.node import NodeGetListQuery
from infrahub.database.constants import Neo4jRuntime
from infrahub.log import get_logger
from tests.helpers.constants import NEO4J_ENTERPRISE_IMAGE
from tests.helpers.query_benchmark.benchmark_config import BenchmarkConfig
from tests.helpers.query_benchmark.car_person_generators import (
CarGenerator,
)
from tests.helpers.query_benchmark.data_generator import load_data_and_profile
from tests.query_benchmark.conftest import RESULTS_FOLDER
from tests.query_benchmark.utils import start_db_and_create_default_branch

log = get_logger()


@pytest.mark.timeout(36000) # 10 hours
@pytest.mark.parametrize(
"benchmark_config, ordering",
[
(
BenchmarkConfig(
neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE, load_db_indexes=False
),
False,
),
(
BenchmarkConfig(
neo4j_runtime=Neo4jRuntime.PARALLEL, neo4j_image=NEO4J_ENTERPRISE_IMAGE, load_db_indexes=False
),
True,
),
],
)
async def test_node_get_list_ordering(
benchmark_config, car_person_schema_root, graph_generator, increase_query_size_limit, ordering
):
# Initialization
db_profiling_queries, default_branch = await start_db_and_create_default_branch(
neo4j_image=benchmark_config.neo4j_image,
load_indexes=benchmark_config.load_db_indexes,
)
registry.schema.register_schema(schema=car_person_schema_root, branch=default_branch.name)

# Build function to profile
async def init_and_execute():
car_node_schema = registry.get_node_schema(name="TestCar", branch=default_branch.name)
query = await NodeGetListQuery.init(
db=db_profiling_queries,
schema=car_node_schema,
branch=default_branch,
ordering=ordering,
)
res = await query.execute(db=db_profiling_queries)
print(f"{len(res.get_node_ids())=}")
return res

nb_cars = 10_000
cars_generator = CarGenerator(db=db_profiling_queries)
test_name = inspect.currentframe().f_code.co_name
module_name = Path(__file__).stem
graph_output_location = RESULTS_FOLDER / module_name / test_name

test_label = str(benchmark_config) + "_ordering_" + str(ordering)

await load_data_and_profile(
data_generator=cars_generator,
func_call=init_and_execute,
profile_frequency=1_000,
nb_elements=nb_cars,
graphs_output_location=graph_output_location,
test_label=test_label,
graph_generator=graph_generator,
)
Loading

0 comments on commit 5eb2d19

Please sign in to comment.