Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable default ordering in GraphQL queries #5380

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backend/infrahub/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, RelationshipSchema
from infrahub.core.timestamp import Timestamp
from infrahub.exceptions import NodeNotFoundError, ProcessingError, SchemaNotFoundError
from infrahub.graphql.models import OrderModel

if TYPE_CHECKING:
from infrahub.core.branch import Branch
Expand Down Expand Up @@ -141,6 +142,7 @@ async def query(
account=...,
partial_match: bool = ...,
branch_agnostic: bool = ...,
order: OrderModel | None = ...,
) -> list[Any]: ...

@overload
Expand All @@ -161,6 +163,7 @@ async def query(
account=...,
partial_match: bool = ...,
branch_agnostic: bool = ...,
order: OrderModel | None = ...,
) -> list[SchemaProtocol]: ...

@classmethod
Expand All @@ -180,6 +183,7 @@ async def query(
account=None,
partial_match: bool = False,
branch_agnostic: bool = False,
order: OrderModel | None = None,
) -> list[Any]:
"""Query one or multiple nodes of a given type based on filter arguments.
Expand Down Expand Up @@ -227,6 +231,7 @@ async def query(
at=at,
partial_match=partial_match,
branch_agnostic=branch_agnostic,
order=order,
)
await query.execute(db=db)
node_ids = query.get_node_ids()
Expand Down Expand Up @@ -295,6 +300,7 @@ async def count(
at=at,
partial_match=partial_match,
branch_agnostic=branch_agnostic,
order=OrderModel(disable=True),
)
return await query.count(db=db)

Expand Down Expand Up @@ -657,6 +663,7 @@ async def get_one_by_default_filter(
prefetch_relationships=prefetch_relationships,
account=account,
branch_agnostic=branch_agnostic,
order=OrderModel(disable=True),
)

if len(items) > 1:
Expand Down Expand Up @@ -820,6 +827,7 @@ async def get_one_by_hfid(
prefetch_relationships=prefetch_relationships,
account=account,
branch_agnostic=branch_agnostic,
order=OrderModel(disable=True),
)

if len(items) < 1:
Expand Down
8 changes: 7 additions & 1 deletion backend/infrahub/core/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from infrahub.types import ATTRIBUTE_TYPES

from ...graphql.constants import KIND_GRAPHQL_FIELD_NAME
from ...graphql.models import OrderModel
from ..relationship import RelationshipManager
from ..utils import update_relationships_to
from .base import BaseNode, BaseNodeMeta, BaseNodeOptions
Expand Down Expand Up @@ -609,7 +610,12 @@ async def delete(self, db: InfrahubDatabase, at: Optional[Timestamp] = None) ->

# Update the relationship to the branch itself
query = await NodeGetListQuery.init(
db=db, schema=self._schema, filters={"id": self.id}, branch=self._branch, at=delete_at
db=db,
schema=self._schema,
filters={"id": self.id},
branch=self._branch,
at=delete_at,
order=OrderModel(disable=True),
)
await query.execute(db=db)
result = query.get_result()
Expand Down
69 changes: 43 additions & 26 deletions backend/infrahub/core/query/node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import defaultdict
from copy import copy
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from enum import Enum
Expand All @@ -18,6 +19,7 @@
from infrahub.core.schema.attribute_schema import AttributeSchema
from infrahub.core.utils import build_regex_attrs, extract_field_filters
from infrahub.exceptions import QueryError
from infrahub.graphql.models import OrderModel

if TYPE_CHECKING:
from neo4j.graph import Node as Neo4jNode
Expand Down Expand Up @@ -808,14 +810,29 @@ class NodeGetListQuery(Query):
type = QueryType.READ

def __init__(
self, schema: NodeSchema, filters: Optional[dict] = None, partial_match: bool = False, **kwargs: Any
self,
schema: NodeSchema,
filters: Optional[dict] = None,
partial_match: bool = False,
order: OrderModel | None = None,
**kwargs: Any,
) -> None:
self.schema = schema
self.filters = filters
self.partial_match = partial_match
self._variables_to_track = ["n", "rb"]
self._validate_filters()

# Force disabling order when `limit` is 1 as it simplifies the query a lot.
if "limit" in kwargs and kwargs["limit"] == 1:
if order is None:
order = OrderModel(disable=True)
else:
order = copy(order)
order.disable = True

self.order = order

super().__init__(**kwargs)

@property
Expand Down Expand Up @@ -860,7 +877,6 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
self.order_by = []

self.return_labels = ["n.uuid", "rb.branch", f"{db.get_id_function_name()}(rb) as rb_id"]
where_clause_elements = []

branch_filter, branch_params = self.branch.get_query_filter_path(
at=self.at, branch_agnostic=self.branch_agnostic
Expand Down Expand Up @@ -894,32 +910,41 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
""" % {"branch_filter": branch_filter, "node_kind": self.schema.kind}
self.add_to_query(topquery)

