diff --git a/backend/infrahub/core/manager.py b/backend/infrahub/core/manager.py index b3e59af753..81d09a9292 100644 --- a/backend/infrahub/core/manager.py +++ b/backend/infrahub/core/manager.py @@ -23,6 +23,7 @@ from infrahub.core.query.relationship import RelationshipGetPeerQuery from infrahub.core.registry import registry from infrahub.core.relationship import Relationship, RelationshipManager +from infrahub.core.relationship.utils import query_peers_relationships from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, RelationshipSchema from infrahub.core.timestamp import Timestamp from infrahub.exceptions import NodeNotFoundError, ProcessingError, SchemaNotFoundError @@ -352,25 +353,21 @@ async def query_peers( branch = await registry.get_branch(branch=branch, db=db) at = Timestamp(at) - rel = Relationship(schema=schema, branch=branch, node_id="PLACEHOLDER") - - query = await RelationshipGetPeerQuery.init( + rels = await query_peers_relationships( db=db, source_ids=ids, source_kind=source_kind, - schema=schema, + rel_schema=schema, filters=filters, - rel=rel, offset=offset, limit=limit, at=at, + branch=branch, branch_agnostic=branch_agnostic, ) - await query.execute(db=db) - peers_info = list(query.get_peers()) - if not peers_info: - return [] + if not fetch_peers: + return rels # if display_label has been requested we need to ensure we are querying the right fields if fields and "display_label" in fields: @@ -386,26 +383,15 @@ async def query_peers( if hfid_fields: fields = deep_merge_dict(dicta=fields, dictb=hfid_fields) - if fetch_peers: - peer_ids = [peer.peer_id for peer in peers_info] - peer_nodes = await cls.get_many( - db=db, ids=peer_ids, fields=fields, at=at, branch=branch, branch_agnostic=branch_agnostic - ) + peer_ids = [str(rel.data.peer_id) for rel in rels] + peer_nodes = await cls.get_many( + db=db, ids=peer_ids, fields=fields, at=at, branch=branch, branch_agnostic=branch_agnostic + ) - results = [] - for peer in peers_info: - result = await Relationship(schema=schema, branch=branch, at=at, node_id=peer.source_id).load( - db=db, - id=peer.rel_node_id, - db_id=peer.rel_node_db_id, - updated_at=peer.updated_at, - data=peer, - ) - if fetch_peers: - await result.set_peer(value=peer_nodes[peer.peer_id]) - results.append(result) + for rel in rels: + await rel.set_peer(value=peer_nodes[str(rel.data.peer_id)]) - return results + return rels @classmethod async def count_hierarchy( diff --git a/backend/infrahub/core/node/__init__.py b/backend/infrahub/core/node/__init__.py index f12f74dec5..9482d3e2b1 100644 --- a/backend/infrahub/core/node/__init__.py +++ b/backend/infrahub/core/node/__init__.py @@ -7,7 +7,13 @@ from infrahub_sdk.uuidt import UUIDT from infrahub.core import registry -from infrahub.core.constants import BranchSupportType, ComputedAttributeKind, InfrahubKind, RelationshipCardinality +from infrahub.core.constants import ( + GLOBAL_BRANCH_NAME, + BranchSupportType, + ComputedAttributeKind, + InfrahubKind, + RelationshipCardinality, +) from infrahub.core.constants.schema import SchemaElementPathType from infrahub.core.protocols import CoreNumberPool from infrahub.core.query.node import NodeCheckIDQuery, NodeCreateAllQuery, NodeDeleteQuery, NodeGetListQuery @@ -19,7 +25,9 @@ from ...graphql.constants import KIND_GRAPHQL_FIELD_NAME from ...graphql.models import OrderModel +from ..query.relationship import RelationshipGetByIdentifierQuery from ..relationship import RelationshipManager +from ..relationship.utils import query_peers_relationships from ..utils import update_relationships_to from .base import BaseNode, BaseNodeMeta, BaseNodeOptions @@ -601,8 +609,15 @@ async def delete(self, db: InfrahubDatabase, at: Optional[Timestamp] = None) -> # Go over the list of relationships and update them one by one for name in self._relationships: - rel: RelationshipManager = getattr(self, name) - await rel.delete(at=delete_at, db=db) + rel_manager: RelationshipManager = getattr(self, name) + await rel_manager.delete(at=delete_at, db=db) + + schema_branch = registry.schema.get_schema_branch(name=self._branch.name) + if ( + self.get_kind() in schema_branch.unidirectional_relationships + and schema_branch.unidirectional_relationships[self.get_kind()] + ): + await self._delete_unidirectional_relationships(at, db, schema_branch) # Need to check if there are some unidirectional relationship as well # For example, if we delete a tag, we must check the permissions and update all the relationships pointing at it @@ -620,12 +635,57 @@ async def delete(self, db: InfrahubDatabase, at: Optional[Timestamp] = None) -> await query.execute(db=db) result = query.get_result() + # result.get("rb_id") actually only returns root_id, so we update `to` only for root here. if result and result.get("rb.branch") == branch.name: await update_relationships_to([result.get("rb_id")], to=delete_at, db=db) query = await NodeDeleteQuery.init(db=db, node=self, at=delete_at) await query.execute(db=db) + async def _delete_unidirectional_relationships(self, at, db, schema_branch) -> None: + """ + Unidirectional incoming relationships require special handling as they do not belong to the schema node + of the object being deleted. Thus, we look for existing ones within SchemaBranch, query for them from the db, + and remove them if some exist. + """ + + query = await RelationshipGetByIdentifierQuery.init( + db=db, + branch=self._branch, + at=at, + identifiers=schema_branch.unidirectional_relationships[self.get_kind()], + excluded_namespaces=[], + ) + await query.execute(db=db) + for peer in query.get_peers(): + if peer.source_kind == self.get_kind(): + peer_kind = peer.destination_kind + peer_id = peer.destination_id + else: + peer_kind = peer.source_kind + peer_id = peer.source_id + + node_schema = schema_branch.get_node(name=peer_kind) + rel_schemas = [ + rel_schema for rel_schema in node_schema.relationships if rel_schema.identifier == peer.identifier + ] + if len(rel_schemas) > 1: + raise ValueError(f"Relationship {peer.identifier} is duplicated") + + rels = await query_peers_relationships( + db=db, + source_ids=[str(peer_id)], + source_kind=peer_kind, + rel_schema=rel_schemas[0], + filters={}, + at=at, + branch=self._branch, + branch_agnostic=self.get_branch_based_on_support_type().name == GLOBAL_BRANCH_NAME, + ) + + for rel in rels: + await rel.delete(db=db, at=at) + async def to_graphql( self, db: InfrahubDatabase, diff --git a/backend/infrahub/core/relationship/utils.py b/backend/infrahub/core/relationship/utils.py new file mode 100644 index 0000000000..0ce8e6df9f --- /dev/null +++ b/backend/infrahub/core/relationship/utils.py @@ -0,0 +1,46 @@ +from infrahub.core.query.relationship import RelationshipGetPeerQuery +from infrahub.core.relationship import Relationship +from infrahub.core.schema import RelationshipSchema +from infrahub.core.timestamp import Timestamp +from infrahub.database import InfrahubDatabase + + +async def query_peers_relationships( # type: ignore[no-untyped-def] + db: InfrahubDatabase, + source_ids: list[str], + source_kind: str, + rel_schema: RelationshipSchema, + filters: dict, + at: Timestamp, + branch, # from infrahub.core.branch import Branch leads to circular import and cannot be used within TYPE_CHECKING block. Why? + branch_agnostic: bool = False, + offset: int | None = None, + limit: int | None = None, +) -> list[Relationship]: + rel = Relationship(schema=rel_schema, branch=branch, node_id="PLACEHOLDER") + + query = await RelationshipGetPeerQuery.init( + db=db, + source_ids=source_ids, + source_kind=source_kind, + schema=rel_schema, + filters=filters, + rel=rel, + offset=offset, + limit=limit, + at=at, + branch_agnostic=branch_agnostic, + ) + await query.execute(db=db) + + rels = [ + await Relationship(schema=rel_schema, branch=branch, at=at, node_id=str(peer.source_id)).load( + db=db, + id=peer.rel_node_id, + db_id=peer.rel_node_db_id, + updated_at=peer.updated_at, + data=peer, + ) + for peer in query.get_peers() + ] + return rels diff --git a/backend/infrahub/core/schema/schema_branch.py b/backend/infrahub/core/schema/schema_branch.py index bf5e9a0b5a..e5c7efe9de 100644 --- a/backend/infrahub/core/schema/schema_branch.py +++ b/backend/infrahub/core/schema/schema_branch.py @@ -72,6 +72,7 @@ def __init__( name: str | None = None, data: dict[str, dict[str, str]] | None = None, computed_attributes: ComputedAttributes | None = None, + unidirectional_relationships: dict[str, list[str]] | None = None, ): self._cache: dict[str, Union[NodeSchema, GenericSchema]] = cache self.name: str | None = name @@ -80,6 +81,11 @@ def __init__( self.profiles: dict[str, str] = {} self.computed_attributes = computed_attributes or ComputedAttributes() + # node kind to relationship identifier + self.unidirectional_relationships: dict[str, list[str]] = ( + unidirectional_relationships if unidirectional_relationships is not None else defaultdict(list) + ) + if data: self.nodes = data.get("nodes", {}) self.generics = data.get("generics", {}) @@ -263,6 +269,7 @@ def duplicate(self, name: Optional[str] = None) -> SchemaBranch: data=copy.deepcopy(self.to_dict()), cache=self._cache, computed_attributes=self.computed_attributes.duplicate(), + unidirectional_relationships=self.unidirectional_relationships, ) def set(self, name: str, schema: MainSchemaTypes) -> str: @@ -1105,6 +1112,14 @@ def process_relationships(self) -> None: schema_to_update: Optional[Union[NodeSchema, GenericSchema]] = None for relationship in node.relationships: + # Fill unidirectional relationships mapping so we can delete them while deleting corresponding nodes. + peer_schema = self.get(name=relationship.peer) + for peer_rel in peer_schema.relationships: + if peer_rel.identifier == relationship.identifier: + break + else: + self.unidirectional_relationships[peer_schema.kind].append(relationship.identifier) + if relationship.on_delete is not None: continue if not schema_to_update: diff --git a/backend/tests/unit/core/diff/test_diff_and_merge.py b/backend/tests/unit/core/diff/test_diff_and_merge.py index 7e72f297b0..6e01d4de42 100644 --- a/backend/tests/unit/core/diff/test_diff_and_merge.py +++ b/backend/tests/unit/core/diff/test_diff_and_merge.py @@ -90,8 +90,14 @@ async def test_diff_and_merge_schema_with_default_values( car_person_schema: SchemaBranch, ): schema_main = registry.schema.get_schema_branch(name=default_branch.name) + # As any node has a `group_member` relationship, we need to load CoreGroup to avoid SchemaNotFoundError + # when we check for unidirectional relationships. Also add CoreNode as CoreGroup has some CoreNode relationships. await registry.schema.update_schema_branch( - db=db, branch=default_branch, schema=schema_main, limit=["TestCar", "TestPerson"], update_db=True + db=db, + branch=default_branch, + schema=schema_main, + limit=["TestCar", "TestPerson", "CoreGroup", "CoreNode"], + update_db=True, ) branch2 = await create_branch(db=db, branch_name="branch2") schema_branch = registry.schema.get_schema_branch(name=branch2.name) @@ -100,10 +106,15 @@ async def test_diff_and_merge_schema_with_default_values( car_schema_branch.attributes.append(AttributeSchema(name="num_cupholders", kind="Number", default_value=15)) car_schema_branch.attributes.append(AttributeSchema(name="is_cool", kind="Boolean", default_value=False)) car_schema_branch.attributes.append(AttributeSchema(name="nickname", kind="Text", default_value="car")) + # car_schema_branch.relationships = [rel for rel in car_schema_branch.relationships if rel != ""] schema_branch.set(name="TestCar", schema=car_schema_branch) schema_branch.process() await registry.schema.update_schema_branch( - db=db, branch=branch2, schema=schema_branch, limit=["TestCar", "TestPerson"], update_db=True + db=db, + branch=branch2, + schema=schema_branch, + limit=["TestCar", "TestPerson", "CoreGroup", "CoreNode"], + update_db=True, ) at = Timestamp() @@ -112,6 +123,11 @@ async def test_diff_and_merge_schema_with_default_values( diff_merger = await self._get_diff_merger(db=db, branch=branch2) await diff_merger.merge_graph(at=at) + # TODO it fails here, loaded schema only has TestCar/Person and corresponding profiles and CoreProfile. + # what is happening above with the `limit`? Are we saying we want only TestCar and TestPerson in branch2 + # then merge in main branch so there are only these schemas? + # 1. Is it acceptable considering TestCar/Person have a group_member link to CoreGroup? + # 2. Or is it something that can only happen internally while testing? How to solve it except adding every missing schemas (cumbersome)? updated_schema = await registry.schema.load_schema_from_db(db=db, branch=default_branch) car_schema_main = updated_schema.get(name="TestCar", duplicate=False) new_int_attr = car_schema_main.get_attribute(name="num_cupholders") diff --git a/backend/tests/unit/core/test_node_manager_delete.py b/backend/tests/unit/core/test_node_manager_delete.py index 88ed38b8b0..e41923b449 100644 --- a/backend/tests/unit/core/test_node_manager_delete.py +++ b/backend/tests/unit/core/test_node_manager_delete.py @@ -11,6 +11,9 @@ from infrahub.core.schema.schema_branch import SchemaBranch from infrahub.database import InfrahubDatabase from infrahub.exceptions import ValidationError +from tests.constants import TestKind +from tests.helpers.schema import CAR_SCHEMA, load_schema +from tests.helpers.test_app import TestInfrahubApp async def test_delete_succeeds( @@ -202,3 +205,31 @@ async def test_delete_with_cascade_on_generic_allowed(db, default_branch, depend assert {d.id for d in deleted} == {human.id, dog.id} node_map = await NodeManager.get_many(db=db, ids=[human.id, dog.id]) assert node_map == {} + + +class TestDeleteUnidirectionalRelationship(TestInfrahubApp): + async def test_delete_unidirectional_optional_relationship(self, db, client, default_branch): + await load_schema(db, schema=CAR_SCHEMA) + + owner = await Node.init(schema=TestKind.PERSON, db=db) + await owner.new(db=db, name="John Doe", height=175) + await owner.save(db=db) + + previous_owner = await Node.init(schema=TestKind.PERSON, db=db) + await previous_owner.new(db=db, name="Eric", height=175) + await previous_owner.save(db=db) + + koenigsegg = await Node.init(schema=TestKind.MANUFACTURER, db=db) + await koenigsegg.new(db=db, name="Koenigsegg") + await koenigsegg.save(db=db) + + car = await Node.init(schema=TestKind.CAR, db=db) + await car.new( + db=db, name="Jesko", color="Red", owner=owner, manufacturer=koenigsegg, previous_owner=previous_owner + ) + await car.save(db=db) + + await previous_owner.delete(db=db) + res = await NodeManager.get_many(db=db, ids=[car.id]) + rels = await res[car.id].previous_owner.get_relationships(db=db) + assert len(rels) == 0