Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor diff calculation query to get less information #4376

Merged
merged 10 commits into from
Sep 27, 2024
56 changes: 44 additions & 12 deletions backend/infrahub/core/diff/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,64 @@
from infrahub.core.timestamp import Timestamp
from infrahub.database import InfrahubDatabase

from .model.path import CalculatedDiffs
from .model.path import CalculatedDiffs, NodeFieldSpecifier


class DiffCalculator:
def __init__(self, db: InfrahubDatabase) -> None:
self.db = db

async def calculate_diff(
self, base_branch: Branch, diff_branch: Branch, from_time: Timestamp, to_time: Timestamp
self,
base_branch: Branch,
diff_branch: Branch,
from_time: Timestamp,
to_time: Timestamp,
previous_node_specifiers: set[NodeFieldSpecifier] | None = None,
) -> CalculatedDiffs:
diff_query = await DiffAllPathsQuery.init(
if diff_branch.name == registry.default_branch:
diff_branch_create_time = from_time
else:
diff_branch_create_time = Timestamp(diff_branch.get_created_at())
diff_parser = DiffQueryParser(
base_branch=base_branch,
diff_branch=diff_branch,
schema_manager=registry.schema,
from_time=from_time,
to_time=to_time,
)
branch_diff_query = await DiffAllPathsQuery.init(
db=self.db,
branch=diff_branch,
base_branch=base_branch,
diff_branch_create_time=diff_branch_create_time,
diff_from=from_time,
diff_to=to_time,
)
await diff_query.execute(db=self.db)
diff_parser = DiffQueryParser(
diff_query=diff_query,
base_branch_name=base_branch.name,
diff_branch_name=diff_branch.name,
schema_manager=registry.schema,
from_time=from_time,
to_time=to_time,
)
await branch_diff_query.execute(db=self.db)
for query_result in branch_diff_query.get_results():
diff_parser.read_result(query_result=query_result)

if base_branch.name != diff_branch.name:
branch_node_specifiers = diff_parser.get_node_field_specifiers_for_branch(branch_name=diff_branch.name)
new_node_field_specifiers = branch_node_specifiers - (previous_node_specifiers or set())
current_node_field_specifiers = (previous_node_specifiers or set()) - new_node_field_specifiers
base_diff_query = await DiffAllPathsQuery.init(
db=self.db,
branch=base_branch,
base_branch=base_branch,
diff_branch_create_time=diff_branch_create_time,
diff_from=from_time,
diff_to=to_time,
current_node_field_specifiers=[
(nfs.node_uuid, nfs.field_name) for nfs in current_node_field_specifiers
],
new_node_field_specifiers=[(nfs.node_uuid, nfs.field_name) for nfs in new_node_field_specifiers],
)
await base_diff_query.execute(db=self.db)
for query_result in base_diff_query.get_results():
diff_parser.read_result(query_result=query_result)

diff_parser.parse()
return CalculatedDiffs(
base_branch_name=base_branch.name,
Expand Down
42 changes: 30 additions & 12 deletions backend/infrahub/core/diff/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
EnrichedDiffProperty,
EnrichedDiffRelationship,
EnrichedDiffRoot,
EnrichedDiffs,
EnrichedDiffSingleRelationship,
)

Expand Down Expand Up @@ -356,16 +357,33 @@ def _link_child_nodes(self, nodes: Iterable[EnrichedDiffNode]) -> None:
parent_rel = child_node.get_relationship(name=parent_rel_name)
parent_rel.nodes.add(parent_node)

async def combine(self, earlier_diff: EnrichedDiffRoot, later_diff: EnrichedDiffRoot) -> EnrichedDiffRoot:
self._initialize(earlier_diff=earlier_diff, later_diff=later_diff)
filtered_node_pairs = self._filter_nodes_to_keep(earlier_diff=earlier_diff, later_diff=later_diff)
combined_nodes = self._combine_nodes(node_pairs=filtered_node_pairs)
self._link_child_nodes(nodes=combined_nodes)
return EnrichedDiffRoot(
uuid=str(uuid4()),
base_branch_name=later_diff.base_branch_name,
diff_branch_name=later_diff.diff_branch_name,
from_time=earlier_diff.from_time,
to_time=later_diff.to_time,
nodes=combined_nodes,
async def combine(self, earlier_diffs: EnrichedDiffs, later_diffs: EnrichedDiffs) -> EnrichedDiffs:
combined_diffs: list[EnrichedDiffRoot] = []
for earlier, later in (
(earlier_diffs.base_branch_diff, later_diffs.base_branch_diff),
(earlier_diffs.diff_branch_diff, later_diffs.diff_branch_diff),
):
self._initialize(earlier_diff=earlier, later_diff=later)
filtered_node_pairs = self._filter_nodes_to_keep(earlier_diff=earlier, later_diff=later)
combined_nodes = self._combine_nodes(node_pairs=filtered_node_pairs)
self._link_child_nodes(nodes=combined_nodes)
combined_diffs.append(
EnrichedDiffRoot(
uuid=str(uuid4()),
partner_uuid=later.partner_uuid,
base_branch_name=later.base_branch_name,
diff_branch_name=later.diff_branch_name,
from_time=earlier.from_time,
to_time=later.to_time,
nodes=combined_nodes,
)
)
base_branch_diff, diff_branch_diff = combined_diffs # pylint: disable=unbalanced-tuple-unpacking
base_branch_diff.partner_uuid = diff_branch_diff.uuid
diff_branch_diff.partner_uuid = base_branch_diff.uuid
return EnrichedDiffs(
base_branch_name=later_diffs.base_branch_name,
diff_branch_name=later_diffs.diff_branch_name,
base_branch_diff=base_branch_diff,
diff_branch_diff=diff_branch_diff,
)
18 changes: 7 additions & 11 deletions backend/infrahub/core/diff/conflicts_enricher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from infrahub.core.constants import DiffAction, RelationshipCardinality
from infrahub.core.constants.database import DatabaseEdgeType
from infrahub.database import InfrahubDatabase

from .model.path import (
EnrichedDiffAttribute,
Expand All @@ -16,10 +15,9 @@


class ConflictsEnricher:
def __init__(self, db: InfrahubDatabase) -> None:
def __init__(self) -> None:
self._base_branch_name: str | None = None
self._diff_branch_name: str | None = None
self.schema_manager = db.schema

@property
def base_branch_name(self) -> str:
Expand Down Expand Up @@ -66,7 +64,6 @@ def _add_node_conflicts(self, base_node: EnrichedDiffNode, branch_node: Enriched
base_relationship = base_relationship_map[relationship_name]
branch_relationship = branch_relationship_map[relationship_name]
self._add_relationship_conflicts(
branch_node=branch_node,
base_relationship=base_relationship,
branch_relationship=branch_relationship,
)
Expand Down Expand Up @@ -100,7 +97,11 @@ def _add_attribute_conflicts(
for property_type in common_property_types:
base_property = base_property_map[property_type]
branch_property = branch_property_map[property_type]
if base_property.new_value != branch_property.new_value:
same_value = base_property.new_value == branch_property.new_value or (
base_property.action is DiffAction.UNCHANGED
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this if the action is DELETE and the values would be the same regardless?

and base_property.previous_value == branch_property.previous_value
)
if not same_value:
self._add_property_conflict(
base_property=base_property,
branch_property=branch_property,
Expand All @@ -110,15 +111,10 @@ def _add_attribute_conflicts(

def _add_relationship_conflicts(
self,
branch_node: EnrichedDiffNode,
base_relationship: EnrichedDiffRelationship,
branch_relationship: EnrichedDiffRelationship,
) -> None:
node_schema = self.schema_manager.get_node_schema(
name=branch_node.kind, branch=self.diff_branch_name, duplicate=False
)
relationship_schema = node_schema.get_relationship(name=branch_relationship.name)
is_cardinality_one = relationship_schema.cardinality is RelationshipCardinality.ONE
is_cardinality_one = branch_relationship.cardinality is RelationshipCardinality.ONE
if is_cardinality_one:
if not base_relationship.relationships or not branch_relationship.relationships:
return
Expand Down
Loading
Loading