use_simple = False
if self.has_filter_by_id and self.filters:
use_simple = True
where_clause_elements.append("n.uuid = $uuid")
self.params["uuid"] = self.filters["id"]
elif not self.has_filters and not self.schema.order_by:
use_simple = True
self.order_by = ["n.uuid"]
self.add_to_query(" AND n.uuid = $uuid")
return

if use_simple:
if where_clause_elements:
self.add_to_query(" AND " + " AND ".join(where_clause_elements))
disable_order = not self.schema.order_by or (self.order is not None and self.order.disable)
if not self.has_filters and disable_order:
# Always order by uuid to guarantee pagination, see https://github.com/opsmill/infrahub/pull/4704.
self.order_by = ["n.uuid"]
return

if self.filters and "ids" in self.filters:
self.add_to_query("AND n.uuid IN $node_ids")
self.params["node_ids"] = self.filters["ids"]

field_attribute_requirements = self._get_field_requirements()
field_attribute_requirements = self._get_field_requirements(disable_order=disable_order)
use_profiles = any(far for far in field_attribute_requirements if far.supports_profile)
await self._add_node_filter_attributes(
db=db, field_attribute_requirements=field_attribute_requirements, branch_filter=branch_filter
)
await self._add_node_order_attributes(
db=db, field_attribute_requirements=field_attribute_requirements, branch_filter=branch_filter
)

if not disable_order:
await self._add_node_order_attributes(
db=db, field_attribute_requirements=field_attribute_requirements, branch_filter=branch_filter
)
for far in field_attribute_requirements:
if not far.is_order:
continue
if far.supports_profile:
self.order_by.append(far.final_value_query_variable)
continue
self.order_by.append(far.node_value_query_variable)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I let this block within should_order branch condition as this block seems only related to ordering.


# Always order by uuid to guarantee pagination, see https://github.com/opsmill/infrahub/pull/4704.
self.order_by.append("n.uuid")

if use_profiles:
await self._add_profiles_per_node_query(db=db, branch_filter=branch_filter)
Expand All @@ -929,15 +954,6 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
await self._add_profile_rollups(field_attribute_requirements=field_attribute_requirements)

self._add_final_filter(field_attribute_requirements=field_attribute_requirements)
self.order_by = []
for far in field_attribute_requirements:
if not far.is_order:
continue
if far.supports_profile:
self.order_by.append(far.final_value_query_variable)
continue
self.order_by.append(far.node_value_query_variable)
self.order_by.append("n.uuid")

