Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix node deletion with unidirectional optional relationship #5783

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 13 additions & 27 deletions backend/infrahub/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from infrahub.core.query.relationship import RelationshipGetPeerQuery
from infrahub.core.registry import registry
from infrahub.core.relationship import Relationship, RelationshipManager
from infrahub.core.relationship.utils import query_peers_relationships
from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, RelationshipSchema
from infrahub.core.timestamp import Timestamp
from infrahub.exceptions import NodeNotFoundError, ProcessingError, SchemaNotFoundError
Expand Down Expand Up @@ -352,25 +353,21 @@ async def query_peers(
branch = await registry.get_branch(branch=branch, db=db)
at = Timestamp(at)

rel = Relationship(schema=schema, branch=branch, node_id="PLACEHOLDER")

query = await RelationshipGetPeerQuery.init(
rels = await query_peers_relationships(
db=db,
source_ids=ids,
source_kind=source_kind,
schema=schema,
rel_schema=schema,
filters=filters,
rel=rel,
offset=offset,
limit=limit,
at=at,
branch=branch,
branch_agnostic=branch_agnostic,
)
await query.execute(db=db)

peers_info = list(query.get_peers())
if not peers_info:
return []
if not fetch_peers:
return rels

# if display_label has been requested we need to ensure we are querying the right fields
if fields and "display_label" in fields:
Expand All @@ -386,26 +383,15 @@ async def query_peers(
if hfid_fields:
fields = deep_merge_dict(dicta=fields, dictb=hfid_fields)

if fetch_peers:
peer_ids = [peer.peer_id for peer in peers_info]
peer_nodes = await cls.get_many(
db=db, ids=peer_ids, fields=fields, at=at, branch=branch, branch_agnostic=branch_agnostic
)
peer_ids = [str(rel.data.peer_id) for rel in rels]
peer_nodes = await cls.get_many(
db=db, ids=peer_ids, fields=fields, at=at, branch=branch, branch_agnostic=branch_agnostic
)

results = []
for peer in peers_info:
result = await Relationship(schema=schema, branch=branch, at=at, node_id=peer.source_id).load(
db=db,
id=peer.rel_node_id,
db_id=peer.rel_node_db_id,
updated_at=peer.updated_at,
data=peer,
)
if fetch_peers:
await result.set_peer(value=peer_nodes[peer.peer_id])
results.append(result)
for rel in rels:
await rel.set_peer(value=peer_nodes[str(rel.data.peer_id)])

return results
return rels

@classmethod
async def count_hierarchy(
Expand Down
66 changes: 63 additions & 3 deletions backend/infrahub/core/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from infrahub_sdk.uuidt import UUIDT

from infrahub.core import registry
from infrahub.core.constants import BranchSupportType, ComputedAttributeKind, InfrahubKind, RelationshipCardinality
from infrahub.core.constants import (
GLOBAL_BRANCH_NAME,
BranchSupportType,
ComputedAttributeKind,
InfrahubKind,
RelationshipCardinality,
)
from infrahub.core.constants.schema import SchemaElementPathType
from infrahub.core.protocols import CoreNumberPool
from infrahub.core.query.node import NodeCheckIDQuery, NodeCreateAllQuery, NodeDeleteQuery, NodeGetListQuery
Expand All @@ -19,7 +25,9 @@

from ...graphql.constants import KIND_GRAPHQL_FIELD_NAME
from ...graphql.models import OrderModel
from ..query.relationship import RelationshipGetByIdentifierQuery
from ..relationship import RelationshipManager
from ..relationship.utils import query_peers_relationships
from ..utils import update_relationships_to
from .base import BaseNode, BaseNodeMeta, BaseNodeOptions

Expand Down Expand Up @@ -601,8 +609,15 @@ async def delete(self, db: InfrahubDatabase, at: Optional[Timestamp] = None) ->

# Go over the list of relationships and update them one by one
for name in self._relationships:
rel: RelationshipManager = getattr(self, name)
await rel.delete(at=delete_at, db=db)
rel_manager: RelationshipManager = getattr(self, name)
await rel_manager.delete(at=delete_at, db=db)

schema_branch = registry.schema.get_schema_branch(name=self._branch.name)
if (
self.get_kind() in schema_branch.unidirectional_relationships
and schema_branch.unidirectional_relationships[self.get_kind()]
):
await self._delete_unidirectional_relationships(at, db, schema_branch)

# Need to check if there are some unidirectional relationship as well
# For example, if we delete a tag, we must check the permissions and update all the relationships pointing at it
Expand All @@ -620,12 +635,57 @@ async def delete(self, db: InfrahubDatabase, at: Optional[Timestamp] = None) ->
await query.execute(db=db)
result = query.get_result()

# result.get("rb_id") actually only returns root_id, so we update `to` only for root here.
if result and result.get("rb.branch") == branch.name:
await update_relationships_to([result.get("rb_id")], to=delete_at, db=db)

query = await NodeDeleteQuery.init(db=db, node=self, at=delete_at)
await query.execute(db=db)

async def _delete_unidirectional_relationships(self, at, db, schema_branch) -> None:
"""
Unidirectional incoming relationships require special handling as they do not belong to the schema node
of the object being deleted. Thus, we look for existing ones within SchemaBranch, query for them from the db,
and remove them if some exist.
"""

query = await RelationshipGetByIdentifierQuery.init(
db=db,
branch=self._branch,
at=at,
identifiers=schema_branch.unidirectional_relationships[self.get_kind()],
excluded_namespaces=[],
)
await query.execute(db=db)
for peer in query.get_peers():
if peer.source_kind == self.get_kind():
peer_kind = peer.destination_kind
peer_id = peer.destination_id
else:
peer_kind = peer.source_kind
peer_id = peer.source_id

node_schema = schema_branch.get_node(name=peer_kind)
rel_schemas = [
rel_schema for rel_schema in node_schema.relationships if rel_schema.identifier == peer.identifier
]
if len(rel_schemas) > 1:
raise ValueError(f"Relationship {peer.identifier} is duplicated")

rels = await query_peers_relationships(
db=db,
source_ids=[str(peer_id)],
source_kind=peer_kind,
rel_schema=rel_schemas[0],
filters={},
at=at,
branch=self._branch,
branch_agnostic=self.get_branch_based_on_support_type().name == GLOBAL_BRANCH_NAME,
)

for rel in rels:
await rel.delete(db=db, at=at)

async def to_graphql(
self,
db: InfrahubDatabase,
Expand Down
46 changes: 46 additions & 0 deletions backend/infrahub/core/relationship/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from infrahub.core.query.relationship import RelationshipGetPeerQuery
from infrahub.core.relationship import Relationship
from infrahub.core.schema import RelationshipSchema
from infrahub.core.timestamp import Timestamp
from infrahub.database import InfrahubDatabase


async def query_peers_relationships( # type: ignore[no-untyped-def]
db: InfrahubDatabase,
source_ids: list[str],
source_kind: str,
rel_schema: RelationshipSchema,
filters: dict,
at: Timestamp,
branch, # from infrahub.core.branch import Branch leads to circular import and cannot be used within TYPE_CHECKING block. Why?
branch_agnostic: bool = False,
offset: int | None = None,
limit: int | None = None,
) -> list[Relationship]:
rel = Relationship(schema=rel_schema, branch=branch, node_id="PLACEHOLDER")

query = await RelationshipGetPeerQuery.init(
db=db,
source_ids=source_ids,
source_kind=source_kind,
schema=rel_schema,
filters=filters,
rel=rel,
offset=offset,
limit=limit,
at=at,
branch_agnostic=branch_agnostic,
)
await query.execute(db=db)

rels = [
await Relationship(schema=rel_schema, branch=branch, at=at, node_id=str(peer.source_id)).load(
db=db,
id=peer.rel_node_id,
db_id=peer.rel_node_db_id,
updated_at=peer.updated_at,
data=peer,
)
for peer in query.get_peers()
]
return rels
15 changes: 15 additions & 0 deletions backend/infrahub/core/schema/schema_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
name: str | None = None,
data: dict[str, dict[str, str]] | None = None,
computed_attributes: ComputedAttributes | None = None,
unidirectional_relationships: dict[str, list[str]] | None = None,
):
self._cache: dict[str, Union[NodeSchema, GenericSchema]] = cache
self.name: str | None = name
Expand All @@ -80,6 +81,11 @@ def __init__(
self.profiles: dict[str, str] = {}
self.computed_attributes = computed_attributes or ComputedAttributes()

# node kind to relationship identifier
self.unidirectional_relationships: dict[str, list[str]] = (
unidirectional_relationships if unidirectional_relationships is not None else defaultdict(list)
)

if data:
self.nodes = data.get("nodes", {})
self.generics = data.get("generics", {})
Expand Down Expand Up @@ -263,6 +269,7 @@ def duplicate(self, name: Optional[str] = None) -> SchemaBranch:
data=copy.deepcopy(self.to_dict()),
cache=self._cache,
computed_attributes=self.computed_attributes.duplicate(),
unidirectional_relationships=self.unidirectional_relationships,
)

def set(self, name: str, schema: MainSchemaTypes) -> str:
Expand Down Expand Up @@ -1105,6 +1112,14 @@ def process_relationships(self) -> None:

schema_to_update: Optional[Union[NodeSchema, GenericSchema]] = None
for relationship in node.relationships:
# Fill unidirectional relationships mapping so we can delete them while deleting corresponding nodes.
peer_schema = self.get(name=relationship.peer)
for peer_rel in peer_schema.relationships:
if peer_rel.identifier == relationship.identifier:
break
else:
self.unidirectional_relationships[peer_schema.kind].append(relationship.identifier)

if relationship.on_delete is not None:
continue
if not schema_to_update:
Expand Down
20 changes: 18 additions & 2 deletions backend/tests/unit/core/diff/test_diff_and_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,14 @@ async def test_diff_and_merge_schema_with_default_values(
car_person_schema: SchemaBranch,
):
schema_main = registry.schema.get_schema_branch(name=default_branch.name)
# As any node has a `group_member` relationship, we need to load CoreGroup to avoid SchemaNotFoundError
# when we check for unidirectional relationships. Also add CoreNode as CoreGroup has some CoreNode relationships.
await registry.schema.update_schema_branch(
db=db, branch=default_branch, schema=schema_main, limit=["TestCar", "TestPerson"], update_db=True
db=db,
branch=default_branch,
schema=schema_main,
limit=["TestCar", "TestPerson", "CoreGroup", "CoreNode"],
update_db=True,
)
branch2 = await create_branch(db=db, branch_name="branch2")
schema_branch = registry.schema.get_schema_branch(name=branch2.name)
Expand All @@ -100,10 +106,15 @@ async def test_diff_and_merge_schema_with_default_values(
car_schema_branch.attributes.append(AttributeSchema(name="num_cupholders", kind="Number", default_value=15))
car_schema_branch.attributes.append(AttributeSchema(name="is_cool", kind="Boolean", default_value=False))
car_schema_branch.attributes.append(AttributeSchema(name="nickname", kind="Text", default_value="car"))
# car_schema_branch.relationships = [rel for rel in car_schema_branch.relationships if rel != ""]
schema_branch.set(name="TestCar", schema=car_schema_branch)
schema_branch.process()
await registry.schema.update_schema_branch(
db=db, branch=branch2, schema=schema_branch, limit=["TestCar", "TestPerson"], update_db=True
db=db,
branch=branch2,
schema=schema_branch,
limit=["TestCar", "TestPerson", "CoreGroup", "CoreNode"],
update_db=True,
)

at = Timestamp()
Expand All @@ -112,6 +123,11 @@ async def test_diff_and_merge_schema_with_default_values(
diff_merger = await self._get_diff_merger(db=db, branch=branch2)
await diff_merger.merge_graph(at=at)

# TODO it fails here, loaded schema only has TestCar/Person and corresponding profiles and CoreProfile.
# what is happening above with the `limit`? Are we saying we want only TestCar and TestPerson in branch2
# then merge in main branch so there are only these schemas?
# 1. Is it acceptable considering TestCar/Person have a group_member link to CoreGroup?
# 2. Or is it something that can only happen internally while testing? How to solve it except adding every missing schemas (cumbersome)?
updated_schema = await registry.schema.load_schema_from_db(db=db, branch=default_branch)
car_schema_main = updated_schema.get(name="TestCar", duplicate=False)
new_int_attr = car_schema_main.get_attribute(name="num_cupholders")
Expand Down
31 changes: 31 additions & 0 deletions backend/tests/unit/core/test_node_manager_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from infrahub.core.schema.schema_branch import SchemaBranch
from infrahub.database import InfrahubDatabase
from infrahub.exceptions import ValidationError
from tests.constants import TestKind
from tests.helpers.schema import CAR_SCHEMA, load_schema
from tests.helpers.test_app import TestInfrahubApp


async def test_delete_succeeds(
Expand Down Expand Up @@ -202,3 +205,31 @@ async def test_delete_with_cascade_on_generic_allowed(db, default_branch, depend
assert {d.id for d in deleted} == {human.id, dog.id}
node_map = await NodeManager.get_many(db=db, ids=[human.id, dog.id])
assert node_map == {}


class TestDeleteUnidirectionalRelationship(TestInfrahubApp):
async def test_delete_unidirectional_optional_relationship(self, db, client, default_branch):
await load_schema(db, schema=CAR_SCHEMA)

owner = await Node.init(schema=TestKind.PERSON, db=db)
await owner.new(db=db, name="John Doe", height=175)
await owner.save(db=db)

previous_owner = await Node.init(schema=TestKind.PERSON, db=db)
await previous_owner.new(db=db, name="Eric", height=175)
await previous_owner.save(db=db)

koenigsegg = await Node.init(schema=TestKind.MANUFACTURER, db=db)
await koenigsegg.new(db=db, name="Koenigsegg")
await koenigsegg.save(db=db)

car = await Node.init(schema=TestKind.CAR, db=db)
await car.new(
db=db, name="Jesko", color="Red", owner=owner, manufacturer=koenigsegg, previous_owner=previous_owner
)
await car.save(db=db)

await previous_owner.delete(db=db)
res = await NodeManager.get_many(db=db, ids=[car.id])
rels = await res[car.id].previous_owner.get_relationships(db=db)
assert len(rels) == 0
Loading