diff --git a/backend/infrahub/core/branch/tasks.py b/backend/infrahub/core/branch/tasks.py index 41f5796d05..ff5661e61d 100644 --- a/backend/infrahub/core/branch/tasks.py +++ b/backend/infrahub/core/branch/tasks.py @@ -70,7 +70,7 @@ async def rebase_branch(branch: str) -> None: service=service, ) diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=obj) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=base_branch, diff_branch=obj) + enriched_diff = await diff_coordinator.update_branch_diff_and_return(base_branch=base_branch, diff_branch=obj) if enriched_diff.get_all_conflicts(): raise ValidationError( f"Branch {obj.name} contains conflicts with the default branch that must be addressed." @@ -185,6 +185,7 @@ async def merge_branch(branch: str) -> None: try: await merger.merge() except Exception as exc: + log.exception("Merge failed, beginning rollback") await merger.rollback() raise MergeFailedError(branch_name=branch) from exc await merger.update_schema() diff --git a/backend/infrahub/core/diff/combiner.py b/backend/infrahub/core/diff/combiner.py index 602762e742..880b8fb170 100644 --- a/backend/infrahub/core/diff/combiner.py +++ b/backend/infrahub/core/diff/combiner.py @@ -320,6 +320,7 @@ def _combine_relationships( combined_relationship = EnrichedDiffRelationship( name=later_relationship.name, label=later_relationship.label, + identifier=later_relationship.identifier, cardinality=later_relationship.cardinality, changed_at=later_relationship.changed_at or earlier_relationship.changed_at, action=combined_action, diff --git a/backend/infrahub/core/diff/coordinator.py b/backend/infrahub/core/diff/coordinator.py index 88d237d40c..76119b5d0d 100644 --- a/backend/infrahub/core/diff/coordinator.py +++ b/backend/infrahub/core/diff/coordinator.py @@ -1,17 +1,19 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING, Iterable, Literal, Sequence, overload +from uuid import uuid4 from infrahub import lock -from infrahub.core import registry from infrahub.core.timestamp import Timestamp from infrahub.log import get_logger from .model.path import ( BranchTrackingId, EnrichedDiffRoot, + EnrichedDiffRootMetadata, EnrichedDiffs, + EnrichedDiffsMetadata, NameTrackingId, NodeFieldSpecifier, TrackingId, @@ -41,8 +43,17 @@ class EnrichedDiffRequest: diff_branch: Branch from_time: Timestamp to_time: Timestamp + tracking_id: TrackingId | None = field(default=None) node_field_specifiers: set[NodeFieldSpecifier] = field(default_factory=set) + def __repr__(self) -> str: + return ( + f"EnrichedDiffRequest(base_branch_name={self.base_branch.name}, diff_branch_name={self.diff_branch.name}," + f" from_time={self.from_time.to_string()}, to_time={self.to_time.to_string()}," + f" tracking_id={self.tracking_id.serialize() if self.tracking_id else None})," + f" num_node_field_specifiers={len(self.node_field_specifiers)}" + ) + class DiffCoordinator: lock_namespace = "diff-update" @@ -77,10 +88,11 @@ async def run_update( from_time: str | None = None, to_time: str | None = None, name: str | None = None, - ) -> EnrichedDiffRoot: + ) -> None: # we are updating a diff that tracks the full lifetime of a branch if not name and not from_time and not to_time: - return await self.update_branch_diff(base_branch=base_branch, diff_branch=diff_branch) + await self.update_branch_diff(base_branch=base_branch, diff_branch=diff_branch) + return if from_time: from_timestamp = Timestamp(from_time) @@ -90,7 +102,7 @@ async def run_update( to_timestamp = Timestamp(to_time) else: to_timestamp = Timestamp() - return await self.create_or_update_arbitrary_timeframe_diff( + await self.create_or_update_arbitrary_timeframe_diff( base_branch=base_branch, diff_branch=diff_branch, from_time=from_timestamp, @@ -104,8 +116,28 @@ def _get_lock_name(self, base_branch_name: str, diff_branch_name: str, is_increm lock_name += "__incremental" return lock_name - async def update_branch_diff(self, base_branch: Branch, diff_branch: Branch) -> EnrichedDiffRoot: - log.debug(f"Received request to update branch diff for {base_branch.name} - {diff_branch.name}") + async def update_branch_diff_and_return(self, base_branch: Branch, diff_branch: Branch) -> EnrichedDiffRoot: + enriched_diff = await self.update_branch_diff(base_branch=base_branch, diff_branch=diff_branch) + if isinstance(enriched_diff, EnrichedDiffRoot): + return enriched_diff + return await self._finalize_diff_root_metadata(diff_root_metadata=enriched_diff) + + async def _finalize_diff_root_metadata(self, diff_root_metadata: EnrichedDiffRootMetadata) -> EnrichedDiffRoot: + # if this is EnrichedDiffMetadata, we need to retrieve the full diff and set its metadata to match + full_enriched_diff = await self.diff_repo.get_one( + diff_branch_name=diff_root_metadata.diff_branch_name, diff_id=diff_root_metadata.uuid + ) + full_enriched_diff.update_metadata( + from_time=diff_root_metadata.from_time, + to_time=diff_root_metadata.to_time, + tracking_id=diff_root_metadata.tracking_id, + ) + return full_enriched_diff + + async def update_branch_diff( + self, base_branch: Branch, diff_branch: Branch + ) -> EnrichedDiffRoot | EnrichedDiffRootMetadata: + log.info(f"Received request to update branch diff for {base_branch.name} - {diff_branch.name}") incremental_lock_name = self._get_lock_name( base_branch_name=base_branch.name, diff_branch_name=diff_branch.name, is_incremental=True ) @@ -113,9 +145,9 @@ async def update_branch_diff(self, base_branch: Branch, diff_branch: Branch) -> name=incremental_lock_name, namespace=self.lock_namespace ) if existing_incremental_lock and await existing_incremental_lock.locked(): - log.debug(f"Branch diff update for {base_branch.name} - {diff_branch.name} already in progress") + log.info(f"Branch diff update for {base_branch.name} - {diff_branch.name} already in progress") async with self.lock_registry.get(name=incremental_lock_name, namespace=self.lock_namespace): - log.debug(f"Existing branch diff update for {base_branch.name} - {diff_branch.name} complete") + log.info(f"Existing branch diff update for {base_branch.name} - {diff_branch.name} complete") return await self.diff_repo.get_one( tracking_id=BranchTrackingId(name=diff_branch.name), diff_branch_name=diff_branch.name ) @@ -129,19 +161,24 @@ async def update_branch_diff(self, base_branch: Branch, diff_branch: Branch) -> self.lock_registry.get(name=general_lock_name, namespace=self.lock_namespace), self.lock_registry.get(name=incremental_lock_name, namespace=self.lock_namespace), ): - log.debug(f"Acquired lock to run branch diff update for {base_branch.name} - {diff_branch.name}") + log.info(f"Acquired lock to run branch diff update for {base_branch.name} - {diff_branch.name}") enriched_diffs = await self._update_diffs( base_branch=base_branch, diff_branch=diff_branch, from_time=from_time, to_time=to_time, tracking_id=tracking_id, + force_branch_refresh=False, ) + if not isinstance(enriched_diffs, EnrichedDiffs): + await self._update_core_data_checks(enriched_diff=enriched_diffs.diff_branch_diff) + return enriched_diffs.diff_branch_diff + await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.base_branch_diff) await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.diff_branch_diff) await self.diff_repo.save(enriched_diffs=enriched_diffs) await self._update_core_data_checks(enriched_diff=enriched_diffs.diff_branch_diff) - log.debug(f"Branch diff update complete for {base_branch.name} - {diff_branch.name}") + log.info(f"Branch diff update complete for {base_branch.name} - {diff_branch.name}") return enriched_diffs.diff_branch_diff async def create_or_update_arbitrary_timeframe_diff( @@ -159,19 +196,25 @@ async def create_or_update_arbitrary_timeframe_diff( base_branch_name=base_branch.name, diff_branch_name=diff_branch.name, is_incremental=False ) async with self.lock_registry.get(name=general_lock_name, namespace=self.lock_namespace): - log.debug(f"Acquired lock to run arbitrary diff update for {base_branch.name} - {diff_branch.name}") + log.info(f"Acquired lock to run arbitrary diff update for {base_branch.name} - {diff_branch.name}") enriched_diffs = await self._update_diffs( base_branch=base_branch, diff_branch=diff_branch, from_time=from_time, to_time=to_time, tracking_id=tracking_id, + force_branch_refresh=False, ) + # metadata-only diff, so no nodes to enrich + if not isinstance(enriched_diffs, EnrichedDiffs): + await self._update_core_data_checks(enriched_diff=enriched_diffs.diff_branch_diff) + return await self._finalize_diff_root_metadata(diff_root_metadata=enriched_diffs.diff_branch_diff) + await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.base_branch_diff) await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.diff_branch_diff) await self.diff_repo.save(enriched_diffs=enriched_diffs) await self._update_core_data_checks(enriched_diff=enriched_diffs.diff_branch_diff) - log.debug(f"Arbitrary diff update complete for {base_branch.name} - {diff_branch.name}") + log.info(f"Arbitrary diff update complete for {base_branch.name} - {diff_branch.name}") return enriched_diffs.diff_branch_diff async def recalculate( @@ -184,7 +227,7 @@ async def recalculate( base_branch_name=base_branch.name, diff_branch_name=diff_branch.name, is_incremental=False ) async with self.lock_registry.get(name=general_lock_name, namespace=self.lock_namespace): - log.debug(f"Acquired lock to recalculate diff for {base_branch.name} - {diff_branch.name}") + log.info(f"Acquired lock to recalculate diff for {base_branch.name} - {diff_branch.name}") current_branch_diff = await self.diff_repo.get_one(diff_branch_name=diff_branch.name, diff_id=diff_id) current_base_diff = await self.diff_repo.get_one( diff_branch_name=base_branch.name, diff_id=current_branch_diff.partner_uuid @@ -205,7 +248,6 @@ async def recalculate( tracking_id=current_branch_diff.tracking_id, force_branch_refresh=True, ) - if current_branch_diff: await self.conflict_transferer.transfer( earlier=current_branch_diff, later=enriched_diffs.diff_branch_diff @@ -215,16 +257,16 @@ async def recalculate( await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.diff_branch_diff) await self.diff_repo.save(enriched_diffs=enriched_diffs) await self._update_core_data_checks(enriched_diff=enriched_diffs.diff_branch_diff) - log.debug(f"Diff recalculation complete for {base_branch.name} - {diff_branch.name}") + log.info(f"Diff recalculation complete for {base_branch.name} - {diff_branch.name}") return enriched_diffs.diff_branch_diff def _get_ordered_diff_pairs( - self, diff_pairs: Iterable[EnrichedDiffs], allow_overlap: bool = False - ) -> list[EnrichedDiffs]: + self, diff_pairs: Iterable[EnrichedDiffsMetadata], allow_overlap: bool = False + ) -> list[EnrichedDiffsMetadata]: ordered_diffs = sorted(diff_pairs, key=lambda d: d.diff_branch_diff.from_time) if allow_overlap: return ordered_diffs - ordered_diffs_no_overlaps: list[EnrichedDiffs] = [] + ordered_diffs_no_overlaps: list[EnrichedDiffsMetadata] = [] for candidate_diff_pair in ordered_diffs: if not ordered_diffs_no_overlaps: ordered_diffs_no_overlaps.append(candidate_diff_pair) @@ -242,6 +284,54 @@ def _get_ordered_diff_pairs( ordered_diffs_no_overlaps[-1] = candidate_diff_pair return ordered_diffs_no_overlaps + def _build_enriched_diffs_with_no_nodes(self, diff_request: EnrichedDiffRequest) -> EnrichedDiffs: + base_uuid = str(uuid4()) + branch_uuid = str(uuid4()) + return EnrichedDiffs( + base_branch_name=diff_request.base_branch.name, + diff_branch_name=diff_request.diff_branch.name, + base_branch_diff=EnrichedDiffRoot( + base_branch_name=diff_request.base_branch.name, + diff_branch_name=diff_request.base_branch.name, + from_time=diff_request.from_time, + to_time=diff_request.to_time, + tracking_id=diff_request.tracking_id, + uuid=base_uuid, + partner_uuid=branch_uuid, + ), + diff_branch_diff=EnrichedDiffRoot( + base_branch_name=diff_request.base_branch.name, + diff_branch_name=diff_request.diff_branch.name, + from_time=diff_request.from_time, + to_time=diff_request.to_time, + tracking_id=diff_request.tracking_id, + uuid=branch_uuid, + partner_uuid=base_uuid, + ), + ) + + @overload + async def _update_diffs( + self, + base_branch: Branch, + diff_branch: Branch, + from_time: Timestamp, + to_time: Timestamp, + tracking_id: TrackingId | None = None, + force_branch_refresh: Literal[True] = ..., + ) -> EnrichedDiffs: ... + + @overload + async def _update_diffs( + self, + base_branch: Branch, + diff_branch: Branch, + from_time: Timestamp, + to_time: Timestamp, + tracking_id: TrackingId | None = None, + force_branch_refresh: Literal[False] = ..., + ) -> EnrichedDiffs | EnrichedDiffsMetadata: ... + async def _update_diffs( self, base_branch: Branch, @@ -250,29 +340,45 @@ async def _update_diffs( to_time: Timestamp, tracking_id: TrackingId | None = None, force_branch_refresh: bool = False, - ) -> EnrichedDiffs: - diff_uuids_to_delete = [] - retrieved_enriched_diffs = await self.diff_repo.get_pairs( - base_branch_name=base_branch.name, - diff_branch_name=diff_branch.name, + ) -> EnrichedDiffs | EnrichedDiffsMetadata: + # start with empty diffs b/c we only care about their metadata for now, hydrate them with data as needed + diff_pairs_metadata = await self.diff_repo.get_diff_pairs_metadata( + base_branch_names=[base_branch.name], + diff_branch_names=[diff_branch.name], from_time=from_time, to_time=to_time, ) - for enriched_diffs in retrieved_enriched_diffs: - if tracking_id: - if enriched_diffs.base_branch_diff.tracking_id: - diff_uuids_to_delete.append(enriched_diffs.base_branch_diff.uuid) - if enriched_diffs.diff_branch_diff.tracking_id: - diff_uuids_to_delete.append(enriched_diffs.diff_branch_diff.uuid) - aggregated_enriched_diffs = await self._get_aggregated_enriched_diffs( + aggregated_enriched_diffs = await self._aggregate_enriched_diffs( diff_request=EnrichedDiffRequest( base_branch=base_branch, diff_branch=diff_branch, from_time=from_time, to_time=to_time, + tracking_id=tracking_id, ), - partial_enriched_diffs=retrieved_enriched_diffs if not force_branch_refresh else [], + partial_enriched_diffs=diff_pairs_metadata if not force_branch_refresh else None, ) + if tracking_id: + diff_uuids_to_delete: list[str] = [] + for diff_pair in diff_pairs_metadata: + if ( + diff_pair.base_branch_diff.tracking_id == tracking_id + and diff_pair.base_branch_diff.uuid != aggregated_enriched_diffs.base_branch_diff.uuid + ): + diff_uuids_to_delete.append(diff_pair.base_branch_diff.uuid) + if ( + diff_pair.diff_branch_diff.tracking_id == tracking_id + and diff_pair.diff_branch_diff.uuid != aggregated_enriched_diffs.diff_branch_diff.uuid + ): + diff_uuids_to_delete.append(diff_pair.diff_branch_diff.uuid) + + if diff_uuids_to_delete: + await self.diff_repo.delete_diff_roots(diff_root_uuids=diff_uuids_to_delete) + + # this is an EnrichedDiffsMetadata, so there are no nodes to enrich + if not isinstance(aggregated_enriched_diffs, EnrichedDiffs): + aggregated_enriched_diffs.update_metadata(from_time=from_time, to_time=to_time, tracking_id=tracking_id) + return aggregated_enriched_diffs await self.conflicts_enricher.add_conflicts_to_branch_diff( base_diff_root=aggregated_enriched_diffs.base_branch_diff, @@ -282,64 +388,170 @@ async def _update_diffs( enriched_diff_root=aggregated_enriched_diffs.diff_branch_diff, conflicts_only=True ) - if tracking_id: - aggregated_enriched_diffs.base_branch_diff.tracking_id = tracking_id - aggregated_enriched_diffs.diff_branch_diff.tracking_id = tracking_id - if diff_uuids_to_delete: - await self.diff_repo.delete_diff_roots(diff_root_uuids=diff_uuids_to_delete) return aggregated_enriched_diffs - async def _get_aggregated_enriched_diffs( - self, diff_request: EnrichedDiffRequest, partial_enriched_diffs: list[EnrichedDiffs] - ) -> EnrichedDiffs: + @overload + async def _aggregate_enriched_diffs( + self, + diff_request: EnrichedDiffRequest, + partial_enriched_diffs: list[EnrichedDiffsMetadata], + ) -> EnrichedDiffs | EnrichedDiffsMetadata: ... + + @overload + async def _aggregate_enriched_diffs( + self, + diff_request: EnrichedDiffRequest, + partial_enriched_diffs: None, + ) -> EnrichedDiffs: ... + + async def _aggregate_enriched_diffs( + self, + diff_request: EnrichedDiffRequest, + partial_enriched_diffs: list[EnrichedDiffsMetadata] | None, + ) -> EnrichedDiffs | EnrichedDiffsMetadata: + """ + If return is an EnrichedDiffsMetadata, it acts as a pointer to a diff in the database that has all the + necessary data for this diff_request. Might have a different time range and/or tracking_id + """ + aggregated_enriched_diffs: EnrichedDiffs | EnrichedDiffsMetadata | None = None if not partial_enriched_diffs: - return await self._get_enriched_diff(diff_request=diff_request, is_incremental_diff=False) - - ordered_diffs = self._get_ordered_diff_pairs(diff_pairs=partial_enriched_diffs, allow_overlap=False) - ordered_diff_reprs = [repr(d) for d in ordered_diffs] - log.debug(f"Ordered diffs for aggregation: {ordered_diff_reprs}") - current_time = diff_request.from_time - previous_diffs: EnrichedDiffs | None = None - while current_time < diff_request.to_time: - if ordered_diffs and ordered_diffs[0].diff_branch_diff.from_time == current_time: - current_diffs = ordered_diffs.pop(0) - else: + # no existing diffs to use in calculating this diff, so calculate the whole thing and return it + aggregated_enriched_diffs = await self._calculate_enriched_diff( + diff_request=diff_request, is_incremental_diff=False + ) + + if partial_enriched_diffs is not None and not aggregated_enriched_diffs: + ordered_diffs = self._get_ordered_diff_pairs(diff_pairs=partial_enriched_diffs, allow_overlap=False) + ordered_diff_reprs = [repr(d) for d in ordered_diffs] + log.info(f"Ordered diffs for aggregation: {ordered_diff_reprs}") + incremental_diffs_and_requests: list[EnrichedDiffsMetadata | EnrichedDiffRequest | None] = [] + current_time = diff_request.from_time + while current_time < diff_request.to_time: + # the next diff to include has already been calculated + if ordered_diffs and ordered_diffs[0].diff_branch_diff.from_time == current_time: + current_diff = ordered_diffs.pop(0) + incremental_diffs_and_requests.append(current_diff) + current_time = current_diff.diff_branch_diff.to_time + continue + # set the end time to the start of the next calculated diff or the end of the time range if ordered_diffs: end_time = ordered_diffs[0].diff_branch_diff.from_time else: end_time = diff_request.to_time - if previous_diffs is None: - node_field_specifiers = set() - else: - node_field_specifiers = self._get_node_field_specifiers( - enriched_diff=previous_diffs.diff_branch_diff - ) - inner_diff_request = EnrichedDiffRequest( - base_branch=diff_request.base_branch, - diff_branch=diff_request.diff_branch, + # if there are no changes on either branch in this time range, then there cannot be a diff + log.info(f"Checking number of changes on branches for {diff_request!r}") + num_changes_by_branch = await self.diff_repo.get_num_changes_in_time_range_by_branch( + branch_names=[diff_request.base_branch.name, diff_request.diff_branch.name], from_time=current_time, to_time=end_time, - node_field_specifiers=node_field_specifiers, ) - is_incremental_diff = current_time != diff_request.from_time - current_diffs = await self._get_enriched_diff( - diff_request=inner_diff_request, is_incremental_diff=is_incremental_diff + log.info(f"Number of changes: {num_changes_by_branch}") + might_have_changes_in_time_range = any(num_changes_by_branch.values()) + if not might_have_changes_in_time_range: + incremental_diffs_and_requests.append(None) + current_time = end_time + continue + + incremental_diffs_and_requests.append( + EnrichedDiffRequest( + base_branch=diff_request.base_branch, + diff_branch=diff_request.diff_branch, + from_time=current_time, + to_time=end_time, + ) ) + current_time = end_time + + aggregated_enriched_diffs = await self._concatenate_diffs_and_requests( + diff_or_request_list=incremental_diffs_and_requests, full_diff_request=diff_request + ) + + # no changes during this time period, so generate an EnrichedDiffs with no nodes + if not aggregated_enriched_diffs: + return self._build_enriched_diffs_with_no_nodes(diff_request=diff_request) + + # metadata-only diff, means that a diff exists in the database that covers at least + # part of this time period, but it might need to have its start or end time extended + # to cover time ranges with no changes + if not isinstance(aggregated_enriched_diffs, EnrichedDiffs): + return aggregated_enriched_diffs - if previous_diffs: - current_diffs = await self.diff_combiner.combine( - earlier_diffs=previous_diffs, later_diffs=current_diffs + # a new diff (with nodes) covering the time period + aggregated_enriched_diffs.update_metadata( + from_time=diff_request.from_time, to_time=diff_request.to_time, tracking_id=diff_request.tracking_id + ) + aggregated_enriched_diffs.set_fresh_uuids() + return aggregated_enriched_diffs + + async def _concatenate_diffs_and_requests( + self, + diff_or_request_list: Sequence[EnrichedDiffsMetadata | EnrichedDiffRequest | None], + full_diff_request: EnrichedDiffRequest, + ) -> EnrichedDiffs | EnrichedDiffsMetadata | None: + """ + Returns None if diff_or_request_list is empty or all Nones + meaning there are no changes for the diff during this time period + Returns EnrichedDiffsMetadata if diff_or_request_list includes one EnrichedDiffsMetadata and no EnrichedDiffRequests + meaning no diffs needed to be hydrated and combined + Otherwise, returns EnrichedDiffs + meaning multiple diffs (some that may have been freshly calculated) were combined + """ + previous_diff_pair: EnrichedDiffs | EnrichedDiffsMetadata | None = None + for diff_or_request in diff_or_request_list: + if isinstance(diff_or_request, EnrichedDiffRequest): + if previous_diff_pair: + log.info(f"Getting node field specifiers diff uuid={previous_diff_pair.diff_branch_diff.uuid}") + node_field_specifiers = await self.diff_repo.get_node_field_specifiers( + diff_id=previous_diff_pair.diff_branch_diff.uuid, + ) + log.info(f"Number node field specifiers: {len(node_field_specifiers)}") + diff_or_request.node_field_specifiers = node_field_specifiers + is_incremental_diff = diff_or_request.from_time != full_diff_request.from_time + single_enriched_diffs: EnrichedDiffs | EnrichedDiffsMetadata = await self._calculate_enriched_diff( + diff_request=diff_or_request, is_incremental_diff=is_incremental_diff ) - previous_diffs = current_diffs - current_time = current_diffs.diff_branch_diff.to_time + elif isinstance(diff_or_request, EnrichedDiffsMetadata): + single_enriched_diffs = diff_or_request + else: + continue + + if previous_diff_pair is None: + previous_diff_pair = single_enriched_diffs + continue - return current_diffs + previous_diff_pair = await self._combine_diffs(earlier=previous_diff_pair, later=single_enriched_diffs) - async def _update_core_data_checks(self, enriched_diff: EnrichedDiffRoot) -> list[Node]: + return previous_diff_pair + + async def _combine_diffs( + self, earlier: EnrichedDiffs | EnrichedDiffsMetadata, later: EnrichedDiffs | EnrichedDiffsMetadata + ) -> EnrichedDiffs | EnrichedDiffsMetadata: + # if one of the diffs is hydrated and has no data, we can combine them without hydrating the other + if isinstance(earlier, EnrichedDiffs) and earlier.is_empty: + later.base_branch_diff.from_time = earlier.base_branch_diff.from_time + later.diff_branch_diff.from_time = earlier.diff_branch_diff.from_time + return later + if isinstance(later, EnrichedDiffs) and later.is_empty: + earlier.base_branch_diff.to_time = later.base_branch_diff.to_time + earlier.diff_branch_diff.to_time = later.diff_branch_diff.to_time + return earlier + + # hydrate the diffs to combine, if necessary + if not isinstance(earlier, EnrichedDiffs): + earlier = await self.diff_repo.hydrate_diff_pair(enriched_diffs_metadata=earlier) + if not isinstance(later, EnrichedDiffs): + later = await self.diff_repo.hydrate_diff_pair(enriched_diffs_metadata=later) + + return await self.diff_combiner.combine(earlier_diffs=earlier, later_diffs=later) + + async def _update_core_data_checks(self, enriched_diff: EnrichedDiffRoot | EnrichedDiffRootMetadata) -> list[Node]: return await self.data_check_synchronizer.synchronize(enriched_diff=enriched_diff) - async def _get_enriched_diff(self, diff_request: EnrichedDiffRequest, is_incremental_diff: bool) -> EnrichedDiffs: + async def _calculate_enriched_diff( + self, diff_request: EnrichedDiffRequest, is_incremental_diff: bool + ) -> EnrichedDiffs: + log.info(f"Calculating diff for {diff_request!r}, include_unchanged={is_incremental_diff}") calculated_diff_pair = await self.diff_calculator.calculate_diff( base_branch=diff_request.base_branch, diff_branch=diff_request.diff_branch, @@ -348,20 +560,7 @@ async def _get_enriched_diff(self, diff_request: EnrichedDiffRequest, is_increme include_unchanged=is_incremental_diff, previous_node_specifiers=diff_request.node_field_specifiers, ) + log.info("Calculation complete. Enriching diff...") enriched_diff_pair = await self.diff_enricher.enrich(calculated_diffs=calculated_diff_pair) + log.info("Enrichment complete") return enriched_diff_pair - - def _get_node_field_specifiers(self, enriched_diff: EnrichedDiffRoot) -> set[NodeFieldSpecifier]: - specifiers: set[NodeFieldSpecifier] = set() - schema_branch = registry.schema.get_schema_branch(name=enriched_diff.diff_branch_name) - for node in enriched_diff.nodes: - specifiers.update( - NodeFieldSpecifier(node_uuid=node.uuid, field_name=attribute.name) for attribute in node.attributes - ) - if not node.relationships: - continue - node_schema = schema_branch.get_node(name=node.kind, duplicate=False) - for relationship in node.relationships: - relationship_schema = node_schema.get_relationship(name=relationship.name) - specifiers.add(NodeFieldSpecifier(node_uuid=node.uuid, field_name=relationship_schema.get_identifier())) - return specifiers diff --git a/backend/infrahub/core/diff/data_check_synchronizer.py b/backend/infrahub/core/diff/data_check_synchronizer.py index 624fe31597..178f437b8f 100644 --- a/backend/infrahub/core/diff/data_check_synchronizer.py +++ b/backend/infrahub/core/diff/data_check_synchronizer.py @@ -9,7 +9,9 @@ from infrahub.proposed_change.constants import ProposedChangeState from .conflicts_extractor import DiffConflictsExtractor -from .model.path import ConflictSelection, EnrichedDiffConflict, EnrichedDiffRoot +from .model.diff import DataConflict +from .model.path import ConflictSelection, EnrichedDiffConflict, EnrichedDiffRoot, EnrichedDiffRootMetadata +from .repository.repository import DiffRepository class DiffDataCheckSynchronizer: @@ -18,12 +20,28 @@ def __init__( db: InfrahubDatabase, conflicts_extractor: DiffConflictsExtractor, conflict_recorder: ObjectConflictValidatorRecorder, + diff_repository: DiffRepository, ): self.db = db self.conflicts_extractor = conflicts_extractor self.conflict_recorder = conflict_recorder + self.diff_repository = diff_repository + self._enriched_conflicts_map: dict[str, EnrichedDiffConflict] | None = None + self._data_conflicts: list[DataConflict] | None = None - async def synchronize(self, enriched_diff: EnrichedDiffRoot) -> list[Node]: + def _get_enriched_conflicts_map(self, enriched_diff: EnrichedDiffRoot) -> dict[str, EnrichedDiffConflict]: + if self._enriched_conflicts_map is None: + self._enriched_conflicts_map = enriched_diff.get_all_conflicts() + return self._enriched_conflicts_map + + async def _get_data_conflicts(self, enriched_diff: EnrichedDiffRoot) -> list[DataConflict]: + if self._data_conflicts is None: + self._data_conflicts = await self.conflicts_extractor.get_data_conflicts(enriched_diff_root=enriched_diff) + return self._data_conflicts + + async def synchronize(self, enriched_diff: EnrichedDiffRoot | EnrichedDiffRootMetadata) -> list[Node]: + self._enriched_conflicts_map = None + self._data_conflicts = None try: proposed_changes = await NodeManager.query( db=self.db, @@ -35,10 +53,21 @@ async def synchronize(self, enriched_diff: EnrichedDiffRoot) -> list[Node]: proposed_changes = [] if not proposed_changes: return [] - enriched_conflicts_map = enriched_diff.get_all_conflicts() - data_conflicts = await self.conflicts_extractor.get_data_conflicts(enriched_diff_root=enriched_diff) all_data_checks = [] for pc in proposed_changes: + # if the enriched_diff is EnrichedDiffRootMetadata, then it has no new data in it + if not isinstance(enriched_diff, EnrichedDiffRoot): + has_validator = bool(await self.conflict_recorder.get_validator(proposed_change=pc)) + # if this pc does not have a validator, then it is a new ProposedChange + if has_validator: + continue + # if this is a new ProposedChange, we need to hydrate then EnrichedDiffRoot so that we can get the conflicts from it + enriched_diff = await self.diff_repository.get_one( + diff_branch_name=enriched_diff.diff_branch_name, diff_id=enriched_diff.uuid + ) + + data_conflicts = await self._get_data_conflicts(enriched_diff=enriched_diff) + enriched_conflicts_map = self._get_enriched_conflicts_map(enriched_diff=enriched_diff) core_data_checks = await self.conflict_recorder.record_conflicts( proposed_change_id=pc.get_id(), conflicts=data_conflicts ) diff --git a/backend/infrahub/core/diff/enricher/hierarchy.py b/backend/infrahub/core/diff/enricher/hierarchy.py index 79c6ba44c0..f4b47730cd 100644 --- a/backend/infrahub/core/diff/enricher/hierarchy.py +++ b/backend/infrahub/core/diff/enricher/hierarchy.py @@ -92,6 +92,7 @@ async def _enrich_hierarchical_nodes( parent_kind=ancestor.kind, parent_label="", parent_rel_name=parent_rel.name, + parent_rel_identifier=parent_rel.get_identifier(), parent_rel_cardinality=parent_rel.cardinality, parent_rel_label=parent_rel.label or "", ) @@ -150,6 +151,7 @@ async def _enrich_nodes_with_parent( parent_kind=peer_parent.peer_kind, parent_label="", parent_rel_name=parent_rel.name, + parent_rel_identifier=parent_rel.get_identifier(), parent_rel_cardinality=parent_rel.cardinality, parent_rel_label=parent_rel.label or "", ) diff --git a/backend/infrahub/core/diff/merger/merger.py b/backend/infrahub/core/diff/merger/merger.py index 02afb698da..fc858cd744 100644 --- a/backend/infrahub/core/diff/merger/merger.py +++ b/backend/infrahub/core/diff/merger/merger.py @@ -31,7 +31,7 @@ def __init__( self.serializer = serializer async def merge_graph(self, at: Timestamp) -> None: - enriched_diffs = await self.diff_repository.get_empty_roots( + enriched_diffs = await self.diff_repository.get_roots_metadata( diff_branch_names=[self.source_branch.name], base_branch_names=[self.destination_branch.name] ) latest_diff = None diff --git a/backend/infrahub/core/diff/model/path.py b/backend/infrahub/core/diff/model/path.py index 8cb5e5cc50..eb83a10c83 100644 --- a/backend/infrahub/core/diff/model/path.py +++ b/backend/infrahub/core/diff/model/path.py @@ -1,8 +1,9 @@ from __future__ import annotations -from dataclasses import dataclass, field, replace +from dataclasses import asdict, dataclass, field, replace from enum import Enum from typing import TYPE_CHECKING, Any, Optional +from uuid import uuid4 from infrahub.core.constants import ( BranchSupportType, @@ -239,6 +240,7 @@ def from_calculated_element(cls, calculated_element: DiffSingleRelationship) -> @dataclass class EnrichedDiffRelationship(BaseSummary): name: str + identifier: str label: str cardinality: RelationshipCardinality path_identifier: str = field(default="", kw_only=True) @@ -270,6 +272,7 @@ def include_in_response(self) -> bool: def from_calculated_relationship(cls, calculated_relationship: DiffRelationship) -> EnrichedDiffRelationship: return EnrichedDiffRelationship( name=calculated_relationship.name, + identifier=calculated_relationship.identifier, label="", cardinality=calculated_relationship.cardinality, changed_at=calculated_relationship.changed_at, @@ -403,7 +406,7 @@ def from_calculated_node(cls, calculated_node: DiffNode) -> EnrichedDiffNode: @dataclass -class EnrichedDiffRoot(BaseSummary): +class EnrichedDiffRootMetadata(BaseSummary): base_branch_name: str diff_branch_name: str from_time: Timestamp @@ -411,6 +414,35 @@ class EnrichedDiffRoot(BaseSummary): uuid: str partner_uuid: str tracking_id: TrackingId | None = field(default=None, kw_only=True) + + def __hash__(self) -> int: + return hash(self.uuid) + + @property + def time_range(self) -> Interval: + return self.to_time.obj - self.from_time.obj + + def update_metadata( + self, + from_time: Timestamp | None = None, + to_time: Timestamp | None = None, + tracking_id: TrackingId | None = None, + ) -> bool: + is_changed = False + if from_time and self.from_time != from_time: + self.from_time = from_time + is_changed = True + if to_time and self.to_time != to_time: + self.to_time = to_time + is_changed = True + if self.tracking_id != tracking_id: + self.tracking_id = tracking_id + is_changed = True + return is_changed + + +@dataclass +class EnrichedDiffRoot(EnrichedDiffRootMetadata): nodes: set[EnrichedDiffNode] = field(default_factory=set) def __hash__(self) -> int: @@ -446,6 +478,10 @@ def get_all_conflicts(self) -> dict[str, EnrichedDiffConflict]: all_conflicts.update(node.get_all_conflicts()) return all_conflicts + @classmethod + def from_root_metadata(cls, empty_root: EnrichedDiffRootMetadata) -> EnrichedDiffRoot: + return EnrichedDiffRoot(**asdict(empty_root)) + @classmethod def from_calculated_diff( cls, calculated_diff: DiffRoot, base_branch_name: str, partner_uuid: str @@ -467,6 +503,7 @@ def add_parent( parent_kind: str, parent_label: str, parent_rel_name: str, + parent_rel_identifier: str, parent_rel_cardinality: RelationshipCardinality, parent_rel_label: str = "", ) -> EnrichedDiffNode: @@ -491,6 +528,7 @@ def add_parent( node.relationships.add( EnrichedDiffRelationship( name=parent_rel_name, + identifier=parent_rel_identifier, label=parent_rel_label, cardinality=parent_rel_cardinality, changed_at=None, @@ -503,9 +541,48 @@ def add_parent( @dataclass -class EnrichedDiffs: +class EnrichedDiffsMetadata: base_branch_name: str diff_branch_name: str + base_branch_diff: EnrichedDiffRootMetadata + diff_branch_diff: EnrichedDiffRootMetadata + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"branch_uuid={self.diff_branch_diff.uuid}," + f"base_uuid={self.base_branch_diff.uuid}," + f"branch_name={self.diff_branch_name}," + f"base_name={self.base_branch_name}," + f"from_time={self.diff_branch_diff.from_time}," + f"to_time={self.diff_branch_diff.to_time})" + ) + + def update_metadata( + self, + from_time: Timestamp | None = None, + to_time: Timestamp | None = None, + tracking_id: TrackingId | None = None, + ) -> bool: + is_changed = self.base_branch_diff.update_metadata( + from_time=from_time, to_time=to_time, tracking_id=tracking_id + ) + is_changed |= self.diff_branch_diff.update_metadata( + from_time=from_time, to_time=to_time, tracking_id=tracking_id + ) + return is_changed + + def set_fresh_uuids(self) -> None: + base_uuid = str(uuid4()) + branch_uuid = str(uuid4()) + self.base_branch_diff.uuid = base_uuid + self.base_branch_diff.partner_uuid = branch_uuid + self.diff_branch_diff.uuid = branch_uuid + self.diff_branch_diff.partner_uuid = base_uuid + + +@dataclass +class EnrichedDiffs(EnrichedDiffsMetadata): base_branch_diff: EnrichedDiffRoot diff_branch_diff: EnrichedDiffRoot @@ -539,6 +616,10 @@ def from_calculated_diffs(cls, calculated_diffs: CalculatedDiffs) -> EnrichedDif diff_branch_diff=diff_branch_diff, ) + @property + def is_empty(self) -> bool: + return len(self.base_branch_diff.nodes) == 0 and len(self.diff_branch_diff.nodes) == 0 + @dataclass class CalculatedDiffs: @@ -577,6 +658,7 @@ class DiffSingleRelationship: @dataclass class DiffRelationship: name: str + identifier: str cardinality: RelationshipCardinality changed_at: Timestamp action: DiffAction diff --git a/backend/infrahub/core/diff/query/empty_roots.py b/backend/infrahub/core/diff/query/empty_roots.py deleted file mode 100644 index 64ccbfa38a..0000000000 --- a/backend/infrahub/core/diff/query/empty_roots.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Any, Generator - -from neo4j.graph import Node as Neo4jNode - -from infrahub.core.query import Query, QueryType -from infrahub.database import InfrahubDatabase - - -class EnrichedDiffEmptyRootsQuery(Query): - name = "enriched_diff_empty_roots" - type = QueryType.READ - - def __init__( - self, diff_branch_names: list[str] | None = None, base_branch_names: list[str] | None = None, **kwargs: Any - ) -> None: - super().__init__(**kwargs) - self.diff_branch_names = diff_branch_names - self.base_branch_names = base_branch_names - - async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: - self.params = {"diff_branch_names": self.diff_branch_names, "base_branch_names": self.base_branch_names} - - query = """ - MATCH (diff_root:DiffRoot) - WHERE ($diff_branch_names IS NULL OR diff_root.diff_branch IN $diff_branch_names) - AND ($base_branch_names IS NULL OR diff_root.base_branch IN $base_branch_names) - """ - self.return_labels = ["diff_root"] - self.add_to_query(query=query) - - def get_empty_root_nodes(self) -> Generator[Neo4jNode, None, None]: - for result in self.get_results(): - yield result.get_node("diff_root") diff --git a/backend/infrahub/core/diff/query/field_specifiers.py b/backend/infrahub/core/diff/query/field_specifiers.py new file mode 100644 index 0000000000..2325d6d58f --- /dev/null +++ b/backend/infrahub/core/diff/query/field_specifiers.py @@ -0,0 +1,35 @@ +from typing import Any, Generator + +from infrahub.core.query import Query, QueryType +from infrahub.database import InfrahubDatabase + + +class EnrichedDiffFieldSpecifiersQuery(Query): + name = "enriched_diff_field_specifiers" + type = QueryType.READ + + def __init__(self, diff_id: str, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.diff_id = diff_id + + async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: + self.params["diff_id"] = self.diff_id + query = """ +CALL { + MATCH (root:DiffRoot {uuid: $diff_id})-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_ATTRIBUTE]->(attr:DiffAttribute) + RETURN node.uuid AS node_uuid, attr.name AS field_name + UNION + MATCH (root:DiffRoot {uuid: $diff_id})-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_RELATIONSHIP]->(rel:DiffRelationship) + RETURN node.uuid AS node_uuid, rel.identifier AS field_name +} + """ + self.add_to_query(query=query) + self.return_labels = ["node_uuid", "field_name"] + self.order_by = ["node_uuid", "field_name"] + + def get_node_field_specifier_tuples(self) -> Generator[tuple[str, str], None, None]: + for result in self.get_results(): + node_uuid = result.get_as_str("node_uuid") + field_name = result.get_as_str("field_name") + if node_uuid and field_name: + yield (node_uuid, field_name) diff --git a/backend/infrahub/core/diff/query/roots_metadata.py b/backend/infrahub/core/diff/query/roots_metadata.py new file mode 100644 index 0000000000..0d0d321af7 --- /dev/null +++ b/backend/infrahub/core/diff/query/roots_metadata.py @@ -0,0 +1,48 @@ +from typing import Any, Generator + +from neo4j.graph import Node as Neo4jNode + +from infrahub.core.query import Query, QueryType +from infrahub.core.timestamp import Timestamp +from infrahub.database import InfrahubDatabase + + +class EnrichedDiffRootsMetadataQuery(Query): + name = "enriched_diff_roots_metadata" + type = QueryType.READ + + def __init__( + self, + diff_branch_names: list[str] | None = None, + base_branch_names: list[str] | None = None, + from_time: Timestamp | None = None, + to_time: Timestamp | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.diff_branch_names = diff_branch_names + self.base_branch_names = base_branch_names + self.from_time = from_time + self.to_time = to_time + + async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: + self.params = { + "diff_branch_names": self.diff_branch_names, + "base_branch_names": self.base_branch_names, + "from_time": self.from_time.to_string() if self.from_time else None, + "to_time": self.to_time.to_string() if self.to_time else None, + } + + query = """ + MATCH (diff_root:DiffRoot) + WHERE ($diff_branch_names IS NULL OR diff_root.diff_branch IN $diff_branch_names) + AND ($base_branch_names IS NULL OR diff_root.base_branch IN $base_branch_names) + AND ($from_time IS NULL OR diff_root.from_time >= $from_time) + AND ($to_time IS NULL OR diff_root.to_time <= $to_time) + """ + self.return_labels = ["diff_root"] + self.add_to_query(query=query) + + def get_root_nodes_metadata(self) -> Generator[Neo4jNode, None, None]: + for result in self.get_results(): + yield result.get_node("diff_root") diff --git a/backend/infrahub/core/diff/query/save.py b/backend/infrahub/core/diff/query/save.py index ab09736faf..d5bd01a7ed 100644 --- a/backend/infrahub/core/diff/query/save.py +++ b/backend/infrahub/core/diff/query/save.py @@ -261,6 +261,7 @@ def _build_diff_relationship_params(self, enriched_relationship: EnrichedDiffRel return { "node_properties": { "name": enriched_relationship.name, + "identifier": enriched_relationship.identifier, "label": enriched_relationship.label, "cardinality": enriched_relationship.cardinality.value, "changed_at": enriched_relationship.changed_at.to_string() diff --git a/backend/infrahub/core/diff/query_parser.py b/backend/infrahub/core/diff/query_parser.py index afb18716e7..74122b77d4 100644 --- a/backend/infrahub/core/diff/query_parser.py +++ b/backend/infrahub/core/diff/query_parser.py @@ -374,6 +374,7 @@ def to_diff_relationship(self, include_unchanged: bool) -> DiffRelationship: action = actions.pop() return DiffRelationship( name=self.name, + identifier=self.identifier, changed_at=last_changed_at, action=action, relationships=single_relationships, diff --git a/backend/infrahub/core/diff/repository/deserializer.py b/backend/infrahub/core/diff/repository/deserializer.py index 41ccc72f88..246375dec4 100644 --- a/backend/infrahub/core/diff/repository/deserializer.py +++ b/backend/infrahub/core/diff/repository/deserializer.py @@ -16,6 +16,7 @@ EnrichedDiffProperty, EnrichedDiffRelationship, EnrichedDiffRoot, + EnrichedDiffRootMetadata, EnrichedDiffSingleRelationship, deserialize_tracking_id, ) @@ -135,6 +136,7 @@ def _deserialize_parents(self, result: QueryResult, enriched_root: EnrichedDiffR parent_kind=parent.get("kind"), parent_label=parent.get("label"), parent_rel_name=rel.get("name"), + parent_rel_identifier=rel.get("identifier"), parent_rel_cardinality=RelationshipCardinality(rel.get("cardinality")), parent_rel_label=rel.get("label"), ) @@ -149,19 +151,20 @@ def _deserialize_diff_root(self, root_node: Neo4jNode) -> EnrichedDiffRoot: root_uuid = str(root_node.get("uuid")) if root_uuid in self._diff_root_map: return self._diff_root_map[root_uuid] - enriched_root = self.build_diff_root(root_node=root_node) + root_empty = self.build_diff_root_metadata(root_node=root_node) + enriched_root = EnrichedDiffRoot.from_root_metadata(empty_root=root_empty) self._diff_root_map[root_uuid] = enriched_root return enriched_root @classmethod - def build_diff_root(cls, root_node: Neo4jNode) -> EnrichedDiffRoot: + def build_diff_root_metadata(cls, root_node: Neo4jNode) -> EnrichedDiffRootMetadata: from_time = Timestamp(str(root_node.get("from_time"))) to_time = Timestamp(str(root_node.get("to_time"))) tracking_id_str = cls._get_str_or_none_property_value(node=root_node, property_name="tracking_id") tracking_id = None if tracking_id_str: tracking_id = deserialize_tracking_id(tracking_id_str=tracking_id_str) - return EnrichedDiffRoot( + return EnrichedDiffRootMetadata( base_branch_name=str(root_node.get("base_branch")), diff_branch_name=str(root_node.get("diff_branch")), from_time=from_time, @@ -234,6 +237,7 @@ def _deserialize_diff_relationship_group( timestamp_str = relationship_group_node.get("changed_at") enriched_relationship = EnrichedDiffRelationship( name=relationship_group_node.get("name"), + identifier=relationship_group_node.get("identifier"), label=relationship_group_node.get("label"), cardinality=RelationshipCardinality(relationship_group_node.get("cardinality")), changed_at=Timestamp(timestamp_str) if timestamp_str else None, diff --git a/backend/infrahub/core/diff/repository/repository.py b/backend/infrahub/core/diff/repository/repository.py index dd738b6763..82b751cbf0 100644 --- a/backend/infrahub/core/diff/repository/repository.py +++ b/backend/infrahub/core/diff/repository/repository.py @@ -3,17 +3,22 @@ from infrahub import config from infrahub.core import registry from infrahub.core.diff.query.field_summary import EnrichedDiffNodeFieldSummaryQuery +from infrahub.core.query.diff import DiffCountChanges from infrahub.core.timestamp import Timestamp from infrahub.database import InfrahubDatabase, retry_db_transaction from infrahub.exceptions import ResourceNotFoundError +from infrahub.log import get_logger from ..model.path import ( ConflictSelection, EnrichedDiffConflict, EnrichedDiffRoot, + EnrichedDiffRootMetadata, EnrichedDiffs, + EnrichedDiffsMetadata, EnrichedNodeCreateRequest, NodeDiffFieldSummary, + NodeFieldSpecifier, TimeRange, TrackingId, ) @@ -21,15 +26,18 @@ from ..query.diff_get import EnrichedDiffGetQuery from ..query.diff_summary import DiffSummaryCounters, DiffSummaryQuery from ..query.drop_tracking_id import EnrichedDiffDropTrackingIdQuery -from ..query.empty_roots import EnrichedDiffEmptyRootsQuery +from ..query.field_specifiers import EnrichedDiffFieldSpecifiersQuery from ..query.filters import EnrichedDiffQueryFilters from ..query.get_conflict_query import EnrichedDiffConflictQuery from ..query.has_conflicts_query import EnrichedDiffHasConflictQuery +from ..query.roots_metadata import EnrichedDiffRootsMetadataQuery from ..query.save import EnrichedDiffRootsCreateQuery, EnrichedNodeBatchCreateQuery, EnrichedNodesLinkQuery from ..query.time_range_query import EnrichedDiffTimeRangeQuery from ..query.update_conflict_query import EnrichedDiffConflictUpdateQuery from .deserializer import EnrichedDiffDeserializer +log = get_logger() + class DiffRepository: MAX_SAVE_BATCH_SIZE: int = 100 @@ -118,6 +126,22 @@ async def get_pairs( for dbr in diff_branch_roots ] + async def hydrate_diff_pair(self, enriched_diffs_metadata: EnrichedDiffsMetadata) -> EnrichedDiffs: + hydrated_base_diff = await self.get_one( + diff_branch_name=enriched_diffs_metadata.base_branch_name, + diff_id=enriched_diffs_metadata.base_branch_diff.uuid, + ) + hydrated_branch_diff = await self.get_one( + diff_branch_name=enriched_diffs_metadata.diff_branch_name, + diff_id=enriched_diffs_metadata.diff_branch_diff.uuid, + ) + return EnrichedDiffs( + base_branch_name=enriched_diffs_metadata.base_branch_name, + diff_branch_name=enriched_diffs_metadata.diff_branch_name, + base_branch_diff=hydrated_base_diff, + diff_branch_diff=hydrated_branch_diff, + ) + async def get_one( self, diff_branch_name: str, @@ -161,6 +185,7 @@ def _get_node_create_request_batch( @retry_db_transaction(name="enriched_diff_save") async def save(self, enriched_diffs: EnrichedDiffs) -> None: + log.info("Saving diff...") root_query = await EnrichedDiffRootsCreateQuery.init(db=self.db, enriched_diffs=enriched_diffs) await root_query.execute(db=self.db) for node_create_batch in self._get_node_create_request_batch(enriched_diffs=enriched_diffs): @@ -168,6 +193,7 @@ async def save(self, enriched_diffs: EnrichedDiffs) -> None: await node_query.execute(db=self.db) link_query = await EnrichedNodesLinkQuery.init(db=self.db, enriched_diffs=enriched_diffs) await link_query.execute(db=self.db) + log.info("Diff saved.") async def summary( self, @@ -211,18 +237,55 @@ async def get_time_ranges( await query.execute(db=self.db) return await query.get_time_ranges() - async def get_empty_roots( + async def get_diff_pairs_metadata( self, diff_branch_names: list[str] | None = None, base_branch_names: list[str] | None = None, - ) -> list[EnrichedDiffRoot]: - query = await EnrichedDiffEmptyRootsQuery.init( - db=self.db, diff_branch_names=diff_branch_names, base_branch_names=base_branch_names + from_time: Timestamp | None = None, + to_time: Timestamp | None = None, + ) -> list[EnrichedDiffsMetadata]: + if diff_branch_names and base_branch_names: + diff_branch_names += base_branch_names + empty_roots = await self.get_roots_metadata( + diff_branch_names=diff_branch_names, + base_branch_names=base_branch_names, + from_time=from_time, + to_time=to_time, + ) + roots_by_id = {root.uuid: root for root in empty_roots} + pairs: list[EnrichedDiffsMetadata] = [] + for branch_root in empty_roots: + if branch_root.base_branch_name == branch_root.diff_branch_name: + continue + base_root = roots_by_id[branch_root.partner_uuid] + pairs.append( + EnrichedDiffsMetadata( + base_branch_name=branch_root.base_branch_name, + diff_branch_name=branch_root.diff_branch_name, + base_branch_diff=base_root, + diff_branch_diff=branch_root, + ) + ) + return pairs + + async def get_roots_metadata( + self, + diff_branch_names: list[str] | None = None, + base_branch_names: list[str] | None = None, + from_time: Timestamp | None = None, + to_time: Timestamp | None = None, + ) -> list[EnrichedDiffRootMetadata]: + query = await EnrichedDiffRootsMetadataQuery.init( + db=self.db, + diff_branch_names=diff_branch_names, + base_branch_names=base_branch_names, + from_time=from_time, + to_time=to_time, ) await query.execute(db=self.db) diff_roots = [] - for neo4j_node in query.get_empty_root_nodes(): - diff_roots.append(self.deserializer.build_diff_root(root_node=neo4j_node)) + for neo4j_node in query.get_root_nodes_metadata(): + diff_roots.append(self.deserializer.build_diff_root_metadata(root_node=neo4j_node)) return diff_roots async def diff_has_conflicts( @@ -267,3 +330,31 @@ async def get_node_field_summaries( async def drop_tracking_ids(self, tracking_ids: list[TrackingId]) -> None: query = await EnrichedDiffDropTrackingIdQuery.init(db=self.db, tracking_ids=tracking_ids) await query.execute(db=self.db) + + async def get_num_changes_in_time_range_by_branch( + self, branch_names: list[str], from_time: Timestamp, to_time: Timestamp + ) -> dict[str, int]: + query = await DiffCountChanges.init(db=self.db, branch_names=branch_names, diff_from=from_time, diff_to=to_time) + await query.execute(db=self.db) + return query.get_num_changes_by_branch() + + async def get_node_field_specifiers(self, diff_id: str) -> set[NodeFieldSpecifier]: + limit = 5000 + offset = 0 + specifiers: set[NodeFieldSpecifier] = set() + while True: + query = await EnrichedDiffFieldSpecifiersQuery.init(db=self.db, diff_id=diff_id, offset=offset, limit=limit) + await query.execute(db=self.db) + + new_specifiers = { + NodeFieldSpecifier( + node_uuid=field_specifier_tuple[0], + field_name=field_specifier_tuple[1], + ) + for field_specifier_tuple in query.get_node_field_specifier_tuples() + } + if not new_specifiers: + break + specifiers |= new_specifiers + offset += limit + return specifiers diff --git a/backend/infrahub/core/diff/tasks.py b/backend/infrahub/core/diff/tasks.py index 7dba986e7c..9301241837 100644 --- a/backend/infrahub/core/diff/tasks.py +++ b/backend/infrahub/core/diff/tasks.py @@ -57,7 +57,7 @@ async def refresh_diff_all(branch_name: str) -> None: component_registry = get_component_registry() default_branch = registry.get_branch_from_registry() diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=default_branch) - diff_roots_to_refresh = await diff_repository.get_empty_roots(diff_branch_names=[branch_name]) + diff_roots_to_refresh = await diff_repository.get_roots_metadata(diff_branch_names=[branch_name]) for diff_root in diff_roots_to_refresh: if diff_root.base_branch_name != diff_root.diff_branch_name: diff --git a/backend/infrahub/core/integrity/object_conflict/conflict_recorder.py b/backend/infrahub/core/integrity/object_conflict/conflict_recorder.py index eaf3d71eb0..889854825b 100644 --- a/backend/infrahub/core/integrity/object_conflict/conflict_recorder.py +++ b/backend/infrahub/core/integrity/object_conflict/conflict_recorder.py @@ -94,13 +94,18 @@ async def record_conflicts(self, proposed_change_id: str, conflicts: Sequence[Ob await self.finalize_validator(validator, is_success) return current_checks - async def get_or_create_validator(self, proposed_change: CoreProposedChange) -> Node: + async def get_validator(self, proposed_change: CoreProposedChange) -> Node | None: validations = await proposed_change.validations.get_peers(db=self.db, branch_agnostic=True) for validation in validations.values(): if validation.get_kind() == self.validator_kind: return validation + return None + async def get_or_create_validator(self, proposed_change: CoreProposedChange) -> Node: + validator_obj = await self.get_validator(proposed_change=proposed_change) + if validator_obj: + return validator_obj validator_obj = await Node.init(db=self.db, schema=self.validator_kind) await validator_obj.new( db=self.db, diff --git a/backend/infrahub/core/merge.py b/backend/infrahub/core/merge.py index 5fe3c1c8e6..32232e277a 100644 --- a/backend/infrahub/core/merge.py +++ b/backend/infrahub/core/merge.py @@ -174,7 +174,7 @@ async def merge( if self.source_branch.name == registry.default_branch: raise ValidationError(f"Unable to merge the branch '{self.source_branch.name}' into itself") - enriched_diff = await self.diff_coordinator.update_branch_diff( + enriched_diff = await self.diff_coordinator.update_branch_diff_and_return( base_branch=self.destination_branch, diff_branch=self.source_branch ) conflict_map = enriched_diff.get_all_conflicts() diff --git a/backend/infrahub/core/migrations/graph/m015_diff_format_update.py b/backend/infrahub/core/migrations/graph/m015_diff_format_update.py index 2fa2e772ea..ba8d3d1b2e 100644 --- a/backend/infrahub/core/migrations/graph/m015_diff_format_update.py +++ b/backend/infrahub/core/migrations/graph/m015_diff_format_update.py @@ -31,6 +31,6 @@ async def execute(self, db: InfrahubDatabase) -> MigrationResult: component_registry = get_component_registry() diff_repo = await component_registry.get_component(DiffRepository, db=db, branch=default_branch) - diff_roots = await diff_repo.get_empty_roots() + diff_roots = await diff_repo.get_roots_metadata() await diff_repo.delete_diff_roots(diff_root_uuids=[d.uuid for d in diff_roots]) return MigrationResult() diff --git a/backend/infrahub/core/migrations/graph/m016_diff_delete_bug_fix.py b/backend/infrahub/core/migrations/graph/m016_diff_delete_bug_fix.py index 57a1bd0550..484ca214e8 100644 --- a/backend/infrahub/core/migrations/graph/m016_diff_delete_bug_fix.py +++ b/backend/infrahub/core/migrations/graph/m016_diff_delete_bug_fix.py @@ -31,6 +31,6 @@ async def execute(self, db: InfrahubDatabase) -> MigrationResult: component_registry = get_component_registry() diff_repo = await component_registry.get_component(DiffRepository, db=db, branch=default_branch) - diff_roots = await diff_repo.get_empty_roots() + diff_roots = await diff_repo.get_roots_metadata() await diff_repo.delete_diff_roots(diff_root_uuids=[d.uuid for d in diff_roots]) return MigrationResult() diff --git a/backend/infrahub/dependencies/builder/diff/data_check_synchronizer.py b/backend/infrahub/dependencies/builder/diff/data_check_synchronizer.py index 3250724e0a..427e5e22a8 100644 --- a/backend/infrahub/dependencies/builder/diff/data_check_synchronizer.py +++ b/backend/infrahub/dependencies/builder/diff/data_check_synchronizer.py @@ -3,6 +3,7 @@ from .conflicts_extractor import DiffConflictsExtractorDependency from .data_check_conflict_recorder import DataCheckConflictRecorderDependency +from .repository import DiffRepositoryDependency class DiffDataCheckSynchronizerDependency(DependencyBuilder[DiffDataCheckSynchronizer]): @@ -12,4 +13,5 @@ def build(cls, context: DependencyBuilderContext) -> DiffDataCheckSynchronizer: db=context.db, conflicts_extractor=DiffConflictsExtractorDependency.build(context=context), conflict_recorder=DataCheckConflictRecorderDependency.build(context=context), + diff_repository=DiffRepositoryDependency.build(context=context), ) diff --git a/backend/infrahub/graphql/mutations/tasks.py b/backend/infrahub/graphql/mutations/tasks.py index 3eb1c8a187..b3d571fc76 100644 --- a/backend/infrahub/graphql/mutations/tasks.py +++ b/backend/infrahub/graphql/mutations/tasks.py @@ -31,7 +31,7 @@ async def merge_branch_mutation(branch: str) -> None: diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=obj) diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=obj) diff_merger = await component_registry.get_component(DiffMerger, db=db, branch=obj) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=base_branch, diff_branch=obj) + enriched_diff = await diff_coordinator.update_branch_diff_and_return(base_branch=base_branch, diff_branch=obj) if enriched_diff.get_all_conflicts(): raise ValidationError( f"Branch {obj.name} contains conflicts with the default branch." diff --git a/backend/infrahub/message_bus/operations/event/branch.py b/backend/infrahub/message_bus/operations/event/branch.py index 32c5d63657..a501da69a7 100644 --- a/backend/infrahub/message_bus/operations/event/branch.py +++ b/backend/infrahub/message_bus/operations/event/branch.py @@ -29,7 +29,7 @@ async def merge(message: messages.EventBranchMerge, service: InfrahubServices) - default_branch = registry.get_branch_from_registry() diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=default_branch) # send diff update requests for every branch-tracking diff - branch_diff_roots = await diff_repository.get_empty_roots(base_branch_names=[message.target_branch]) + branch_diff_roots = await diff_repository.get_roots_metadata(base_branch_names=[message.target_branch]) await service.workflow.submit_workflow( workflow=TRIGGER_ARTIFACT_DEFINITION_GENERATE, diff --git a/backend/tests/integration/diff/test_diff_incremental_addition.py b/backend/tests/integration/diff/test_diff_incremental_addition.py index 3263a25282..71abada137 100644 --- a/backend/tests/integration/diff/test_diff_incremental_addition.py +++ b/backend/tests/integration/diff/test_diff_incremental_addition.py @@ -191,7 +191,9 @@ async def test_remove_on_main( diff_coordinator: DiffCoordinator, data_01_remove_on_main, ) -> None: - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) assert len(enriched_diff.nodes) == 0 @@ -263,7 +265,7 @@ async def test_update_previous_owner_on_branch( initial_dataset, data_02_previous_owner_on_branch, ) -> None: - incremental_diff = await diff_coordinator.update_branch_diff( + incremental_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) await self.validate_diff_data_02(db=db, enriched_diff=incremental_diff, initial_dataset=initial_dataset) @@ -271,7 +273,7 @@ async def test_update_previous_owner_on_branch( base_branch=default_branch, diff_branch=diff_branch, from_time=incremental_diff.from_time, - to_time=incremental_diff.to_time, + to_time=Timestamp(), name=str(uuid4()), ) await self.validate_diff_data_02(db=db, enriched_diff=full_diff, initial_dataset=initial_dataset) @@ -344,7 +346,7 @@ async def test_add_new_peer_on_main( diff_coordinator: DiffCoordinator, data_03_new_peer_on_main, ) -> None: - incremental_diff = await diff_coordinator.update_branch_diff( + incremental_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) await self.validate_diff_data_03( @@ -426,7 +428,7 @@ async def test_update_previous_owner_protected_on_branch( diff_coordinator: DiffCoordinator, data_04_update_previous_owner_protected_on_branch, ) -> None: - incremental_diff = await diff_coordinator.update_branch_diff( + incremental_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) await self.validate_diff_data_04(db=db, enriched_diff=incremental_diff, initial_dataset=initial_dataset) @@ -500,7 +502,7 @@ async def test_remove_previous_owner_on_branch( diff_coordinator: DiffCoordinator, data_05_remove_previous_owner_on_branch, ) -> None: - incremental_diff = await diff_coordinator.update_branch_diff( + incremental_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) await self.validate_diff_data_05(db=db, enriched_diff=incremental_diff, initial_dataset=initial_dataset) @@ -570,7 +572,7 @@ async def test_remove_previous_owner_on_main_again( diff_coordinator: DiffCoordinator, data_06_remove_previous_owner_on_main_again, ) -> None: - incremental_diff = await diff_coordinator.update_branch_diff( + incremental_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) await self.validate_diff_data_06(db=db, enriched_diff=incremental_diff, initial_dataset=initial_dataset) diff --git a/backend/tests/integration/diff/test_diff_merge.py b/backend/tests/integration/diff/test_diff_merge.py index 0bdde6683f..089a7bc19c 100644 --- a/backend/tests/integration/diff/test_diff_merge.py +++ b/backend/tests/integration/diff/test_diff_merge.py @@ -117,7 +117,9 @@ async def test_select_cardinality_one_resolution_and_merge( delorean_id = initial_dataset["delorean"].get_id() marty_id = initial_dataset["marty"].get_id() - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) conflicts_map = enriched_diff.get_all_conflicts() assert len(conflicts_map) == 1 owner_conflict = list(conflicts_map.values())[0] @@ -169,7 +171,9 @@ async def test_node_delete_conflict( await new_car.save(db=db) # check that the expected node-level conflict exists - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) conflicts_map = enriched_diff.get_all_conflicts() assert set(conflicts_map.keys()) == {f"data/{person_updated.id}"} conflict = conflicts_map[f"data/{person_updated.id}"] @@ -190,7 +194,9 @@ async def test_node_delete_conflict( await car_branch.save(db=db) # check that the conflict is gone - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) conflicts_map = enriched_diff.get_all_conflicts() assert len(conflicts_map) == 0 diff --git a/backend/tests/integration_docker/test_propose_change_repository.py b/backend/tests/integration_docker/test_propose_change_repository.py index 36984e2c20..96d88eddbf 100644 --- a/backend/tests/integration_docker/test_propose_change_repository.py +++ b/backend/tests/integration_docker/test_propose_change_repository.py @@ -84,5 +84,3 @@ async def test_create_propose_change(self, client: InfrahubClient, default_branc kind=CoreProposedChange, name="pc1", source_branch=branch.name, destination_branch=default_branch ) await pc.save() - - # breakpoint() diff --git a/backend/tests/unit/core/diff/query/test_read.py b/backend/tests/unit/core/diff/query/test_read.py index cd25b17e09..f0c549c069 100644 --- a/backend/tests/unit/core/diff/query/test_read.py +++ b/backend/tests/unit/core/diff/query/test_read.py @@ -163,7 +163,7 @@ async def load_data(self, db: InfrahubDatabase, default_branch: Branch, hierarch diff_coordinator.data_check_synchronizer = AsyncMock(spec=DiffDataCheckSynchronizer) diff_coordinator.data_check_synchronizer.synchronize.return_value = [] - enriched_diff = await diff_coordinator.update_branch_diff( + enriched_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch, ) diff --git a/backend/tests/unit/core/diff/test_coordinator.py b/backend/tests/unit/core/diff/test_coordinator.py index a1ed15c255..7bee68ca06 100644 --- a/backend/tests/unit/core/diff/test_coordinator.py +++ b/backend/tests/unit/core/diff/test_coordinator.py @@ -1,8 +1,13 @@ +from typing import Any +from unittest.mock import AsyncMock, call + from infrahub.core.branch import Branch from infrahub.core.constants import DiffAction from infrahub.core.constants.database import DatabaseEdgeType +from infrahub.core.diff.calculator import DiffCalculator +from infrahub.core.diff.combiner import DiffCombiner from infrahub.core.diff.coordinator import DiffCoordinator -from infrahub.core.diff.model.path import BranchTrackingId +from infrahub.core.diff.model.path import BranchTrackingId, EnrichedDiffRootMetadata, NodeFieldSpecifier from infrahub.core.diff.repository.repository import DiffRepository from infrahub.core.initialization import create_branch from infrahub.core.manager import NodeManager @@ -13,6 +18,27 @@ class TestDiffCoordinator: + async def get_wrapped_diff_coordinator( + self, + db: InfrahubDatabase, + branch: Branch, + ) -> DiffCoordinator: + component_registry = get_component_registry() + diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=branch) + real_repository = await component_registry.get_component(DiffRepository, db=db, branch=branch) + real_calculator = await component_registry.get_component(DiffCalculator, db=db, branch=branch) + real_combiner = await component_registry.get_component(DiffCombiner, db=db, branch=branch) + diff_coordinator.diff_repo = AsyncMock(wraps=real_repository) + diff_coordinator.diff_calculator = AsyncMock(wraps=real_calculator) + diff_coordinator.diff_combiner = AsyncMock(wraps=real_combiner) + return diff_coordinator + + def reset_mocks(self, reset_it: Any) -> None: + for attr_name in dir(reset_it): + attr = getattr(reset_it, attr_name) + if isinstance(attr, AsyncMock): + attr.reset_mock() + async def test_node_deleted_after_branching( self, db: InfrahubDatabase, default_branch: Branch, person_john_main: Node ): @@ -24,7 +50,7 @@ async def test_node_deleted_after_branching( component_registry = get_component_registry() diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=branch) - diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch) + diff = await diff_coordinator.update_branch_diff_and_return(base_branch=default_branch, diff_branch=branch) assert diff.base_branch_name == default_branch.name assert diff.diff_branch_name == branch.name @@ -109,3 +135,97 @@ async def test_overlapping_diffs(self, db: InfrahubDatabase, default_branch: Bra assert diff_property.action is DiffAction.UPDATED assert diff_property.previous_value == str(original_height) assert diff_property.new_value == "3" + + async def test_no_changes_skips_expensive_operations( + self, db: InfrahubDatabase, default_branch: Branch, person_john_main: Node + ): + branch = await create_branch(db=db, branch_name="branch") + wrapped_diff_coordinator = await self.get_wrapped_diff_coordinator(db=db, branch=branch) + + time1 = Timestamp() + person_john_branch = await NodeManager.get_one(db=db, branch=branch, id=person_john_main.id) + person_john_branch.height.value += 1 + await person_john_branch.save(db=db) + time2 = Timestamp() + + # calculate this diff in the middle of change timeframe + diff_with_data = await wrapped_diff_coordinator.create_or_update_arbitrary_timeframe_diff( + base_branch=default_branch, diff_branch=branch, from_time=time1, to_time=time2 + ) + self.reset_mocks(wrapped_diff_coordinator) + + # get the whole diff with no-change time periods before and after the calculated diff + no_changes_diff = await wrapped_diff_coordinator.update_branch_diff( + base_branch=default_branch, diff_branch=branch + ) + assert type(no_changes_diff) is EnrichedDiffRootMetadata + assert no_changes_diff.uuid == diff_with_data.uuid + assert no_changes_diff.from_time == Timestamp(branch.get_branched_from()) + assert no_changes_diff.from_time < diff_with_data.from_time + assert no_changes_diff.to_time > diff_with_data.to_time + wrapped_diff_coordinator.diff_calculator.calculate_diff.assert_not_awaited() + wrapped_diff_coordinator.diff_repo.get_one.assert_not_awaited() + wrapped_diff_coordinator.diff_repo.save.assert_not_awaited() + wrapped_diff_coordinator.diff_repo.hydrate_diff_pair.assert_not_awaited() + + async def test_unrelated_changes_skip_some_expensive_operations( + self, db: InfrahubDatabase, default_branch: Branch, person_john_main: Node + ): + branch = await create_branch(db=db, branch_name="branch") + wrapped_diff_coordinator = await self.get_wrapped_diff_coordinator(db=db, branch=branch) + + # unrelated change on main before + john_main = await NodeManager.get_one(db=db, id=person_john_main.id) + john_main.name.value = "Before John" + await john_main.save(db=db) + + # change on branch for the diff + time1 = Timestamp() + person_john_branch = await NodeManager.get_one(db=db, branch=branch, id=person_john_main.id) + person_john_branch.height.value += 1 + await person_john_branch.save(db=db) + time2 = Timestamp() + + # unrelated change on main after + john_main = await NodeManager.get_one(db=db, id=person_john_main.id) + john_main.name.value = "After John" + await john_main.save(db=db) + + # calculate this diff in the middle of change timeframe + diff_with_data = await wrapped_diff_coordinator.create_or_update_arbitrary_timeframe_diff( + base_branch=default_branch, diff_branch=branch, from_time=time1, to_time=time2 + ) + self.reset_mocks(wrapped_diff_coordinator) + + # get the whole diff with no-change time periods before and after the calculated diff + no_changes_diff = await wrapped_diff_coordinator.update_branch_diff( + base_branch=default_branch, diff_branch=branch + ) + assert type(no_changes_diff) is EnrichedDiffRootMetadata + assert no_changes_diff.uuid == diff_with_data.uuid + assert no_changes_diff.from_time == Timestamp(branch.get_branched_from()) + assert no_changes_diff.from_time < diff_with_data.from_time + assert no_changes_diff.to_time > diff_with_data.to_time + wrapped_diff_coordinator.diff_calculator.calculate_diff.assert_has_awaits( + [ + call( + base_branch=default_branch, + diff_branch=branch, + from_time=Timestamp(branch.get_branched_from()), + to_time=diff_with_data.from_time, + include_unchanged=False, + previous_node_specifiers=set(), + ), + call( + base_branch=default_branch, + diff_branch=branch, + from_time=diff_with_data.to_time, + to_time=no_changes_diff.to_time, + include_unchanged=True, + previous_node_specifiers={NodeFieldSpecifier(node_uuid=person_john_branch.id, field_name="height")}, + ), + ] + ) + wrapped_diff_coordinator.diff_repo.get_one.assert_not_awaited() + wrapped_diff_coordinator.diff_repo.save.assert_not_awaited() + wrapped_diff_coordinator.diff_repo.hydrate_diff_pair.assert_not_awaited() diff --git a/backend/tests/unit/core/diff/test_coordinator_lock.py b/backend/tests/unit/core/diff/test_coordinator_lock.py index 21aeb4ad1c..74f748859c 100644 --- a/backend/tests/unit/core/diff/test_coordinator_lock.py +++ b/backend/tests/unit/core/diff/test_coordinator_lock.py @@ -1,4 +1,5 @@ import asyncio +from datetime import timedelta from unittest.mock import AsyncMock from uuid import uuid4 @@ -78,16 +79,16 @@ async def test_arbitrary_diff_locks_queue_up( ) assert len(results) == 2 assert results[0].to_time != results[1].to_time - assert results[0].uuid != results[1].uuid - assert results[0].partner_uuid != results[1].partner_uuid + assert results[0].uuid == results[1].uuid + assert results[0].partner_uuid == results[1].partner_uuid results[0].to_time = results[1].to_time - results[0].uuid = results[1].uuid - results[0].partner_uuid = results[1].partner_uuid assert results[0] == results[1] - # called once to calculate diff on main and once to calculate diff on the branch - assert len(diff_coordinator.diff_calculator.calculate_diff.call_args_list) == 2 - # not called because diffs are calculated both times - diff_coordinator.diff_repo.get_one.assert_not_awaited() + # second diff uses first diff for its data and is not calculated + assert len(diff_coordinator.diff_calculator.calculate_diff.call_args_list) == 1 + # confirm that we retrieve the first diff to use when calculating the second, overlapping diff + diff_coordinator.diff_repo.get_one.assert_called_once_with( + diff_branch_name=diff_branch.name, diff_id=results[0].uuid + ) async def test_arbitrary_diff_blocks_incremental_diff( self, db: InfrahubDatabase, default_branch: Branch, branch_with_data: Branch @@ -102,22 +103,22 @@ async def test_arbitrary_diff_blocks_incremental_diff( from_time=Timestamp(branch_with_data.branched_from), to_time=Timestamp(), ), - diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch), + diff_coordinator.update_branch_diff_and_return(base_branch=default_branch, diff_branch=diff_branch), ) assert len(results) == 2 assert results[0].to_time != results[1].to_time - assert results[0].uuid != results[1].uuid - assert results[0].partner_uuid != results[1].partner_uuid + assert results[0].uuid == results[1].uuid + assert results[0].partner_uuid == results[1].partner_uuid assert results[0].tracking_id != results[1].tracking_id results[0].to_time = results[1].to_time - results[0].uuid = results[1].uuid - results[0].partner_uuid = results[1].partner_uuid results[0].tracking_id = results[1].tracking_id assert results[0] == results[1] - # called once to calculate diff on main and once to calculate diff on the branch - assert len(diff_coordinator.diff_calculator.calculate_diff.call_args_list) == 2 - # not called because diffs are calculated both times - diff_coordinator.diff_repo.get_one.assert_not_awaited() + # second diff uses first diff for its data and is not calculated + assert len(diff_coordinator.diff_calculator.calculate_diff.call_args_list) == 1 + # confirm that we retrieve the first diff to use when calculating the second, overlapping diff + diff_coordinator.diff_repo.get_one.assert_called_once_with( + diff_branch_name=diff_branch.name, diff_id=results[0].uuid + ) async def test_incremental_diff_blocks_arbitrary_diff( self, db: InfrahubDatabase, default_branch: Branch, branch_with_data: Branch @@ -125,26 +126,28 @@ async def test_incremental_diff_blocks_arbitrary_diff( diff_branch = branch_with_data diff_coordinator = await self.get_diff_coordinator(db=db, diff_branch=diff_branch) + arbitrary_to_time = Timestamp() + arbitrary_to_time.obj += timedelta(seconds=5) results = await asyncio.gather( - diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch), + diff_coordinator.update_branch_diff_and_return(base_branch=default_branch, diff_branch=diff_branch), diff_coordinator.create_or_update_arbitrary_timeframe_diff( base_branch=default_branch, diff_branch=diff_branch, from_time=Timestamp(branch_with_data.branched_from), - to_time=Timestamp(), + to_time=arbitrary_to_time, ), ) assert len(results) == 2 assert results[0].to_time != results[1].to_time - assert results[0].uuid != results[1].uuid - assert results[0].partner_uuid != results[1].partner_uuid + assert results[0].uuid == results[1].uuid + assert results[0].partner_uuid == results[1].partner_uuid assert results[0].tracking_id != results[1].tracking_id results[0].to_time = results[1].to_time - results[0].uuid = results[1].uuid - results[0].partner_uuid = results[1].partner_uuid results[0].tracking_id = results[1].tracking_id assert results[0] == results[1] - # called once to calculate diff on main and once to calculate diff on the branch - assert len(diff_coordinator.diff_calculator.calculate_diff.call_args_list) == 2 - # not called because diffs are calculated both times - diff_coordinator.diff_repo.get_one.assert_not_awaited() + # second diff uses first diff for its data and is not calculated + assert len(diff_coordinator.diff_calculator.calculate_diff.call_args_list) == 1 + # confirm that we retrieve the first diff to use when calculating the second, overlapping diff + diff_coordinator.diff_repo.get_one.assert_called_once_with( + diff_branch_name=diff_branch.name, diff_id=results[0].uuid + ) diff --git a/backend/tests/unit/core/diff/test_diff_and_merge.py b/backend/tests/unit/core/diff/test_diff_and_merge.py index ec22fb7ca7..19f2f1febd 100644 --- a/backend/tests/unit/core/diff/test_diff_and_merge.py +++ b/backend/tests/unit/core/diff/test_diff_and_merge.py @@ -132,7 +132,9 @@ async def test_diff_and_merge_with_attribute_value_conflict( at = Timestamp() diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) conflicts_map = enriched_diff.get_all_conflicts() assert len(conflicts_map) == 1 conflict = next(iter(conflicts_map.values())) @@ -174,7 +176,9 @@ async def test_diff_and_merge_with_relationship_conflict( at = Timestamp() diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) conflicts_map = enriched_diff.get_all_conflicts() assert len(conflicts_map) == 1 conflict = next(iter(conflicts_map.values())) @@ -220,7 +224,9 @@ async def test_diff_and_merge_with_attribute_property_conflict( at = Timestamp() diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) conflicts_map = enriched_diff.get_all_conflicts() assert len(conflicts_map) == 1 conflict = next(iter(conflicts_map.values())) @@ -268,7 +274,9 @@ async def test_diff_and_merge_with_relationship_property_conflict( at = Timestamp() diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) conflicts_map = enriched_diff.get_all_conflicts() # conflict on both sides of the relationship assert len(conflicts_map) == 2 @@ -302,7 +310,9 @@ async def test_single_attribute_update( await person_branch.save(db=db) diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) node = enriched_diff.get_node(node_uuid=person_jane_main.id) assert node.action is DiffAction.UPDATED @@ -329,7 +339,9 @@ async def test_relationship_set_to_null(self, db: InfrahubDatabase, default_bran await dog_branch.save(db=db) diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) dog_node = enriched_diff.get_node(node_uuid=dog_main.id) assert dog_node.action is DiffAction.UPDATED friend_node = enriched_diff.get_node(node_uuid=friend_main.id) @@ -357,7 +369,9 @@ async def test_local_and_aware_nodes_added_on_branch( await car.save(db=db) diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) diff_person = enriched_diff.get_node(node_uuid=person.id) assert diff_person.action is DiffAction.ADDED # validate car is not in the diff @@ -422,7 +436,9 @@ async def test_agnostic_and_aware_nodes_added_on_branch( await car.save(db=db) diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) diff_person = enriched_diff.get_node(node_uuid=person.id) assert diff_person.action is DiffAction.UPDATED diff_car = enriched_diff.get_node(node_uuid=car.id) @@ -560,7 +576,9 @@ async def test_branch_delete_with_added_base_relationship( await car_main.save(db=db) # check that the conflict is removed - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) conflicts_map = enriched_diff.get_all_conflicts() assert len(conflicts_map) == 0 diff --git a/backend/tests/unit/core/diff/test_diff_combiner.py b/backend/tests/unit/core/diff/test_diff_combiner.py index 9e2bd7d709..2d4088c168 100644 --- a/backend/tests/unit/core/diff/test_diff_combiner.py +++ b/backend/tests/unit/core/diff/test_diff_combiner.py @@ -450,6 +450,7 @@ async def test_relationship_one_combined(self, with_schema_manager): expected_relationship = EnrichedDiffRelationship( name=relationship_name, label=later_relationship.label, + identifier=later_relationship.identifier, cardinality=RelationshipCardinality.ONE, changed_at=later_relationship.changed_at, action=DiffAction.ADDED, @@ -618,6 +619,7 @@ async def test_relationship_many_combined(self, with_schema_manager): expected_relationship = EnrichedDiffRelationship( name=relationship_name, label=relationship_group_2.label, + identifier=relationship_group_2.identifier, cardinality=RelationshipCardinality.MANY, changed_at=relationship_group_2.changed_at, action=DiffAction.UPDATED, @@ -682,6 +684,7 @@ async def test_relationship_with_only_nodes(self, with_schema_manager): expected_relationship = EnrichedDiffRelationship( name=relationship_name, label=later_relationship.label, + identifier=later_relationship.identifier, cardinality=RelationshipCardinality.MANY, changed_at=later_relationship.changed_at, action=DiffAction.ADDED, @@ -851,6 +854,7 @@ async def test_unchanged_parents_correctly_updated(self): expected_relationship = EnrichedDiffRelationship( name=relationship_name, label=parent_rel_2.label, + identifier=parent_rel_2.identifier, changed_at=parent_rel_2.changed_at, cardinality=RelationshipCardinality.ONE, path_identifier=parent_rel_2.path_identifier, @@ -923,6 +927,7 @@ async def test_updated_parents_correctly_updated(self): expected_child_rel = EnrichedDiffRelationship( name=relationship_name, label=child_rel_2.label, + identifier=child_rel_2.identifier, changed_at=child_rel_2.changed_at, cardinality=RelationshipCardinality.ONE, path_identifier=child_rel_2.path_identifier, @@ -1083,6 +1088,7 @@ async def test_resetting_relationship_one_makes_it_unchanged(self, with_schema_m expected_relationship = EnrichedDiffRelationship( name=relationship_name, label=later_relationship.label, + identifier=later_relationship.identifier, cardinality=RelationshipCardinality.ONE, changed_at=later_relationship.changed_at, action=DiffAction.UPDATED, diff --git a/backend/tests/unit/core/diff/test_diff_merger.py b/backend/tests/unit/core/diff/test_diff_merger.py index 89e7b534c8..5d7b060026 100644 --- a/backend/tests/unit/core/diff/test_diff_merger.py +++ b/backend/tests/unit/core/diff/test_diff_merger.py @@ -326,7 +326,7 @@ async def test_merge_node_added( check_idempotent: bool, ): empty_diff_root.nodes = {added_person_node_diff} - mock_diff_repository.get_empty_roots.return_value = [empty_diff_root] + mock_diff_repository.get_roots_metadata.return_value = [empty_diff_root] mock_diff_repository.get_one.return_value = empty_diff_root at = Timestamp() @@ -380,7 +380,7 @@ async def test_merge_node_deleted( person_branch = await NodeManager.get_one(db=db, branch=source_branch, id=person_node_main.id) await person_branch.delete(db=db) empty_diff_root.nodes = {deleted_person_node_diff} - mock_diff_repository.get_empty_roots.return_value = [empty_diff_root] + mock_diff_repository.get_roots_metadata.return_value = [empty_diff_root] mock_diff_repository.get_one.return_value = empty_diff_root at = Timestamp() @@ -433,7 +433,7 @@ async def test_merge_node_deleted_with_conflict( ) deleted_node_diff.conflict = node_conflict empty_diff_root.nodes = {deleted_node_diff} - mock_diff_repository.get_empty_roots.return_value = [empty_diff_root] + mock_diff_repository.get_roots_metadata.return_value = [empty_diff_root] mock_diff_repository.get_one.return_value = empty_diff_root at = Timestamp() @@ -654,7 +654,7 @@ async def test_merge_node_updated( await car_branch.save(db=db) empty_diff_root.nodes = {updated_person_node_diff, updated_car_diff} - mock_diff_repository.get_empty_roots.return_value = [empty_diff_root] + mock_diff_repository.get_roots_metadata.return_value = [empty_diff_root] mock_diff_repository.get_one.return_value = empty_diff_root at = Timestamp() diff --git a/backend/tests/unit/core/ipam/conftest.py b/backend/tests/unit/core/ipam/conftest.py index 9a2776b17d..4efceedb8c 100644 --- a/backend/tests/unit/core/ipam/conftest.py +++ b/backend/tests/unit/core/ipam/conftest.py @@ -191,7 +191,7 @@ async def ip_dataset_01( ): yield ip_dataset_01_load - all_diff_roots = await diff_repository.get_empty_roots() + all_diff_roots = await diff_repository.get_roots_metadata() root_uuids_to_delete = [] for diff_root in all_diff_roots: if start_time <= diff_root.from_time: diff --git a/backend/tests/unit/graphql/test_diff_tree_query.py b/backend/tests/unit/graphql/test_diff_tree_query.py index dbe4da0a4c..999a2d851f 100644 --- a/backend/tests/unit/graphql/test_diff_tree_query.py +++ b/backend/tests/unit/graphql/test_diff_tree_query.py @@ -214,7 +214,9 @@ async def test_diff_tree_no_changes( diff_coordinator: DiffCoordinator, diff_branch: Branch, ): - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) from_time = datetime.fromisoformat(diff_branch.branched_from) to_time = datetime.fromisoformat(enriched_diff.to_time.to_string()) @@ -297,7 +299,9 @@ async def test_diff_tree_one_attr_change( await branch_crit.save(db=db) after_change_datetime = datetime.now(tz=UTC) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) enriched_conflict_map = enriched_diff.get_all_conflicts() enriched_conflict = list(enriched_conflict_map.values())[0] await diff_repository.update_conflict_by_id( @@ -425,7 +429,9 @@ async def test_diff_tree_one_relationship_change( john_label = await person_john_main.render_display_label(db=db) jane_label = await person_jane_main.render_display_label(db=db) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) params = await prepare_graphql_params( db=db, include_mutation=False, include_subscription=False, branch=default_branch ) @@ -732,7 +738,9 @@ async def test_diff_tree_summary_no_changes( diff_coordinator: DiffCoordinator, diff_branch: Branch, ): - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) from_time = datetime.fromisoformat(diff_branch.branched_from) to_time = datetime.fromisoformat(enriched_diff.to_time.to_string()) @@ -818,7 +826,9 @@ async def test_diff_summary_filters( # ---------------------------- component_registry = get_component_registry() diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=diff_branch) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) params = await prepare_graphql_params( db=db, include_mutation=False, include_subscription=False, branch=default_branch ) diff --git a/backend/tests/unit/message_bus/operations/event/test_branch.py b/backend/tests/unit/message_bus/operations/event/test_branch.py index a4f3bcc868..52317b109c 100644 --- a/backend/tests/unit/message_bus/operations/event/test_branch.py +++ b/backend/tests/unit/message_bus/operations/event/test_branch.py @@ -71,7 +71,7 @@ async def test_merged(default_branch: Branch, init_service: InfrahubServices, pr for _ in range(2) ] diff_repo = AsyncMock(spec=DiffRepository) - diff_repo.get_empty_roots.return_value = untracked_diff_roots + tracked_diff_roots + diff_repo.get_roots_metadata.return_value = untracked_diff_roots + tracked_diff_roots mock_component_registry = Mock(spec=ComponentDependencyRegistry) mock_get_component_registry = MagicMock(return_value=mock_component_registry) mock_component_registry.get_component.return_value = diff_repo @@ -104,7 +104,7 @@ async def test_merged(default_branch: Branch, init_service: InfrahubServices, pr # Use `db=ANY` as a new InfrahubDatabase object is created as we use a new session mock_component_registry.get_component.assert_awaited_once_with(DiffRepository, db=ANY, branch=default_branch) - diff_repo.get_empty_roots.assert_awaited_once_with(base_branch_names=[target_branch_name]) + diff_repo.get_roots_metadata.assert_awaited_once_with(base_branch_names=[target_branch_name]) assert len(service.message_bus.messages) == 1 assert service.message_bus.messages[0] == messages.RefreshRegistryBranches() diff --git a/changelog/+incremental-diff-performance.fixed.md b/changelog/+incremental-diff-performance.fixed.md new file mode 100644 index 0000000000..005a73a2a5 --- /dev/null +++ b/changelog/+incremental-diff-performance.fixed.md @@ -0,0 +1 @@ +Update how we calculate an incremental diff to skip potentially expensive operations if at all possible \ No newline at end of file