async def _add_node_filter_attributes(
self,
Expand Down Expand Up @@ -1184,7 +1200,7 @@ def _add_final_filter(self, field_attribute_requirements: list[FieldAttributeReq
where_str = "WHERE " + " AND ".join(where_parts)
self.add_to_query(where_str)

def _get_field_requirements(self) -> list[FieldAttributeRequirement]:
def _get_field_requirements(self, disable_order: bool) -> list[FieldAttributeRequirement]:
internal_filters = ["any", "attribute", "relationship"]
field_requirements_map: dict[tuple[str, str], FieldAttributeRequirement] = {}
index = 1
Expand All @@ -1206,7 +1222,8 @@ def _get_field_requirements(self) -> list[FieldAttributeRequirement]:
types=[FieldAttributeRequirementType.FILTER],
)
index += 1
if not self.schema.order_by:

if disable_order:
return list(field_requirements_map.values())

for order_by_path in self.schema.order_by:
Expand Down
6 changes: 5 additions & 1 deletion backend/infrahub/graphql/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class DeleteInput(graphene.InputObjectType):
]


class OrderInput(graphene.InputObjectType):
disable = graphene.Boolean(required=False)


@dataclass
class GraphqlMutations:
create: type[InfrahubMutation]
Expand Down Expand Up @@ -864,7 +868,7 @@ def generate_filters(
dict: A Dictionary containing all the filters with their name as the key and their Type as value
"""

filters: dict[str, Any] = {"offset": graphene.Int(), "limit": graphene.Int()}
filters: dict[str, Any] = {"offset": graphene.Int(), "limit": graphene.Int(), "order": OrderInput()}
default_filters: list[str] = list(filters.keys())

filters["ids"] = graphene.List(graphene.ID)
Expand Down
6 changes: 6 additions & 0 deletions backend/infrahub/graphql/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel


# Corresponds to infrahub.graphql.manager.OrderInput
class OrderModel(BaseModel):
disable: bool | None = None
6 changes: 5 additions & 1 deletion backend/infrahub/graphql/mutations/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from infrahub.database import InfrahubDatabase, retry_db_transaction
from infrahub.exceptions import NodeNotFoundError, PermissionDeniedError

from ..models import OrderModel
from ..types import InfrahubObjectType

if TYPE_CHECKING:
Expand Down Expand Up @@ -112,7 +113,10 @@ async def delete_token(
token_id = str(data.get("id"))

results = await NodeManager.query(
schema=InternalAccountToken, filters={"account_ids": [account.id], "ids": [token_id]}, db=db
schema=InternalAccountToken,
filters={"account_ids": [account.id], "ids": [token_id]},
db=db,
order=OrderModel(disable=True),
)

if not results:
Expand Down
4 changes: 4 additions & 0 deletions backend/infrahub/graphql/resolvers/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from infrahub.core.query.node import NodeGetHierarchyQuery
from infrahub.exceptions import NodeNotFoundError

from ..models import OrderModel
from ..parser import extract_selection
from ..permissions import get_permissions
from ..types import RELATIONS_PROPERTY_MAP, RELATIONS_PROPERTY_MAP_REVERSED
Expand All @@ -33,6 +34,7 @@ async def account_resolver(
filters={"ids": [context.account_session.account_id]},
fields=fields,
db=db,
order=OrderModel(disable=True),
)
if results:
account_profile = await results[0].to_graphql(db=db, fields=fields)
Expand Down Expand Up @@ -132,6 +134,7 @@ async def default_paginated_list_resolver(
info: GraphQLResolveInfo,
offset: int | None = None,
limit: int | None = None,
order: OrderModel | None = None,
partial_match: bool = False,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
Expand Down Expand Up @@ -173,6 +176,7 @@ async def default_paginated_list_resolver(
include_source=True,
include_owner=True,
partial_match=partial_match,
order=order,
)

if "count" in fields:
Expand Down
3 changes: 3 additions & 0 deletions backend/tests/query_benchmark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ async def car_person_schema_root() -> SchemaRoot:
"default_filter": "name__value",
"display_labels": ["name__value", "color__value"],
"uniqueness_constraints": [["name__value"]],
"order_by": ["name__value"],
"branch": BranchSupportType.AWARE.value,
"attributes": [
{"name": "name", "kind": "Text", "unique": True},
Expand Down Expand Up @@ -58,6 +59,7 @@ async def car_person_schema_root() -> SchemaRoot:
"display_labels": ["name__value"],
"branch": BranchSupportType.AWARE.value,
"uniqueness_constraints": [["name__value"]],
"order_by": ["name__value"],
"attributes": [
{"name": "name", "kind": "Text", "unique": True},
{"name": "height", "kind": "Number", "optional": True},
Expand All @@ -82,6 +84,7 @@ async def car_person_schema_root() -> SchemaRoot:
"namespace": "Test",
"default_filter": "name__value",
"display_labels": ["name__value"],
"order_by": ["name__value"],
"branch": BranchSupportType.AWARE.value,
"uniqueness_constraints": [["name__value"]],
"attributes": [
Expand Down
2 changes: 0 additions & 2 deletions backend/tests/query_benchmark/test_diff_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

log = get_logger()

# pytestmark = pytest.mark.skip("Not relevant to test this currently.")


@pytest.mark.timeout(36000) # 10 hours
@pytest.mark.parametrize(
Expand Down
Loading
Loading