diff --git a/backend/infrahub/core/diff/merger/__init__.py b/backend/infrahub/core/diff/merger/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/infrahub/core/diff/merger/merger.py b/backend/infrahub/core/diff/merger/merger.py new file mode 100644 index 0000000000..0792f2549f --- /dev/null +++ b/backend/infrahub/core/diff/merger/merger.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from infrahub.core.diff.model.path import BranchTrackingId +from infrahub.core.diff.query.merge import DiffMergeQuery + +if TYPE_CHECKING: + from infrahub.core.branch import Branch + from infrahub.core.diff.repository.repository import DiffRepository + from infrahub.core.timestamp import Timestamp + from infrahub.database import InfrahubDatabase + + from .serializer import DiffMergeSerializer + + +class DiffMerger: + def __init__( + self, + db: InfrahubDatabase, + source_branch: Branch, + destination_branch: Branch, + diff_repository: DiffRepository, + serializer: DiffMergeSerializer, + ): + self.source_branch = source_branch + self.destination_branch = destination_branch + self.db = db + self.diff_repository = diff_repository + self.serializer = serializer + + async def merge_graph(self, at: Timestamp) -> None: + enriched_diff = await self.diff_repository.get_one( + diff_branch_name=self.source_branch.name, tracking_id=BranchTrackingId(name=self.source_branch.name) + ) + node_diff_dicts = await self.serializer.serialize(diff=enriched_diff) + merge_query = await DiffMergeQuery.init( + db=self.db, + branch=self.source_branch, + at=at, + target_branch=self.destination_branch, + node_diff_dicts=node_diff_dicts, + ) + await merge_query.execute(db=self.db) + + self.source_branch.branched_from = at.to_string() + await self.source_branch.save(db=self.db) diff --git a/backend/infrahub/core/diff/merger/serializer.py b/backend/infrahub/core/diff/merger/serializer.py new file mode 100644 index 0000000000..52aad8ad28 --- /dev/null +++ b/backend/infrahub/core/diff/merger/serializer.py @@ -0,0 +1,28 @@ +from typing import Any + +from infrahub.core.constants import DiffAction + +from ..model.path import ConflictSelection, EnrichedDiffConflict, EnrichedDiffRoot + + +class DiffMergeSerializer: + def _get_action(self, action: DiffAction, conflict: EnrichedDiffConflict | None) -> DiffAction: + if not conflict: + return action + if conflict.selected_branch is ConflictSelection.BASE_BRANCH: + return conflict.base_branch_action + if conflict.selected_branch is ConflictSelection.DIFF_BRANCH: + return conflict.diff_branch_action + raise ValueError(f"conflict {conflict.uuid} does not have a branch selection") + + async def serialize(self, diff: EnrichedDiffRoot) -> list[dict[str, Any]]: + serialized_node_diffs = [] + for node in diff.nodes: + node_action = self._get_action(action=node.action, conflict=node.conflict) + serialized_node_diffs.append( + { + "action": str(node_action.value).upper(), + "uuid": node.uuid, + } + ) + return serialized_node_diffs diff --git a/backend/infrahub/core/diff/query/merge.py b/backend/infrahub/core/diff/query/merge.py new file mode 100644 index 0000000000..b9084431db --- /dev/null +++ b/backend/infrahub/core/diff/query/merge.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from infrahub.core.query import Query, QueryType + +if TYPE_CHECKING: + from infrahub.core.branch import Branch + from infrahub.core.timestamp import Timestamp + from infrahub.database import InfrahubDatabase + + +class DiffMergeQuery(Query): + name = "diff_merge" + type = QueryType.WRITE + insert_return = False + + def __init__( + self, + node_diff_dicts: dict[str, Any], + at: Timestamp, + target_branch: Branch, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.node_diff_dicts = node_diff_dicts + self.at = at + self.target_branch = target_branch + self.source_branch_name = self.branch.name + + async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: + self.params = { + "node_diff_dicts": self.node_diff_dicts, + "at": self.at.to_string(), + "branch_level": self.target_branch.hierarchy_level, + "target_branch": self.target_branch.name, + "source_branch": self.source_branch_name, + } + query = """ +UNWIND $node_diff_dicts AS node_diff_map +CALL { + WITH node_diff_map + WITH node_diff_map, CASE + WHEN node_diff_map.action = "ADDED" THEN "active" + WHEN node_diff_map.action = "REMOVED" THEN "deleted" + ELSE NULL + END AS node_rel_status + CALL { + // ------------------------------ + // only make IS_PART_OF updates if node is ADDED or REMOVED + // ------------------------------ + WITH node_diff_map, node_rel_status + WITH node_diff_map, node_rel_status + WHERE node_rel_status IS NOT NULL + MATCH (root:Root) + MATCH (n:Node {uuid: node_diff_map.uuid}) + // ------------------------------ + // check if IS_PART_OF relationship with node_rel_status already exists on the target branch + // ------------------------------ + CALL { + WITH root, n, node_rel_status + OPTIONAL MATCH (root)<-[r_root:IS_PART_OF {branch: $target_branch}]-(n) + WHERE r_root.status = node_rel_status + AND r_root.from <= $at + AND (r_root.to >= $at OR r_root.to IS NULL) + RETURN r_root + } + // ------------------------------ + // set IS_PART_OF.to on source branch and, optionally, target branch + // ------------------------------ + WITH root, r_root, n, node_rel_status + CALL { + WITH root, n, node_rel_status + OPTIONAL MATCH (root)<-[source_r_root:IS_PART_OF {branch: $source_branch, status: node_rel_status}]-(n) + WHERE source_r_root.from <= $at AND source_r_root.to IS NULL + SET source_r_root.to = $at + } + WITH root, r_root, n, node_rel_status + CALL { + WITH root, n, node_rel_status + OPTIONAL MATCH (root)<-[target_r_root:IS_PART_OF {branch: $target_branch, status: "active"}]-(n) + WHERE node_rel_status = "deleted" + AND target_r_root.from <= $at AND target_r_root.to IS NULL + SET target_r_root.to = $at + } + // ------------------------------ + // create new IS_PART_OF relationship on target_branch + // ------------------------------ + WITH root, r_root, n, node_rel_status + WHERE r_root IS NULL + CREATE (root)<-[:IS_PART_OF { branch: $target_branch, branch_level: $branch_level, from: $at, status: node_rel_status }]-(n) + } +} + """ + self.add_to_query(query=query) diff --git a/backend/infrahub/core/node/__init__.py b/backend/infrahub/core/node/__init__.py index ad81ec9411..85db92eda9 100644 --- a/backend/infrahub/core/node/__init__.py +++ b/backend/infrahub/core/node/__init__.py @@ -66,6 +66,9 @@ def get_id(self) -> str: raise InitializationError("The node has not been saved yet and doesn't have an id") + def get_updated_at(self) -> Timestamp | None: + return self._updated_at + async def get_hfid(self, db: InfrahubDatabase, include_kind: bool = False) -> Optional[list[str]]: """Return the Human friendly id of the node.""" if not self._schema.human_friendly_id: diff --git a/backend/tests/unit/core/diff/test_diff_merger.py b/backend/tests/unit/core/diff/test_diff_merger.py new file mode 100644 index 0000000000..f30e3480e3 --- /dev/null +++ b/backend/tests/unit/core/diff/test_diff_merger.py @@ -0,0 +1,221 @@ +from unittest.mock import AsyncMock, call +from uuid import uuid4 + +import pytest + +from infrahub.core.branch import Branch +from infrahub.core.constants import DiffAction +from infrahub.core.diff.merger.merger import DiffMerger +from infrahub.core.diff.merger.serializer import DiffMergeSerializer +from infrahub.core.diff.model.path import ( + BranchTrackingId, + ConflictSelection, + EnrichedDiffNode, + EnrichedDiffRoot, +) +from infrahub.core.diff.repository.repository import DiffRepository +from infrahub.core.initialization import create_branch +from infrahub.core.manager import NodeManager +from infrahub.core.node import Node +from infrahub.core.timestamp import Timestamp +from infrahub.database import InfrahubDatabase +from infrahub.exceptions import NodeNotFoundError +from tests.unit.core.diff.factories import EnrichedConflictFactory, EnrichedNodeFactory, EnrichedRootFactory + + +class TestMergeDiff: + @pytest.fixture + async def source_branch(self, db: InfrahubDatabase, default_branch: Branch) -> Branch: + return await create_branch(db=db, branch_name="source") + + @pytest.fixture + def mock_diff_repository(self) -> DiffRepository: + return AsyncMock(spec=DiffRepository) + + @pytest.fixture + def diff_merger( + self, db: InfrahubDatabase, default_branch: Branch, source_branch: Branch, mock_diff_repository: DiffRepository + ) -> DiffMerger: + return DiffMerger( + db=db, + source_branch=source_branch, + destination_branch=default_branch, + diff_repository=mock_diff_repository, + serializer=DiffMergeSerializer(), + ) + + @pytest.fixture + async def person_node_branch(self, db: InfrahubDatabase, source_branch: Branch, car_person_schema) -> Node: + new_node = await Node.init(db=db, schema="TestPerson", branch=source_branch) + await new_node.new(db=db, name="Albert", height=172) + await new_node.save(db=db) + return new_node + + @pytest.fixture + async def person_node_main(self, db: InfrahubDatabase, default_branch: Branch, car_person_schema) -> Node: + new_node = await Node.init(db=db, schema="TestPerson", branch=default_branch) + await new_node.new(db=db, name="Albert", height=172) + await new_node.save(db=db) + return new_node + + @pytest.fixture + def empty_diff_root(self, default_branch: Branch, source_branch: Branch) -> EnrichedDiffRoot: + return EnrichedRootFactory.build( + base_branch_name=default_branch.name, + diff_branch_name=source_branch.name, + from_time=Timestamp(source_branch.get_created_at()), + to_time=Timestamp(), + uuid=str(uuid4()), + partner_uuid=str(uuid4()), + tracking_id=BranchTrackingId(name=source_branch.name), + nodes=set(), + ) + + def _get_empty_node_diff(self, node: Node, action: DiffAction) -> EnrichedDiffNode: + return EnrichedNodeFactory.build( + uuid=node.get_id(), action=action, kind=node.get_kind(), label="", attributes=set(), relationships=set() + ) + + async def test_merge_node_added( + self, + db: InfrahubDatabase, + default_branch: Branch, + source_branch: Branch, + person_node_branch: Node, + mock_diff_repository: DiffRepository, + diff_merger: DiffMerger, + empty_diff_root: EnrichedDiffRoot, + ): + added_node_diff = self._get_empty_node_diff(node=person_node_branch, action=DiffAction.ADDED) + empty_diff_root.nodes = {added_node_diff} + mock_diff_repository.get_one.return_value = empty_diff_root + at = Timestamp() + + await diff_merger.merge_graph(at=at) + + mock_diff_repository.get_one.assert_awaited_once_with( + diff_branch_name=source_branch.name, tracking_id=BranchTrackingId(name=source_branch.name) + ) + target_car = await NodeManager.get_one(db=db, branch=default_branch, id=person_node_branch.id) + assert target_car.id == person_node_branch.id + assert target_car.get_updated_at() == at + + async def test_merge_node_added_idempotent( + self, + db: InfrahubDatabase, + default_branch: Branch, + source_branch: Branch, + person_node_branch: Node, + mock_diff_repository: DiffRepository, + diff_merger: DiffMerger, + empty_diff_root: EnrichedDiffRoot, + ): + added_node_diff = self._get_empty_node_diff(node=person_node_branch, action=DiffAction.ADDED) + empty_diff_root.nodes = {added_node_diff} + mock_diff_repository.get_one.return_value = empty_diff_root + at = Timestamp() + + await diff_merger.merge_graph(at=at) + await diff_merger.merge_graph(at=at) + + assert mock_diff_repository.get_one.await_args_list == [ + call(diff_branch_name=source_branch.name, tracking_id=BranchTrackingId(name=source_branch.name)), + call(diff_branch_name=source_branch.name, tracking_id=BranchTrackingId(name=source_branch.name)), + ] + target_car = await NodeManager.get_one(db=db, branch=default_branch, id=person_node_branch.id) + assert target_car.id == person_node_branch.id + assert target_car.get_updated_at() == at + + async def test_merge_node_deleted( + self, + db: InfrahubDatabase, + default_branch: Branch, + person_node_main: Node, + source_branch: Branch, + mock_diff_repository: DiffRepository, + diff_merger: DiffMerger, + empty_diff_root: EnrichedDiffRoot, + ): + person_node_branch = await NodeManager.get_one(db=db, branch=source_branch, id=person_node_main.id) + await person_node_branch.delete(db=db) + deleted_node_diff = self._get_empty_node_diff(node=person_node_branch, action=DiffAction.REMOVED) + empty_diff_root.nodes = {deleted_node_diff} + mock_diff_repository.get_one.return_value = empty_diff_root + at = Timestamp() + + await diff_merger.merge_graph(at=at) + + mock_diff_repository.get_one.assert_awaited_once_with( + diff_branch_name=source_branch.name, tracking_id=BranchTrackingId(name=source_branch.name) + ) + with pytest.raises(NodeNotFoundError): + await NodeManager.get_one(db=db, branch=default_branch, id=person_node_main.id, raise_on_error=True) + + async def test_merge_node_deleted_idempotent( + self, + db: InfrahubDatabase, + default_branch: Branch, + person_node_main: Node, + source_branch: Branch, + mock_diff_repository: DiffRepository, + diff_merger: DiffMerger, + empty_diff_root: EnrichedDiffRoot, + ): + person_node_branch = await NodeManager.get_one(db=db, branch=source_branch, id=person_node_main.id) + await person_node_branch.delete(db=db) + deleted_node_diff = self._get_empty_node_diff(node=person_node_branch, action=DiffAction.REMOVED) + empty_diff_root.nodes = {deleted_node_diff} + mock_diff_repository.get_one.return_value = empty_diff_root + at = Timestamp() + + await diff_merger.merge_graph(at=at) + await diff_merger.merge_graph(at=at) + + assert mock_diff_repository.get_one.await_args_list == [ + call(diff_branch_name=source_branch.name, tracking_id=BranchTrackingId(name=source_branch.name)), + call(diff_branch_name=source_branch.name, tracking_id=BranchTrackingId(name=source_branch.name)), + ] + with pytest.raises(NodeNotFoundError): + await NodeManager.get_one(db=db, branch=default_branch, id=person_node_main.id, raise_on_error=True) + + @pytest.mark.parametrize( + "conflict_selection,expect_deleted", + [(ConflictSelection.DIFF_BRANCH, True), (ConflictSelection.BASE_BRANCH, False)], + ) + async def test_merge_node_deleted_with_conflict( + self, + db: InfrahubDatabase, + default_branch: Branch, + person_node_main: Node, + source_branch: Branch, + mock_diff_repository: DiffRepository, + diff_merger: DiffMerger, + empty_diff_root: EnrichedDiffRoot, + conflict_selection: ConflictSelection, + expect_deleted: bool, + ): + person_node_branch = await NodeManager.get_one(db=db, branch=source_branch, id=person_node_main.id) + await person_node_branch.delete(db=db) + deleted_node_diff = self._get_empty_node_diff(node=person_node_branch, action=DiffAction.REMOVED) + node_conflict = EnrichedConflictFactory.build( + base_branch_action=DiffAction.UPDATED, + diff_branch_action=DiffAction.REMOVED, + selected_branch=conflict_selection, + ) + deleted_node_diff.conflict = node_conflict + empty_diff_root.nodes = {deleted_node_diff} + mock_diff_repository.get_one.return_value = empty_diff_root + at = Timestamp() + + await diff_merger.merge_graph(at=at) + + mock_diff_repository.get_one.assert_awaited_once_with( + diff_branch_name=source_branch.name, tracking_id=BranchTrackingId(name=source_branch.name) + ) + if expect_deleted: + with pytest.raises(NodeNotFoundError): + await NodeManager.get_one(db=db, branch=default_branch, id=person_node_main.id, raise_on_error=True) + else: + target_car = await NodeManager.get_one(db=db, branch=default_branch, id=person_node_branch.id) + assert target_car.id == person_node_branch.id + assert target_car.get_updated_at() < at