Skip to content

Commit

Permalink
Fix node deletion with unidirectional optional relationship
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasG0 committed Feb 19, 2025
1 parent 549ae11 commit 86eebfd
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 30 deletions.
40 changes: 13 additions & 27 deletions backend/infrahub/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from infrahub.core.query.relationship import RelationshipGetPeerQuery
from infrahub.core.registry import registry
from infrahub.core.relationship import Relationship, RelationshipManager
from infrahub.core.relationship.utils import query_peers_relationships
from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, RelationshipSchema
from infrahub.core.timestamp import Timestamp
from infrahub.exceptions import NodeNotFoundError, ProcessingError, SchemaNotFoundError
Expand Down Expand Up @@ -352,25 +353,21 @@ async def query_peers(
branch = await registry.get_branch(branch=branch, db=db)
at = Timestamp(at)

rel = Relationship(schema=schema, branch=branch, node_id="PLACEHOLDER")

query = await RelationshipGetPeerQuery.init(
rels = await query_peers_relationships(
db=db,
source_ids=ids,
source_kind=source_kind,
schema=schema,
rel_schema=schema,
filters=filters,
rel=rel,
offset=offset,
limit=limit,
at=at,
branch=branch,
branch_agnostic=branch_agnostic,
)
await query.execute(db=db)

peers_info = list(query.get_peers())
if not peers_info:
return []
if not fetch_peers:
return rels

# if display_label has been requested we need to ensure we are querying the right fields
if fields and "display_label" in fields:
Expand All @@ -386,26 +383,15 @@ async def query_peers(
if hfid_fields:
fields = deep_merge_dict(dicta=fields, dictb=hfid_fields)

if fetch_peers:
peer_ids = [peer.peer_id for peer in peers_info]
peer_nodes = await cls.get_many(
db=db, ids=peer_ids, fields=fields, at=at, branch=branch, branch_agnostic=branch_agnostic
)
peer_ids = [str(rel.data.peer_id) for rel in rels]
peer_nodes = await cls.get_many(
db=db, ids=peer_ids, fields=fields, at=at, branch=branch, branch_agnostic=branch_agnostic
)

results = []
for peer in peers_info:
result = await Relationship(schema=schema, branch=branch, at=at, node_id=peer.source_id).load(
db=db,
id=peer.rel_node_id,
db_id=peer.rel_node_db_id,
updated_at=peer.updated_at,
data=peer,
)
if fetch_peers:
await result.set_peer(value=peer_nodes[peer.peer_id])
results.append(result)
for rel in rels:
await rel.set_peer(value=peer_nodes[str(rel.data.peer_id)])

return results
return rels

@classmethod
async def count_hierarchy(
Expand Down
66 changes: 63 additions & 3 deletions backend/infrahub/core/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from infrahub_sdk.uuidt import UUIDT

from infrahub.core import registry
from infrahub.core.constants import BranchSupportType, ComputedAttributeKind, InfrahubKind, RelationshipCardinality
from infrahub.core.constants import (
GLOBAL_BRANCH_NAME,
BranchSupportType,
ComputedAttributeKind,
InfrahubKind,
RelationshipCardinality,
)
from infrahub.core.constants.schema import SchemaElementPathType
from infrahub.core.protocols import CoreNumberPool
from infrahub.core.query.node import NodeCheckIDQuery, NodeCreateAllQuery, NodeDeleteQuery, NodeGetListQuery
Expand All @@ -19,7 +25,9 @@

from ...graphql.constants import KIND_GRAPHQL_FIELD_NAME
from ...graphql.models import OrderModel
from ..query.relationship import RelationshipGetByIdentifierQuery
from ..relationship import RelationshipManager
from ..relationship.utils import query_peers_relationships
from ..utils import update_relationships_to
from .base import BaseNode, BaseNodeMeta, BaseNodeOptions

Expand Down Expand Up @@ -601,8 +609,15 @@ async def delete(self, db: InfrahubDatabase, at: Optional[Timestamp] = None) ->

# Go over the list of relationships and update them one by one
for name in self._relationships:
rel: RelationshipManager = getattr(self, name)
await rel.delete(at=delete_at, db=db)
rel_manager: RelationshipManager = getattr(self, name)
await rel_manager.delete(at=delete_at, db=db)

schema_branch = registry.schema.get_schema_branch(name=self._branch.name)
if (
self.get_kind() in schema_branch.unidirectional_relationships
and schema_branch.unidirectional_relationships[self.get_kind()]
):
await self._delete_unidirectional_relationships(at, db, schema_branch)

# Need to check if there are some unidirectional relationship as well
# For example, if we delete a tag, we must check the permissions and update all the relationships pointing at it
Expand All @@ -620,12 +635,57 @@ async def delete(self, db: InfrahubDatabase, at: Optional[Timestamp] = None) ->
await query.execute(db=db)
result = query.get_result()

# result.get("rb_id") actually only returns root_id, so we update `to` only for root here.
if result and result.get("rb.branch") == branch.name:
await update_relationships_to([result.get("rb_id")], to=delete_at, db=db)

query = await NodeDeleteQuery.init(db=db, node=self, at=delete_at)
await query.execute(db=db)

async def _delete_unidirectional_relationships(self, at, db, schema_branch) -> None:
"""
Unidirectional incoming relationships require special handling as they do not belong to the schema node
of the object being deleted. Thus, we look for existing ones within SchemaBranch, query for them from the db,
and remove them if some exist.
"""

query = await RelationshipGetByIdentifierQuery.init(
db=db,
branch=self._branch,
at=at,
identifiers=schema_branch.unidirectional_relationships[self.get_kind()],
excluded_namespaces=[],
)
await query.execute(db=db)
for peer in query.get_peers():
if peer.source_kind == self.get_kind():
peer_kind = peer.destination_kind
peer_id = peer.destination_id
else:
peer_kind = peer.source_kind
peer_id = peer.source_id

node_schema = schema_branch.get_node(name=peer_kind)
rel_schemas = [
rel_schema for rel_schema in node_schema.relationships if rel_schema.identifier == peer.identifier
]
if len(rel_schemas) > 1:
raise ValueError(f"Relationship {peer.identifier} is duplicated")

rels = await query_peers_relationships(
db=db,
source_ids=[str(peer_id)],
source_kind=peer_kind,
rel_schema=rel_schemas[0],
filters={},
at=at,
branch=self._branch,
branch_agnostic=self.get_branch_based_on_support_type().name == GLOBAL_BRANCH_NAME,
)

for rel in rels:
await rel.delete(db=db, at=at)

async def to_graphql(
self,
db: InfrahubDatabase,
Expand Down
46 changes: 46 additions & 0 deletions backend/infrahub/core/relationship/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from infrahub.core.query.relationship import RelationshipGetPeerQuery
from infrahub.core.relationship import Relationship
from infrahub.core.schema import RelationshipSchema
from infrahub.core.timestamp import Timestamp
from infrahub.database import InfrahubDatabase


async def query_peers_relationships( # type: ignore[no-untyped-def]
db: InfrahubDatabase,
source_ids: list[str],
source_kind: str,
rel_schema: RelationshipSchema,
filters: dict,
at: Timestamp,
branch, # from infrahub.core.branch import Branch leads to circular import and cannot be used within TYPE_CHECKING block. Why?
branch_agnostic: bool = False,
offset: int | None = None,
limit: int | None = None,
) -> list[Relationship]:
rel = Relationship(schema=rel_schema, branch=branch, node_id="PLACEHOLDER")

query = await RelationshipGetPeerQuery.init(
db=db,
source_ids=source_ids,
source_kind=source_kind,
schema=rel_schema,
filters=filters,
rel=rel,
offset=offset,
limit=limit,
at=at,
branch_agnostic=branch_agnostic,
)
await query.execute(db=db)

rels = [
await Relationship(schema=rel_schema, branch=branch, at=at, node_id=str(peer.source_id)).load(
db=db,
id=peer.rel_node_id,
db_id=peer.rel_node_db_id,
updated_at=peer.updated_at,
data=peer,
)
for peer in query.get_peers()
]
return rels
15 changes: 15 additions & 0 deletions backend/infrahub/core/schema/schema_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
name: str | None = None,
data: dict[str, dict[str, str]] | None = None,
computed_attributes: ComputedAttributes | None = None,
unidirectional_relationships: dict[str, list[str]] | None = None,
):
self._cache: dict[str, Union[NodeSchema, GenericSchema]] = cache
self.name: str | None = name
Expand All @@ -80,6 +81,11 @@ def __init__(
self.profiles: dict[str, str] = {}
self.computed_attributes = computed_attributes or ComputedAttributes()

# node kind to relationship identifier
self.unidirectional_relationships: dict[str, list[str]] = (
unidirectional_relationships if unidirectional_relationships is not None else defaultdict(list)
)

if data:
self.nodes = data.get("nodes", {})
self.generics = data.get("generics", {})
Expand Down Expand Up @@ -263,6 +269,7 @@ def duplicate(self, name: Optional[str] = None) -> SchemaBranch:
data=copy.deepcopy(self.to_dict()),
cache=self._cache,
computed_attributes=self.computed_attributes.duplicate(),
unidirectional_relationships=self.unidirectional_relationships,
)

def set(self, name: str, schema: MainSchemaTypes) -> str:
Expand Down Expand Up @@ -1105,6 +1112,14 @@ def process_relationships(self) -> None:

schema_to_update: Optional[Union[NodeSchema, GenericSchema]] = None
for relationship in node.relationships:
# Fill unidirectional relationships mapping so we can delete them while deleting corresponding nodes.
peer_schema = self.get(name=relationship.peer)
for peer_rel in peer_schema.relationships:
if peer_rel.identifier == relationship.identifier:
break
else:
self.unidirectional_relationships[peer_schema.kind].append(relationship.identifier)

if relationship.on_delete is not None:
continue
if not schema_to_update:
Expand Down
31 changes: 31 additions & 0 deletions backend/tests/unit/core/test_node_manager_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from infrahub.core.schema.schema_branch import SchemaBranch
from infrahub.database import InfrahubDatabase
from infrahub.exceptions import ValidationError
from tests.constants import TestKind
from tests.helpers.schema import CAR_SCHEMA, load_schema
from tests.helpers.test_app import TestInfrahubApp


async def test_delete_succeeds(
Expand Down Expand Up @@ -202,3 +205,31 @@ async def test_delete_with_cascade_on_generic_allowed(db, default_branch, depend
assert {d.id for d in deleted} == {human.id, dog.id}
node_map = await NodeManager.get_many(db=db, ids=[human.id, dog.id])
assert node_map == {}


class TestDeleteUnidirectionalRelationship(TestInfrahubApp):
async def test_delete_unidirectional_optional_relationship(self, db, client, default_branch):
await load_schema(db, schema=CAR_SCHEMA)

owner = await Node.init(schema=TestKind.PERSON, db=db)
await owner.new(db=db, name="John Doe", height=175)
await owner.save(db=db)

previous_owner = await Node.init(schema=TestKind.PERSON, db=db)
await previous_owner.new(db=db, name="Eric", height=175)
await previous_owner.save(db=db)

koenigsegg = await Node.init(schema=TestKind.MANUFACTURER, db=db)
await koenigsegg.new(db=db, name="Koenigsegg")
await koenigsegg.save(db=db)

car = await Node.init(schema=TestKind.CAR, db=db)
await car.new(
db=db, name="Jesko", color="Red", owner=owner, manufacturer=koenigsegg, previous_owner=previous_owner
)
await car.save(db=db)

await previous_owner.delete(db=db)
res = await NodeManager.get_many(db=db, ids=[car.id])
rels = await res[car.id].previous_owner.get_relationships(db=db)
assert len(rels) == 0

0 comments on commit 86eebfd

Please sign in to comment.