Skip to content

Commit

Permalink
add new WIP DiffMerger class (#4505)
Browse files Browse the repository at this point in the history
* IFC-675 start new DiffMerger class

* add support for node conflicts

* neo4j WITH fix
  • Loading branch information
ajtmccarty authored Oct 7, 2024
1 parent eb1f86a commit 74fd1b4
Show file tree
Hide file tree
Showing 6 changed files with 394 additions and 0 deletions.
Empty file.
47 changes: 47 additions & 0 deletions backend/infrahub/core/diff/merger/merger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from infrahub.core.diff.model.path import BranchTrackingId
from infrahub.core.diff.query.merge import DiffMergeQuery

if TYPE_CHECKING:
from infrahub.core.branch import Branch
from infrahub.core.diff.repository.repository import DiffRepository
from infrahub.core.timestamp import Timestamp
from infrahub.database import InfrahubDatabase

from .serializer import DiffMergeSerializer


class DiffMerger:
def __init__(
self,
db: InfrahubDatabase,
source_branch: Branch,
destination_branch: Branch,
diff_repository: DiffRepository,
serializer: DiffMergeSerializer,
):
self.source_branch = source_branch
self.destination_branch = destination_branch
self.db = db
self.diff_repository = diff_repository
self.serializer = serializer

async def merge_graph(self, at: Timestamp) -> None:
enriched_diff = await self.diff_repository.get_one(
diff_branch_name=self.source_branch.name, tracking_id=BranchTrackingId(name=self.source_branch.name)
)
node_diff_dicts = await self.serializer.serialize(diff=enriched_diff)
merge_query = await DiffMergeQuery.init(
db=self.db,
branch=self.source_branch,
at=at,
target_branch=self.destination_branch,
node_diff_dicts=node_diff_dicts,
)
await merge_query.execute(db=self.db)

self.source_branch.branched_from = at.to_string()
await self.source_branch.save(db=self.db)
28 changes: 28 additions & 0 deletions backend/infrahub/core/diff/merger/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any

from infrahub.core.constants import DiffAction

from ..model.path import ConflictSelection, EnrichedDiffConflict, EnrichedDiffRoot


class DiffMergeSerializer:
def _get_action(self, action: DiffAction, conflict: EnrichedDiffConflict | None) -> DiffAction:
if not conflict:
return action
if conflict.selected_branch is ConflictSelection.BASE_BRANCH:
return conflict.base_branch_action
if conflict.selected_branch is ConflictSelection.DIFF_BRANCH:
return conflict.diff_branch_action
raise ValueError(f"conflict {conflict.uuid} does not have a branch selection")

async def serialize(self, diff: EnrichedDiffRoot) -> list[dict[str, Any]]:
serialized_node_diffs = []
for node in diff.nodes:
node_action = self._get_action(action=node.action, conflict=node.conflict)
serialized_node_diffs.append(
{
"action": str(node_action.value).upper(),
"uuid": node.uuid,
}
)
return serialized_node_diffs
95 changes: 95 additions & 0 deletions backend/infrahub/core/diff/query/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from infrahub.core.query import Query, QueryType

if TYPE_CHECKING:
from infrahub.core.branch import Branch
from infrahub.core.timestamp import Timestamp
from infrahub.database import InfrahubDatabase


class DiffMergeQuery(Query):
name = "diff_merge"
type = QueryType.WRITE
insert_return = False

def __init__(
self,
node_diff_dicts: dict[str, Any],
at: Timestamp,
target_branch: Branch,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.node_diff_dicts = node_diff_dicts
self.at = at
self.target_branch = target_branch
self.source_branch_name = self.branch.name

async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
self.params = {
"node_diff_dicts": self.node_diff_dicts,
"at": self.at.to_string(),
"branch_level": self.target_branch.hierarchy_level,
"target_branch": self.target_branch.name,
"source_branch": self.source_branch_name,
}
query = """
UNWIND $node_diff_dicts AS node_diff_map
CALL {
WITH node_diff_map
WITH node_diff_map, CASE
WHEN node_diff_map.action = "ADDED" THEN "active"
WHEN node_diff_map.action = "REMOVED" THEN "deleted"
ELSE NULL
END AS node_rel_status
CALL {
// ------------------------------
// only make IS_PART_OF updates if node is ADDED or REMOVED
// ------------------------------
WITH node_diff_map, node_rel_status
WITH node_diff_map, node_rel_status
WHERE node_rel_status IS NOT NULL
MATCH (root:Root)
MATCH (n:Node {uuid: node_diff_map.uuid})
// ------------------------------
// check if IS_PART_OF relationship with node_rel_status already exists on the target branch
// ------------------------------
CALL {
WITH root, n, node_rel_status
OPTIONAL MATCH (root)<-[r_root:IS_PART_OF {branch: $target_branch}]-(n)
WHERE r_root.status = node_rel_status
AND r_root.from <= $at
AND (r_root.to >= $at OR r_root.to IS NULL)
RETURN r_root
}
// ------------------------------
// set IS_PART_OF.to on source branch and, optionally, target branch
// ------------------------------
WITH root, r_root, n, node_rel_status
CALL {
WITH root, n, node_rel_status
OPTIONAL MATCH (root)<-[source_r_root:IS_PART_OF {branch: $source_branch, status: node_rel_status}]-(n)
WHERE source_r_root.from <= $at AND source_r_root.to IS NULL
SET source_r_root.to = $at
}
WITH root, r_root, n, node_rel_status
CALL {
WITH root, n, node_rel_status
OPTIONAL MATCH (root)<-[target_r_root:IS_PART_OF {branch: $target_branch, status: "active"}]-(n)
WHERE node_rel_status = "deleted"
AND target_r_root.from <= $at AND target_r_root.to IS NULL
SET target_r_root.to = $at
}
// ------------------------------
// create new IS_PART_OF relationship on target_branch
// ------------------------------
WITH root, r_root, n, node_rel_status
WHERE r_root IS NULL
CREATE (root)<-[:IS_PART_OF { branch: $target_branch, branch_level: $branch_level, from: $at, status: node_rel_status }]-(n)
}
}
"""
self.add_to_query(query=query)
3 changes: 3 additions & 0 deletions backend/infrahub/core/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def get_id(self) -> str:

raise InitializationError("The node has not been saved yet and doesn't have an id")

def get_updated_at(self) -> Timestamp | None:
return self._updated_at

async def get_hfid(self, db: InfrahubDatabase, include_kind: bool = False) -> Optional[list[str]]:
"""Return the Human friendly id of the node."""
if not self._schema.human_friendly_id:
Expand Down
Loading

0 comments on commit 74fd1b4

Please sign in to comment.