From deef4cbeddbcf2ef59b59767b62a76f9adcd25d9 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Thu, 12 Dec 2024 17:32:06 -0800 Subject: [PATCH 01/18] skip diff update if no changes --- backend/infrahub/core/branch/tasks.py | 2 +- backend/infrahub/core/diff/coordinator.py | 51 ++++++++++++++----- .../core/diff/repository/repository.py | 8 +++ backend/infrahub/core/merge.py | 2 +- backend/infrahub/graphql/mutations/tasks.py | 2 +- .../diff/test_diff_incremental_addition.py | 14 ++--- .../tests/integration/diff/test_diff_merge.py | 4 +- .../tests/unit/core/diff/test_coordinator.py | 2 +- .../unit/core/diff/test_coordinator_lock.py | 2 +- .../unit/core/diff/test_diff_and_merge.py | 32 +++++++++--- .../unit/graphql/test_diff_tree_query.py | 24 +++++---- 11 files changed, 100 insertions(+), 43 deletions(-) diff --git a/backend/infrahub/core/branch/tasks.py b/backend/infrahub/core/branch/tasks.py index 41f5796d05..542f332bab 100644 --- a/backend/infrahub/core/branch/tasks.py +++ b/backend/infrahub/core/branch/tasks.py @@ -70,7 +70,7 @@ async def rebase_branch(branch: str) -> None: service=service, ) diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=obj) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=base_branch, diff_branch=obj) + enriched_diff = await diff_coordinator.update_branch_diff_and_return(base_branch=base_branch, diff_branch=obj) if enriched_diff.get_all_conflicts(): raise ValidationError( f"Branch {obj.name} contains conflicts with the default branch that must be addressed." diff --git a/backend/infrahub/core/diff/coordinator.py b/backend/infrahub/core/diff/coordinator.py index 88d237d40c..75f5ec0e8f 100644 --- a/backend/infrahub/core/diff/coordinator.py +++ b/backend/infrahub/core/diff/coordinator.py @@ -77,10 +77,11 @@ async def run_update( from_time: str | None = None, to_time: str | None = None, name: str | None = None, - ) -> EnrichedDiffRoot: + ) -> None: # we are updating a diff that tracks the full lifetime of a branch if not name and not from_time and not to_time: - return await self.update_branch_diff(base_branch=base_branch, diff_branch=diff_branch) + await self.update_branch_diff(base_branch=base_branch, diff_branch=diff_branch) + return if from_time: from_timestamp = Timestamp(from_time) @@ -90,7 +91,7 @@ async def run_update( to_timestamp = Timestamp(to_time) else: to_timestamp = Timestamp() - return await self.create_or_update_arbitrary_timeframe_diff( + await self.create_or_update_arbitrary_timeframe_diff( base_branch=base_branch, diff_branch=diff_branch, from_time=from_timestamp, @@ -104,8 +105,16 @@ def _get_lock_name(self, base_branch_name: str, diff_branch_name: str, is_increm lock_name += "__incremental" return lock_name - async def update_branch_diff(self, base_branch: Branch, diff_branch: Branch) -> EnrichedDiffRoot: - log.debug(f"Received request to update branch diff for {base_branch.name} - {diff_branch.name}") + async def update_branch_diff_and_return(self, base_branch: Branch, diff_branch: Branch) -> EnrichedDiffRoot: + enriched_diff = await self.update_branch_diff(base_branch=base_branch, diff_branch=diff_branch) + if enriched_diff: + return enriched_diff + return await self.diff_repo.get_one( + diff_branch_name=diff_branch.name, tracking_id=BranchTrackingId(name=diff_branch.name) + ) + + async def update_branch_diff(self, base_branch: Branch, diff_branch: Branch) -> EnrichedDiffRoot | None: + log.info(f"Received request to update branch diff for {base_branch.name} - {diff_branch.name}") incremental_lock_name = self._get_lock_name( base_branch_name=base_branch.name, diff_branch_name=diff_branch.name, is_incremental=True ) @@ -113,9 +122,9 @@ async def update_branch_diff(self, base_branch: Branch, diff_branch: Branch) -> name=incremental_lock_name, namespace=self.lock_namespace ) if existing_incremental_lock and await existing_incremental_lock.locked(): - log.debug(f"Branch diff update for {base_branch.name} - {diff_branch.name} already in progress") + log.info(f"Branch diff update for {base_branch.name} - {diff_branch.name} already in progress") async with self.lock_registry.get(name=incremental_lock_name, namespace=self.lock_namespace): - log.debug(f"Existing branch diff update for {base_branch.name} - {diff_branch.name} complete") + log.info(f"Existing branch diff update for {base_branch.name} - {diff_branch.name} complete") return await self.diff_repo.get_one( tracking_id=BranchTrackingId(name=diff_branch.name), diff_branch_name=diff_branch.name ) @@ -129,7 +138,7 @@ async def update_branch_diff(self, base_branch: Branch, diff_branch: Branch) -> self.lock_registry.get(name=general_lock_name, namespace=self.lock_namespace), self.lock_registry.get(name=incremental_lock_name, namespace=self.lock_namespace), ): - log.debug(f"Acquired lock to run branch diff update for {base_branch.name} - {diff_branch.name}") + log.info(f"Acquired lock to run branch diff update for {base_branch.name} - {diff_branch.name}") enriched_diffs = await self._update_diffs( base_branch=base_branch, diff_branch=diff_branch, @@ -137,11 +146,13 @@ async def update_branch_diff(self, base_branch: Branch, diff_branch: Branch) -> to_time=to_time, tracking_id=tracking_id, ) + if not enriched_diffs: + return None await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.base_branch_diff) await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.diff_branch_diff) await self.diff_repo.save(enriched_diffs=enriched_diffs) await self._update_core_data_checks(enriched_diff=enriched_diffs.diff_branch_diff) - log.debug(f"Branch diff update complete for {base_branch.name} - {diff_branch.name}") + log.info(f"Branch diff update complete for {base_branch.name} - {diff_branch.name}") return enriched_diffs.diff_branch_diff async def create_or_update_arbitrary_timeframe_diff( @@ -159,7 +170,7 @@ async def create_or_update_arbitrary_timeframe_diff( base_branch_name=base_branch.name, diff_branch_name=diff_branch.name, is_incremental=False ) async with self.lock_registry.get(name=general_lock_name, namespace=self.lock_namespace): - log.debug(f"Acquired lock to run arbitrary diff update for {base_branch.name} - {diff_branch.name}") + log.info(f"Acquired lock to run arbitrary diff update for {base_branch.name} - {diff_branch.name}") enriched_diffs = await self._update_diffs( base_branch=base_branch, diff_branch=diff_branch, @@ -167,11 +178,14 @@ async def create_or_update_arbitrary_timeframe_diff( to_time=to_time, tracking_id=tracking_id, ) + if not enriched_diffs: + return await self.diff_repo.get_one(diff_branch_name=diff_branch.name, tracking_id=tracking_id) + await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.base_branch_diff) await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.diff_branch_diff) await self.diff_repo.save(enriched_diffs=enriched_diffs) await self._update_core_data_checks(enriched_diff=enriched_diffs.diff_branch_diff) - log.debug(f"Arbitrary diff update complete for {base_branch.name} - {diff_branch.name}") + log.info(f"Arbitrary diff update complete for {base_branch.name} - {diff_branch.name}") return enriched_diffs.diff_branch_diff async def recalculate( @@ -184,7 +198,7 @@ async def recalculate( base_branch_name=base_branch.name, diff_branch_name=diff_branch.name, is_incremental=False ) async with self.lock_registry.get(name=general_lock_name, namespace=self.lock_namespace): - log.debug(f"Acquired lock to recalculate diff for {base_branch.name} - {diff_branch.name}") + log.info(f"Acquired lock to recalculate diff for {base_branch.name} - {diff_branch.name}") current_branch_diff = await self.diff_repo.get_one(diff_branch_name=diff_branch.name, diff_id=diff_id) current_base_diff = await self.diff_repo.get_one( diff_branch_name=base_branch.name, diff_id=current_branch_diff.partner_uuid @@ -205,6 +219,10 @@ async def recalculate( tracking_id=current_branch_diff.tracking_id, force_branch_refresh=True, ) + if not enriched_diffs: + return await self.diff_repo.get_one( + diff_branch_name=diff_branch.name, tracking_id=current_branch_diff.tracking_id + ) if current_branch_diff: await self.conflict_transferer.transfer( @@ -215,7 +233,7 @@ async def recalculate( await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.diff_branch_diff) await self.diff_repo.save(enriched_diffs=enriched_diffs) await self._update_core_data_checks(enriched_diff=enriched_diffs.diff_branch_diff) - log.debug(f"Diff recalculation complete for {base_branch.name} - {diff_branch.name}") + log.info(f"Diff recalculation complete for {base_branch.name} - {diff_branch.name}") return enriched_diffs.diff_branch_diff def _get_ordered_diff_pairs( @@ -250,7 +268,12 @@ async def _update_diffs( to_time: Timestamp, tracking_id: TrackingId | None = None, force_branch_refresh: bool = False, - ) -> EnrichedDiffs: + ) -> EnrichedDiffs | None: + if not force_branch_refresh and not await self._any_changes_after_last_diff( + base_branch=base_branch, diff_branch=diff_branch, to_time=to_time + ): + return None + diff_uuids_to_delete = [] retrieved_enriched_diffs = await self.diff_repo.get_pairs( base_branch_name=base_branch.name, diff --git a/backend/infrahub/core/diff/repository/repository.py b/backend/infrahub/core/diff/repository/repository.py index dd738b6763..963595fcd1 100644 --- a/backend/infrahub/core/diff/repository/repository.py +++ b/backend/infrahub/core/diff/repository/repository.py @@ -3,6 +3,7 @@ from infrahub import config from infrahub.core import registry from infrahub.core.diff.query.field_summary import EnrichedDiffNodeFieldSummaryQuery +from infrahub.core.query.diff import DiffCountChanges from infrahub.core.timestamp import Timestamp from infrahub.database import InfrahubDatabase, retry_db_transaction from infrahub.exceptions import ResourceNotFoundError @@ -267,3 +268,10 @@ async def get_node_field_summaries( async def drop_tracking_ids(self, tracking_ids: list[TrackingId]) -> None: query = await EnrichedDiffDropTrackingIdQuery.init(db=self.db, tracking_ids=tracking_ids) await query.execute(db=self.db) + + async def get_num_changes_in_time_range_by_branch( + self, branch_names: list[str], from_time: Timestamp, to_time: Timestamp + ) -> dict[str, int]: + query = await DiffCountChanges.init(db=self.db, branch_names=branch_names, diff_from=from_time, diff_to=to_time) + await query.execute(db=self.db) + return query.get_num_changes_by_branch() diff --git a/backend/infrahub/core/merge.py b/backend/infrahub/core/merge.py index 5fe3c1c8e6..32232e277a 100644 --- a/backend/infrahub/core/merge.py +++ b/backend/infrahub/core/merge.py @@ -174,7 +174,7 @@ async def merge( if self.source_branch.name == registry.default_branch: raise ValidationError(f"Unable to merge the branch '{self.source_branch.name}' into itself") - enriched_diff = await self.diff_coordinator.update_branch_diff( + enriched_diff = await self.diff_coordinator.update_branch_diff_and_return( base_branch=self.destination_branch, diff_branch=self.source_branch ) conflict_map = enriched_diff.get_all_conflicts() diff --git a/backend/infrahub/graphql/mutations/tasks.py b/backend/infrahub/graphql/mutations/tasks.py index 3eb1c8a187..b3d571fc76 100644 --- a/backend/infrahub/graphql/mutations/tasks.py +++ b/backend/infrahub/graphql/mutations/tasks.py @@ -31,7 +31,7 @@ async def merge_branch_mutation(branch: str) -> None: diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=obj) diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=obj) diff_merger = await component_registry.get_component(DiffMerger, db=db, branch=obj) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=base_branch, diff_branch=obj) + enriched_diff = await diff_coordinator.update_branch_diff_and_return(base_branch=base_branch, diff_branch=obj) if enriched_diff.get_all_conflicts(): raise ValidationError( f"Branch {obj.name} contains conflicts with the default branch." diff --git a/backend/tests/integration/diff/test_diff_incremental_addition.py b/backend/tests/integration/diff/test_diff_incremental_addition.py index 3263a25282..40909ede9b 100644 --- a/backend/tests/integration/diff/test_diff_incremental_addition.py +++ b/backend/tests/integration/diff/test_diff_incremental_addition.py @@ -191,7 +191,9 @@ async def test_remove_on_main( diff_coordinator: DiffCoordinator, data_01_remove_on_main, ) -> None: - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) assert len(enriched_diff.nodes) == 0 @@ -263,7 +265,7 @@ async def test_update_previous_owner_on_branch( initial_dataset, data_02_previous_owner_on_branch, ) -> None: - incremental_diff = await diff_coordinator.update_branch_diff( + incremental_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) await self.validate_diff_data_02(db=db, enriched_diff=incremental_diff, initial_dataset=initial_dataset) @@ -344,7 +346,7 @@ async def test_add_new_peer_on_main( diff_coordinator: DiffCoordinator, data_03_new_peer_on_main, ) -> None: - incremental_diff = await diff_coordinator.update_branch_diff( + incremental_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) await self.validate_diff_data_03( @@ -426,7 +428,7 @@ async def test_update_previous_owner_protected_on_branch( diff_coordinator: DiffCoordinator, data_04_update_previous_owner_protected_on_branch, ) -> None: - incremental_diff = await diff_coordinator.update_branch_diff( + incremental_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) await self.validate_diff_data_04(db=db, enriched_diff=incremental_diff, initial_dataset=initial_dataset) @@ -500,7 +502,7 @@ async def test_remove_previous_owner_on_branch( diff_coordinator: DiffCoordinator, data_05_remove_previous_owner_on_branch, ) -> None: - incremental_diff = await diff_coordinator.update_branch_diff( + incremental_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) await self.validate_diff_data_05(db=db, enriched_diff=incremental_diff, initial_dataset=initial_dataset) @@ -570,7 +572,7 @@ async def test_remove_previous_owner_on_main_again( diff_coordinator: DiffCoordinator, data_06_remove_previous_owner_on_main_again, ) -> None: - incremental_diff = await diff_coordinator.update_branch_diff( + incremental_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) await self.validate_diff_data_06(db=db, enriched_diff=incremental_diff, initial_dataset=initial_dataset) diff --git a/backend/tests/integration/diff/test_diff_merge.py b/backend/tests/integration/diff/test_diff_merge.py index 0bdde6683f..f5e96578ae 100644 --- a/backend/tests/integration/diff/test_diff_merge.py +++ b/backend/tests/integration/diff/test_diff_merge.py @@ -117,7 +117,9 @@ async def test_select_cardinality_one_resolution_and_merge( delorean_id = initial_dataset["delorean"].get_id() marty_id = initial_dataset["marty"].get_id() - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) conflicts_map = enriched_diff.get_all_conflicts() assert len(conflicts_map) == 1 owner_conflict = list(conflicts_map.values())[0] diff --git a/backend/tests/unit/core/diff/test_coordinator.py b/backend/tests/unit/core/diff/test_coordinator.py index a1ed15c255..98a129ad7c 100644 --- a/backend/tests/unit/core/diff/test_coordinator.py +++ b/backend/tests/unit/core/diff/test_coordinator.py @@ -24,7 +24,7 @@ async def test_node_deleted_after_branching( component_registry = get_component_registry() diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=branch) - diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch) + diff = await diff_coordinator.update_branch_diff_and_return(base_branch=default_branch, diff_branch=branch) assert diff.base_branch_name == default_branch.name assert diff.diff_branch_name == branch.name diff --git a/backend/tests/unit/core/diff/test_coordinator_lock.py b/backend/tests/unit/core/diff/test_coordinator_lock.py index 21aeb4ad1c..045fb74cf7 100644 --- a/backend/tests/unit/core/diff/test_coordinator_lock.py +++ b/backend/tests/unit/core/diff/test_coordinator_lock.py @@ -102,7 +102,7 @@ async def test_arbitrary_diff_blocks_incremental_diff( from_time=Timestamp(branch_with_data.branched_from), to_time=Timestamp(), ), - diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch), + diff_coordinator.update_branch_diff_and_return(base_branch=default_branch, diff_branch=diff_branch), ) assert len(results) == 2 assert results[0].to_time != results[1].to_time diff --git a/backend/tests/unit/core/diff/test_diff_and_merge.py b/backend/tests/unit/core/diff/test_diff_and_merge.py index ec22fb7ca7..4b203d1d5b 100644 --- a/backend/tests/unit/core/diff/test_diff_and_merge.py +++ b/backend/tests/unit/core/diff/test_diff_and_merge.py @@ -132,7 +132,9 @@ async def test_diff_and_merge_with_attribute_value_conflict( at = Timestamp() diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) conflicts_map = enriched_diff.get_all_conflicts() assert len(conflicts_map) == 1 conflict = next(iter(conflicts_map.values())) @@ -174,7 +176,9 @@ async def test_diff_and_merge_with_relationship_conflict( at = Timestamp() diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) conflicts_map = enriched_diff.get_all_conflicts() assert len(conflicts_map) == 1 conflict = next(iter(conflicts_map.values())) @@ -220,7 +224,9 @@ async def test_diff_and_merge_with_attribute_property_conflict( at = Timestamp() diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) conflicts_map = enriched_diff.get_all_conflicts() assert len(conflicts_map) == 1 conflict = next(iter(conflicts_map.values())) @@ -268,7 +274,9 @@ async def test_diff_and_merge_with_relationship_property_conflict( at = Timestamp() diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) conflicts_map = enriched_diff.get_all_conflicts() # conflict on both sides of the relationship assert len(conflicts_map) == 2 @@ -302,7 +310,9 @@ async def test_single_attribute_update( await person_branch.save(db=db) diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) node = enriched_diff.get_node(node_uuid=person_jane_main.id) assert node.action is DiffAction.UPDATED @@ -329,7 +339,9 @@ async def test_relationship_set_to_null(self, db: InfrahubDatabase, default_bran await dog_branch.save(db=db) diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) dog_node = enriched_diff.get_node(node_uuid=dog_main.id) assert dog_node.action is DiffAction.UPDATED friend_node = enriched_diff.get_node(node_uuid=friend_main.id) @@ -357,7 +369,9 @@ async def test_local_and_aware_nodes_added_on_branch( await car.save(db=db) diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) diff_person = enriched_diff.get_node(node_uuid=person.id) assert diff_person.action is DiffAction.ADDED # validate car is not in the diff @@ -422,7 +436,9 @@ async def test_agnostic_and_aware_nodes_added_on_branch( await car.save(db=db) diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) diff_person = enriched_diff.get_node(node_uuid=person.id) assert diff_person.action is DiffAction.UPDATED diff_car = enriched_diff.get_node(node_uuid=car.id) diff --git a/backend/tests/unit/graphql/test_diff_tree_query.py b/backend/tests/unit/graphql/test_diff_tree_query.py index dbe4da0a4c..f65f7e7399 100644 --- a/backend/tests/unit/graphql/test_diff_tree_query.py +++ b/backend/tests/unit/graphql/test_diff_tree_query.py @@ -214,7 +214,9 @@ async def test_diff_tree_no_changes( diff_coordinator: DiffCoordinator, diff_branch: Branch, ): - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) from_time = datetime.fromisoformat(diff_branch.branched_from) to_time = datetime.fromisoformat(enriched_diff.to_time.to_string()) @@ -297,7 +299,9 @@ async def test_diff_tree_one_attr_change( await branch_crit.save(db=db) after_change_datetime = datetime.now(tz=UTC) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) enriched_conflict_map = enriched_diff.get_all_conflicts() enriched_conflict = list(enriched_conflict_map.values())[0] await diff_repository.update_conflict_by_id( @@ -425,10 +429,10 @@ async def test_diff_tree_one_relationship_change( john_label = await person_john_main.render_display_label(db=db) jane_label = await person_jane_main.render_display_label(db=db) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) - params = await prepare_graphql_params( - db=db, include_mutation=False, include_subscription=False, branch=default_branch + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch ) + params = prepare_graphql_params(db=db, include_mutation=False, include_subscription=False, branch=default_branch) result = await graphql( schema=params.schema, source=DIFF_TREE_QUERY, @@ -732,7 +736,9 @@ async def test_diff_tree_summary_no_changes( diff_coordinator: DiffCoordinator, diff_branch: Branch, ): - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) from_time = datetime.fromisoformat(diff_branch.branched_from) to_time = datetime.fromisoformat(enriched_diff.to_time.to_string()) @@ -818,10 +824,10 @@ async def test_diff_summary_filters( # ---------------------------- component_registry = get_component_registry() diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=diff_branch) - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) - params = await prepare_graphql_params( - db=db, include_mutation=False, include_subscription=False, branch=default_branch + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch ) + params = prepare_graphql_params(db=db, include_mutation=False, include_subscription=False, branch=default_branch) result = await graphql( schema=params.schema, From 2030ea65633292a3ba78766ce022e6b4ae537885 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Mon, 16 Dec 2024 11:05:47 -0800 Subject: [PATCH 02/18] missing unit test update --- backend/tests/unit/core/diff/query/test_read.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/tests/unit/core/diff/query/test_read.py b/backend/tests/unit/core/diff/query/test_read.py index cd25b17e09..f0c549c069 100644 --- a/backend/tests/unit/core/diff/query/test_read.py +++ b/backend/tests/unit/core/diff/query/test_read.py @@ -163,7 +163,7 @@ async def load_data(self, db: InfrahubDatabase, default_branch: Branch, hierarch diff_coordinator.data_check_synchronizer = AsyncMock(spec=DiffDataCheckSynchronizer) diff_coordinator.data_check_synchronizer.synchronize.return_value = [] - enriched_diff = await diff_coordinator.update_branch_diff( + enriched_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch, ) From e8d9e3a8f256f92c6b0b08db5ad0a9f7f5903549 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Thu, 19 Dec 2024 18:35:52 -0800 Subject: [PATCH 03/18] refactor to improve performance if no changes --- backend/infrahub/core/diff/coordinator.py | 182 +++++++++++++----- backend/infrahub/core/diff/model/path.py | 38 +++- .../infrahub/core/diff/query/empty_roots.py | 19 +- .../core/diff/repository/deserializer.py | 8 +- .../core/diff/repository/repository.py | 59 +++++- 5 files changed, 245 insertions(+), 61 deletions(-) diff --git a/backend/infrahub/core/diff/coordinator.py b/backend/infrahub/core/diff/coordinator.py index 75f5ec0e8f..74210f2f7e 100644 --- a/backend/infrahub/core/diff/coordinator.py +++ b/backend/infrahub/core/diff/coordinator.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Iterable +from uuid import uuid4 from infrahub import lock from infrahub.core import registry @@ -11,7 +12,9 @@ from .model.path import ( BranchTrackingId, EnrichedDiffRoot, + EnrichedDiffRootEmpty, EnrichedDiffs, + EnrichedDiffsEmpty, NameTrackingId, NodeFieldSpecifier, TrackingId, @@ -237,12 +240,12 @@ async def recalculate( return enriched_diffs.diff_branch_diff def _get_ordered_diff_pairs( - self, diff_pairs: Iterable[EnrichedDiffs], allow_overlap: bool = False - ) -> list[EnrichedDiffs]: + self, diff_pairs: Iterable[EnrichedDiffsEmpty], allow_overlap: bool = False + ) -> list[EnrichedDiffsEmpty]: ordered_diffs = sorted(diff_pairs, key=lambda d: d.diff_branch_diff.from_time) if allow_overlap: return ordered_diffs - ordered_diffs_no_overlaps: list[EnrichedDiffs] = [] + ordered_diffs_no_overlaps: list[EnrichedDiffsEmpty] = [] for candidate_diff_pair in ordered_diffs: if not ordered_diffs_no_overlaps: ordered_diffs_no_overlaps.append(candidate_diff_pair) @@ -260,6 +263,30 @@ def _get_ordered_diff_pairs( ordered_diffs_no_overlaps[-1] = candidate_diff_pair return ordered_diffs_no_overlaps + def _build_empty_enriched_diffs(self, diff_request: EnrichedDiffRequest) -> EnrichedDiffs: + base_uuid = str(uuid4()) + branch_uuid = str(uuid4()) + return EnrichedDiffs( + base_branch_name=diff_request.base_branch.name, + diff_branch_name=diff_request.diff_branch.name, + base_branch_diff=EnrichedDiffRoot( + base_branch_name=diff_request.base_branch.name, + diff_branch_name=diff_request.base_branch.name, + from_time=diff_request.from_time, + to_time=diff_request.to_time, + uuid=base_uuid, + partner_uuid=branch_uuid, + ), + diff_branch_diff=EnrichedDiffRoot( + base_branch_name=diff_request.base_branch.name, + diff_branch_name=diff_request.diff_branch.name, + from_time=diff_request.from_time, + to_time=diff_request.to_time, + uuid=branch_uuid, + partner_uuid=base_uuid, + ), + ) + async def _update_diffs( self, base_branch: Branch, @@ -269,33 +296,31 @@ async def _update_diffs( tracking_id: TrackingId | None = None, force_branch_refresh: bool = False, ) -> EnrichedDiffs | None: - if not force_branch_refresh and not await self._any_changes_after_last_diff( - base_branch=base_branch, diff_branch=diff_branch, to_time=to_time - ): - return None - diff_uuids_to_delete = [] - retrieved_enriched_diffs = await self.diff_repo.get_pairs( - base_branch_name=base_branch.name, - diff_branch_name=diff_branch.name, + # start with empty diffs b/c we only care about their metadata for now, hydrate them with data as needed + empty_diff_pairs = await self.diff_repo.get_empty_diff_pairs( + base_branch_names=[base_branch.name], + diff_branch_names=[diff_branch.name], from_time=from_time, to_time=to_time, ) - for enriched_diffs in retrieved_enriched_diffs: - if tracking_id: - if enriched_diffs.base_branch_diff.tracking_id: - diff_uuids_to_delete.append(enriched_diffs.base_branch_diff.uuid) - if enriched_diffs.diff_branch_diff.tracking_id: - diff_uuids_to_delete.append(enriched_diffs.diff_branch_diff.uuid) - aggregated_enriched_diffs = await self._get_aggregated_enriched_diffs( + if tracking_id: + for diff_pair in empty_diff_pairs: + if diff_pair.base_branch_diff.tracking_id: + diff_uuids_to_delete.append(diff_pair.base_branch_diff.uuid) + if diff_pair.diff_branch_diff.tracking_id: + diff_uuids_to_delete.append(diff_pair.diff_branch_diff.uuid) + aggregated_enriched_diffs = await self._aggregate_enriched_diffs( diff_request=EnrichedDiffRequest( base_branch=base_branch, diff_branch=diff_branch, from_time=from_time, to_time=to_time, ), - partial_enriched_diffs=retrieved_enriched_diffs if not force_branch_refresh else [], + partial_enriched_diffs=empty_diff_pairs if not force_branch_refresh else [], ) + if not aggregated_enriched_diffs: + return None await self.conflicts_enricher.add_conflicts_to_branch_diff( base_diff_root=aggregated_enriched_diffs.base_branch_diff, @@ -312,57 +337,114 @@ async def _update_diffs( await self.diff_repo.delete_diff_roots(diff_root_uuids=diff_uuids_to_delete) return aggregated_enriched_diffs - async def _get_aggregated_enriched_diffs( - self, diff_request: EnrichedDiffRequest, partial_enriched_diffs: list[EnrichedDiffs] - ) -> EnrichedDiffs: + async def _aggregate_enriched_diffs( + self, diff_request: EnrichedDiffRequest, partial_enriched_diffs: list[EnrichedDiffsEmpty] + ) -> EnrichedDiffs | None: if not partial_enriched_diffs: - return await self._get_enriched_diff(diff_request=diff_request, is_incremental_diff=False) + return await self._calculate_enriched_diff(diff_request=diff_request, is_incremental_diff=False) ordered_diffs = self._get_ordered_diff_pairs(diff_pairs=partial_enriched_diffs, allow_overlap=False) ordered_diff_reprs = [repr(d) for d in ordered_diffs] - log.debug(f"Ordered diffs for aggregation: {ordered_diff_reprs}") + log.info(f"Ordered diffs for aggregation: {ordered_diff_reprs}") + incremental_diffs_and_requests: list[EnrichedDiffsEmpty | EnrichedDiffRequest | None] = [] current_time = diff_request.from_time - previous_diffs: EnrichedDiffs | None = None while current_time < diff_request.to_time: + # the next diff to include has already been calculated if ordered_diffs and ordered_diffs[0].diff_branch_diff.from_time == current_time: - current_diffs = ordered_diffs.pop(0) + current_diff = ordered_diffs.pop(0) + incremental_diffs_and_requests.append(current_diff) + current_time = current_diff.diff_branch_diff.to_time + continue + # set the end time to the start of the next calculated diff or the end of the time range + if ordered_diffs: + end_time = ordered_diffs[0].diff_branch_diff.from_time else: - if ordered_diffs: - end_time = ordered_diffs[0].diff_branch_diff.from_time - else: - end_time = diff_request.to_time - if previous_diffs is None: - node_field_specifiers = set() - else: - node_field_specifiers = self._get_node_field_specifiers( - enriched_diff=previous_diffs.diff_branch_diff - ) - inner_diff_request = EnrichedDiffRequest( + end_time = diff_request.to_time + # if there are no changes on either branch in this time range, then there cannot be a diff + num_changes_by_branch = await self.diff_repo.get_num_changes_in_time_range_by_branch( + branch_names=[diff_request.base_branch.name, diff_request.diff_branch.name], + from_time=current_time, + to_time=end_time, + ) + might_have_changes_in_time_range = any(num_changes_by_branch.values()) + if not might_have_changes_in_time_range: + incremental_diffs_and_requests.append(None) + current_time = end_time + continue + + incremental_diffs_and_requests.append( + EnrichedDiffRequest( base_branch=diff_request.base_branch, diff_branch=diff_request.diff_branch, from_time=current_time, to_time=end_time, - node_field_specifiers=node_field_specifiers, - ) - is_incremental_diff = current_time != diff_request.from_time - current_diffs = await self._get_enriched_diff( - diff_request=inner_diff_request, is_incremental_diff=is_incremental_diff ) + ) + current_time = end_time - if previous_diffs: - current_diffs = await self.diff_combiner.combine( - earlier_diffs=previous_diffs, later_diffs=current_diffs - ) + aggregated_enriched_diffs = await self._concatenate_diffs_and_requests( + diff_or_request_list=incremental_diffs_and_requests, full_diff_request=diff_request + ) - previous_diffs = current_diffs - current_time = current_diffs.diff_branch_diff.to_time + if aggregated_enriched_diffs: + aggregated_enriched_diffs.base_branch_diff.from_time = diff_request.from_time + aggregated_enriched_diffs.diff_branch_diff.from_time = diff_request.from_time + aggregated_enriched_diffs.base_branch_diff.to_time = diff_request.to_time + aggregated_enriched_diffs.diff_branch_diff.to_time = diff_request.to_time + return aggregated_enriched_diffs + return self._build_empty_enriched_diffs(diff_request=diff_request) - return current_diffs + async def _concatenate_diffs_and_requests( + self, + diff_or_request_list: list[EnrichedDiffsEmpty | EnrichedDiffRequest | None], + full_diff_request: EnrichedDiffRequest, + ) -> EnrichedDiffs | None: + calculations_required = False + existing_diff_count = 0 + for diff_or_request in diff_or_request_list: + # a diff needs to be calculated + if isinstance(diff_or_request, EnrichedDiffRequest): + calculations_required = True + break + if isinstance(diff_or_request, EnrichedDiffsEmpty): + existing_diff_count += 1 + # multiple existing diffs need to be added together + if existing_diff_count > 1: + calculations_required = True + break + if not calculations_required: + return None + + complete_enriched_diffs: None | EnrichedDiffs = None + for diff_or_request in diff_or_request_list: + single_enriched_diffs: EnrichedDiffs | None = None + if isinstance(diff_or_request, EnrichedDiffRootEmpty): + single_enriched_diffs = await self.diff_repo.hydrate_diff_pair(enriched_diffs=diff_or_request) + elif isinstance(diff_or_request, EnrichedDiffRequest): + if complete_enriched_diffs: + diff_or_request.node_field_specifiers = self._get_node_field_specifiers( + enriched_diff=complete_enriched_diffs.diff_branch_diff + ) + is_incremental_diff = diff_or_request.from_time != full_diff_request.from_time + single_enriched_diffs = await self._calculate_enriched_diff( + diff_request=diff_or_request, is_incremental_diff=is_incremental_diff + ) + if not single_enriched_diffs: + continue + if complete_enriched_diffs: + complete_enriched_diffs = await self.diff_combiner.combine( + earlier_diffs=complete_enriched_diffs, later_diffs=single_enriched_diffs + ) + else: + complete_enriched_diffs = single_enriched_diffs + return complete_enriched_diffs async def _update_core_data_checks(self, enriched_diff: EnrichedDiffRoot) -> list[Node]: return await self.data_check_synchronizer.synchronize(enriched_diff=enriched_diff) - async def _get_enriched_diff(self, diff_request: EnrichedDiffRequest, is_incremental_diff: bool) -> EnrichedDiffs: + async def _calculate_enriched_diff( + self, diff_request: EnrichedDiffRequest, is_incremental_diff: bool + ) -> EnrichedDiffs: calculated_diff_pair = await self.diff_calculator.calculate_diff( base_branch=diff_request.base_branch, diff_branch=diff_request.diff_branch, diff --git a/backend/infrahub/core/diff/model/path.py b/backend/infrahub/core/diff/model/path.py index 8cb5e5cc50..b90eb190f4 100644 --- a/backend/infrahub/core/diff/model/path.py +++ b/backend/infrahub/core/diff/model/path.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass, field, replace +from dataclasses import asdict, dataclass, field, replace from enum import Enum from typing import TYPE_CHECKING, Any, Optional @@ -403,7 +403,7 @@ def from_calculated_node(cls, calculated_node: DiffNode) -> EnrichedDiffNode: @dataclass -class EnrichedDiffRoot(BaseSummary): +class EnrichedDiffRootEmpty(BaseSummary): base_branch_name: str diff_branch_name: str from_time: Timestamp @@ -411,6 +411,17 @@ class EnrichedDiffRoot(BaseSummary): uuid: str partner_uuid: str tracking_id: TrackingId | None = field(default=None, kw_only=True) + + def __hash__(self) -> int: + return hash(self.uuid) + + @property + def time_range(self) -> Interval: + return self.to_time.obj - self.from_time.obj + + +@dataclass +class EnrichedDiffRoot(EnrichedDiffRootEmpty): nodes: set[EnrichedDiffNode] = field(default_factory=set) def __hash__(self) -> int: @@ -446,6 +457,10 @@ def get_all_conflicts(self) -> dict[str, EnrichedDiffConflict]: all_conflicts.update(node.get_all_conflicts()) return all_conflicts + @classmethod + def from_empty_root(cls, empty_root: EnrichedDiffRootEmpty) -> EnrichedDiffRoot: + return EnrichedDiffRoot(**asdict(empty_root)) + @classmethod def from_calculated_diff( cls, calculated_diff: DiffRoot, base_branch_name: str, partner_uuid: str @@ -503,9 +518,26 @@ def add_parent( @dataclass -class EnrichedDiffs: +class EnrichedDiffsEmpty: base_branch_name: str diff_branch_name: str + base_branch_diff: EnrichedDiffRootEmpty + diff_branch_diff: EnrichedDiffRootEmpty + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + "branch_uuid={self.diff_branch_diff}," + "base_uuid={self.base_branch_diff.uuid}," + "branch_name={self.diff_branch_name}," + "base_name={self.base_branch_name}," + "from_time={self.diff_branch_diff.from_time}," + "to_time={self.diff_branch_diff.to_time})" + ) + + +@dataclass +class EnrichedDiffs(EnrichedDiffsEmpty): base_branch_diff: EnrichedDiffRoot diff_branch_diff: EnrichedDiffRoot diff --git a/backend/infrahub/core/diff/query/empty_roots.py b/backend/infrahub/core/diff/query/empty_roots.py index 64ccbfa38a..688fc2fb03 100644 --- a/backend/infrahub/core/diff/query/empty_roots.py +++ b/backend/infrahub/core/diff/query/empty_roots.py @@ -3,6 +3,7 @@ from neo4j.graph import Node as Neo4jNode from infrahub.core.query import Query, QueryType +from infrahub.core.timestamp import Timestamp from infrahub.database import InfrahubDatabase @@ -11,19 +12,33 @@ class EnrichedDiffEmptyRootsQuery(Query): type = QueryType.READ def __init__( - self, diff_branch_names: list[str] | None = None, base_branch_names: list[str] | None = None, **kwargs: Any + self, + diff_branch_names: list[str] | None = None, + base_branch_names: list[str] | None = None, + from_time: Timestamp | None = None, + to_time: Timestamp | None = None, + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.diff_branch_names = diff_branch_names self.base_branch_names = base_branch_names + self.from_time = from_time + self.to_time = to_time async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: - self.params = {"diff_branch_names": self.diff_branch_names, "base_branch_names": self.base_branch_names} + self.params = { + "diff_branch_names": self.diff_branch_names, + "base_branch_names": self.base_branch_names, + "from_time": self.from_time.to_string() if self.from_time else None, + "to_time": self.to_time.to_string() if self.to_time else None, + } query = """ MATCH (diff_root:DiffRoot) WHERE ($diff_branch_names IS NULL OR diff_root.diff_branch IN $diff_branch_names) AND ($base_branch_names IS NULL OR diff_root.base_branch IN $base_branch_names) + AND ($from_time IS NULL OR diff_root.from_time >= $from_time) + AND ($to_time IS NULL OR diff_root.to_time <= $to_time) """ self.return_labels = ["diff_root"] self.add_to_query(query=query) diff --git a/backend/infrahub/core/diff/repository/deserializer.py b/backend/infrahub/core/diff/repository/deserializer.py index 41ccc72f88..75e70e210c 100644 --- a/backend/infrahub/core/diff/repository/deserializer.py +++ b/backend/infrahub/core/diff/repository/deserializer.py @@ -16,6 +16,7 @@ EnrichedDiffProperty, EnrichedDiffRelationship, EnrichedDiffRoot, + EnrichedDiffRootEmpty, EnrichedDiffSingleRelationship, deserialize_tracking_id, ) @@ -149,19 +150,20 @@ def _deserialize_diff_root(self, root_node: Neo4jNode) -> EnrichedDiffRoot: root_uuid = str(root_node.get("uuid")) if root_uuid in self._diff_root_map: return self._diff_root_map[root_uuid] - enriched_root = self.build_diff_root(root_node=root_node) + root_empty = self.build_diff_root_empty(root_node=root_node) + enriched_root = EnrichedDiffRoot.from_empty_root(empty_root=root_empty) self._diff_root_map[root_uuid] = enriched_root return enriched_root @classmethod - def build_diff_root(cls, root_node: Neo4jNode) -> EnrichedDiffRoot: + def build_diff_root_empty(cls, root_node: Neo4jNode) -> EnrichedDiffRootEmpty: from_time = Timestamp(str(root_node.get("from_time"))) to_time = Timestamp(str(root_node.get("to_time"))) tracking_id_str = cls._get_str_or_none_property_value(node=root_node, property_name="tracking_id") tracking_id = None if tracking_id_str: tracking_id = deserialize_tracking_id(tracking_id_str=tracking_id_str) - return EnrichedDiffRoot( + return EnrichedDiffRootEmpty( base_branch_name=str(root_node.get("base_branch")), diff_branch_name=str(root_node.get("diff_branch")), from_time=from_time, diff --git a/backend/infrahub/core/diff/repository/repository.py b/backend/infrahub/core/diff/repository/repository.py index 963595fcd1..8fcd125db9 100644 --- a/backend/infrahub/core/diff/repository/repository.py +++ b/backend/infrahub/core/diff/repository/repository.py @@ -12,7 +12,9 @@ ConflictSelection, EnrichedDiffConflict, EnrichedDiffRoot, + EnrichedDiffRootEmpty, EnrichedDiffs, + EnrichedDiffsEmpty, EnrichedNodeCreateRequest, NodeDiffFieldSummary, TimeRange, @@ -119,6 +121,20 @@ async def get_pairs( for dbr in diff_branch_roots ] + async def hydrate_diff_pair(self, enriched_diffs: EnrichedDiffsEmpty) -> EnrichedDiffs: + hydrated_base_diff = await self.get_one( + diff_branch_name=enriched_diffs.base_branch_name, diff_id=enriched_diffs.base_branch_diff.uuid + ) + hydrated_branch_diff = await self.get_one( + diff_branch_name=enriched_diffs.diff_branch_name, diff_id=enriched_diffs.diff_branch_diff.uuid + ) + return EnrichedDiffs( + base_branch_name=enriched_diffs.base_branch_name, + diff_branch_name=enriched_diffs.diff_branch_name, + base_branch_diff=hydrated_base_diff, + diff_branch_diff=hydrated_branch_diff, + ) + async def get_one( self, diff_branch_name: str, @@ -212,18 +228,55 @@ async def get_time_ranges( await query.execute(db=self.db) return await query.get_time_ranges() + async def get_empty_diff_pairs( + self, + diff_branch_names: list[str] | None = None, + base_branch_names: list[str] | None = None, + from_time: Timestamp | None = None, + to_time: Timestamp | None = None, + ) -> list[EnrichedDiffsEmpty]: + if diff_branch_names and base_branch_names: + diff_branch_names += base_branch_names + empty_roots = await self.get_empty_roots( + diff_branch_names=diff_branch_names, + base_branch_names=base_branch_names, + from_time=from_time, + to_time=to_time, + ) + roots_by_id = {root.uuid: root for root in empty_roots} + pairs: list[EnrichedDiffsEmpty] = [] + for branch_root in empty_roots: + if branch_root.base_branch_name == branch_root.diff_branch_name: + continue + base_root = roots_by_id[branch_root.partner_uuid] + pairs.append( + EnrichedDiffsEmpty( + base_branch_name=branch_root.base_branch_name, + diff_branch_name=branch_root.diff_branch_name, + base_branch_diff=base_root, + diff_branch_diff=branch_root, + ) + ) + return pairs + async def get_empty_roots( self, diff_branch_names: list[str] | None = None, base_branch_names: list[str] | None = None, - ) -> list[EnrichedDiffRoot]: + from_time: Timestamp | None = None, + to_time: Timestamp | None = None, + ) -> list[EnrichedDiffRootEmpty]: query = await EnrichedDiffEmptyRootsQuery.init( - db=self.db, diff_branch_names=diff_branch_names, base_branch_names=base_branch_names + db=self.db, + diff_branch_names=diff_branch_names, + base_branch_names=base_branch_names, + from_time=from_time, + to_time=to_time, ) await query.execute(db=self.db) diff_roots = [] for neo4j_node in query.get_empty_root_nodes(): - diff_roots.append(self.deserializer.build_diff_root(root_node=neo4j_node)) + diff_roots.append(self.deserializer.build_diff_root_empty(root_node=neo4j_node)) return diff_roots async def diff_has_conflicts( From c74c4bcffa929c0af0805387280e619c5de99d38 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Thu, 19 Dec 2024 21:31:50 -0800 Subject: [PATCH 04/18] typos --- backend/infrahub/core/diff/coordinator.py | 2 +- backend/infrahub/core/diff/model/path.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/backend/infrahub/core/diff/coordinator.py b/backend/infrahub/core/diff/coordinator.py index 74210f2f7e..83b04731f3 100644 --- a/backend/infrahub/core/diff/coordinator.py +++ b/backend/infrahub/core/diff/coordinator.py @@ -418,7 +418,7 @@ async def _concatenate_diffs_and_requests( complete_enriched_diffs: None | EnrichedDiffs = None for diff_or_request in diff_or_request_list: single_enriched_diffs: EnrichedDiffs | None = None - if isinstance(diff_or_request, EnrichedDiffRootEmpty): + if isinstance(diff_or_request, EnrichedDiffsEmpty): single_enriched_diffs = await self.diff_repo.hydrate_diff_pair(enriched_diffs=diff_or_request) elif isinstance(diff_or_request, EnrichedDiffRequest): if complete_enriched_diffs: diff --git a/backend/infrahub/core/diff/model/path.py b/backend/infrahub/core/diff/model/path.py index b90eb190f4..04699f75d6 100644 --- a/backend/infrahub/core/diff/model/path.py +++ b/backend/infrahub/core/diff/model/path.py @@ -527,12 +527,12 @@ class EnrichedDiffsEmpty: def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" - "branch_uuid={self.diff_branch_diff}," - "base_uuid={self.base_branch_diff.uuid}," - "branch_name={self.diff_branch_name}," - "base_name={self.base_branch_name}," - "from_time={self.diff_branch_diff.from_time}," - "to_time={self.diff_branch_diff.to_time})" + f"branch_uuid={self.diff_branch_diff.uuid}," + f"base_uuid={self.base_branch_diff.uuid}," + f"branch_name={self.diff_branch_name}," + f"base_name={self.base_branch_name}," + f"from_time={self.diff_branch_diff.from_time}," + f"to_time={self.diff_branch_diff.to_time})" ) From 6918255f01b6b5c1e94b1e4ba6d484c2b248ea65 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Thu, 19 Dec 2024 22:10:25 -0800 Subject: [PATCH 05/18] pylint --- backend/infrahub/core/diff/coordinator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/infrahub/core/diff/coordinator.py b/backend/infrahub/core/diff/coordinator.py index 83b04731f3..078fe89fe0 100644 --- a/backend/infrahub/core/diff/coordinator.py +++ b/backend/infrahub/core/diff/coordinator.py @@ -12,7 +12,6 @@ from .model.path import ( BranchTrackingId, EnrichedDiffRoot, - EnrichedDiffRootEmpty, EnrichedDiffs, EnrichedDiffsEmpty, NameTrackingId, From 8f1312593db50259e796a2c3ce69784836cad79e Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Tue, 24 Dec 2024 15:36:43 -0800 Subject: [PATCH 06/18] add relationship identifier to saved diff --- backend/infrahub/core/diff/combiner.py | 1 + backend/infrahub/core/diff/coordinator.py | 4 ++-- backend/infrahub/core/diff/enricher/hierarchy.py | 2 ++ backend/infrahub/core/diff/model/path.py | 5 +++++ backend/infrahub/core/diff/query/save.py | 1 + backend/infrahub/core/diff/query_parser.py | 1 + backend/infrahub/core/diff/repository/deserializer.py | 2 ++ 7 files changed, 14 insertions(+), 2 deletions(-) diff --git a/backend/infrahub/core/diff/combiner.py b/backend/infrahub/core/diff/combiner.py index 602762e742..880b8fb170 100644 --- a/backend/infrahub/core/diff/combiner.py +++ b/backend/infrahub/core/diff/combiner.py @@ -320,6 +320,7 @@ def _combine_relationships( combined_relationship = EnrichedDiffRelationship( name=later_relationship.name, label=later_relationship.label, + identifier=later_relationship.identifier, cardinality=later_relationship.cardinality, changed_at=later_relationship.changed_at or earlier_relationship.changed_at, action=combined_action, diff --git a/backend/infrahub/core/diff/coordinator.py b/backend/infrahub/core/diff/coordinator.py index 078fe89fe0..2c175ea0d8 100644 --- a/backend/infrahub/core/diff/coordinator.py +++ b/backend/infrahub/core/diff/coordinator.py @@ -305,9 +305,9 @@ async def _update_diffs( ) if tracking_id: for diff_pair in empty_diff_pairs: - if diff_pair.base_branch_diff.tracking_id: + if diff_pair.base_branch_diff.tracking_id == tracking_id: diff_uuids_to_delete.append(diff_pair.base_branch_diff.uuid) - if diff_pair.diff_branch_diff.tracking_id: + if diff_pair.diff_branch_diff.tracking_id == tracking_id: diff_uuids_to_delete.append(diff_pair.diff_branch_diff.uuid) aggregated_enriched_diffs = await self._aggregate_enriched_diffs( diff_request=EnrichedDiffRequest( diff --git a/backend/infrahub/core/diff/enricher/hierarchy.py b/backend/infrahub/core/diff/enricher/hierarchy.py index 79c6ba44c0..f4b47730cd 100644 --- a/backend/infrahub/core/diff/enricher/hierarchy.py +++ b/backend/infrahub/core/diff/enricher/hierarchy.py @@ -92,6 +92,7 @@ async def _enrich_hierarchical_nodes( parent_kind=ancestor.kind, parent_label="", parent_rel_name=parent_rel.name, + parent_rel_identifier=parent_rel.get_identifier(), parent_rel_cardinality=parent_rel.cardinality, parent_rel_label=parent_rel.label or "", ) @@ -150,6 +151,7 @@ async def _enrich_nodes_with_parent( parent_kind=peer_parent.peer_kind, parent_label="", parent_rel_name=parent_rel.name, + parent_rel_identifier=parent_rel.get_identifier(), parent_rel_cardinality=parent_rel.cardinality, parent_rel_label=parent_rel.label or "", ) diff --git a/backend/infrahub/core/diff/model/path.py b/backend/infrahub/core/diff/model/path.py index 04699f75d6..91dcd466fd 100644 --- a/backend/infrahub/core/diff/model/path.py +++ b/backend/infrahub/core/diff/model/path.py @@ -239,6 +239,7 @@ def from_calculated_element(cls, calculated_element: DiffSingleRelationship) -> @dataclass class EnrichedDiffRelationship(BaseSummary): name: str + identifier: str label: str cardinality: RelationshipCardinality path_identifier: str = field(default="", kw_only=True) @@ -270,6 +271,7 @@ def include_in_response(self) -> bool: def from_calculated_relationship(cls, calculated_relationship: DiffRelationship) -> EnrichedDiffRelationship: return EnrichedDiffRelationship( name=calculated_relationship.name, + identifier=calculated_relationship.identifier, label="", cardinality=calculated_relationship.cardinality, changed_at=calculated_relationship.changed_at, @@ -482,6 +484,7 @@ def add_parent( parent_kind: str, parent_label: str, parent_rel_name: str, + parent_rel_identifier: str, parent_rel_cardinality: RelationshipCardinality, parent_rel_label: str = "", ) -> EnrichedDiffNode: @@ -506,6 +509,7 @@ def add_parent( node.relationships.add( EnrichedDiffRelationship( name=parent_rel_name, + identifier=parent_rel_identifier, label=parent_rel_label, cardinality=parent_rel_cardinality, changed_at=None, @@ -609,6 +613,7 @@ class DiffSingleRelationship: @dataclass class DiffRelationship: name: str + identifier: str cardinality: RelationshipCardinality changed_at: Timestamp action: DiffAction diff --git a/backend/infrahub/core/diff/query/save.py b/backend/infrahub/core/diff/query/save.py index ab09736faf..d5bd01a7ed 100644 --- a/backend/infrahub/core/diff/query/save.py +++ b/backend/infrahub/core/diff/query/save.py @@ -261,6 +261,7 @@ def _build_diff_relationship_params(self, enriched_relationship: EnrichedDiffRel return { "node_properties": { "name": enriched_relationship.name, + "identifier": enriched_relationship.identifier, "label": enriched_relationship.label, "cardinality": enriched_relationship.cardinality.value, "changed_at": enriched_relationship.changed_at.to_string() diff --git a/backend/infrahub/core/diff/query_parser.py b/backend/infrahub/core/diff/query_parser.py index afb18716e7..74122b77d4 100644 --- a/backend/infrahub/core/diff/query_parser.py +++ b/backend/infrahub/core/diff/query_parser.py @@ -374,6 +374,7 @@ def to_diff_relationship(self, include_unchanged: bool) -> DiffRelationship: action = actions.pop() return DiffRelationship( name=self.name, + identifier=self.identifier, changed_at=last_changed_at, action=action, relationships=single_relationships, diff --git a/backend/infrahub/core/diff/repository/deserializer.py b/backend/infrahub/core/diff/repository/deserializer.py index 75e70e210c..14520e89e0 100644 --- a/backend/infrahub/core/diff/repository/deserializer.py +++ b/backend/infrahub/core/diff/repository/deserializer.py @@ -136,6 +136,7 @@ def _deserialize_parents(self, result: QueryResult, enriched_root: EnrichedDiffR parent_kind=parent.get("kind"), parent_label=parent.get("label"), parent_rel_name=rel.get("name"), + parent_rel_identifier=rel.get("identifier"), parent_rel_cardinality=RelationshipCardinality(rel.get("cardinality")), parent_rel_label=rel.get("label"), ) @@ -236,6 +237,7 @@ def _deserialize_diff_relationship_group( timestamp_str = relationship_group_node.get("changed_at") enriched_relationship = EnrichedDiffRelationship( name=relationship_group_node.get("name"), + identifier=relationship_group_node.get("identifier"), label=relationship_group_node.get("label"), cardinality=RelationshipCardinality(relationship_group_node.get("cardinality")), changed_at=Timestamp(timestamp_str) if timestamp_str else None, From 40651351013d5f9a538bf01f6e3b78a00446725c Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Tue, 31 Dec 2024 07:26:38 -0800 Subject: [PATCH 07/18] unit test update --- backend/tests/integration/diff/test_diff_merge.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/backend/tests/integration/diff/test_diff_merge.py b/backend/tests/integration/diff/test_diff_merge.py index f5e96578ae..089a7bc19c 100644 --- a/backend/tests/integration/diff/test_diff_merge.py +++ b/backend/tests/integration/diff/test_diff_merge.py @@ -171,7 +171,9 @@ async def test_node_delete_conflict( await new_car.save(db=db) # check that the expected node-level conflict exists - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) conflicts_map = enriched_diff.get_all_conflicts() assert set(conflicts_map.keys()) == {f"data/{person_updated.id}"} conflict = conflicts_map[f"data/{person_updated.id}"] @@ -192,7 +194,9 @@ async def test_node_delete_conflict( await car_branch.save(db=db) # check that the conflict is gone - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=diff_branch + ) conflicts_map = enriched_diff.get_all_conflicts() assert len(conflicts_map) == 0 From b76e3c78810520a4851011dc6e348c8fa5e9ec08 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Wed, 1 Jan 2025 14:48:18 -0800 Subject: [PATCH 08/18] refactor coordinator to allow skipping diff loading and calculation --- backend/infrahub/core/diff/coordinator.py | 310 +++++++++++------- backend/infrahub/core/diff/model/path.py | 59 +++- .../core/diff/query/field_specifiers.py | 33 ++ .../core/diff/repository/deserializer.py | 6 +- .../core/diff/repository/repository.py | 37 ++- .../diff/test_diff_incremental_addition.py | 2 +- .../unit/core/diff/test_coordinator_lock.py | 55 ++-- .../unit/core/diff/test_diff_and_merge.py | 4 +- .../unit/core/diff/test_diff_combiner.py | 6 + 9 files changed, 341 insertions(+), 171 deletions(-) create mode 100644 backend/infrahub/core/diff/query/field_specifiers.py diff --git a/backend/infrahub/core/diff/coordinator.py b/backend/infrahub/core/diff/coordinator.py index 2c175ea0d8..15af7252e8 100644 --- a/backend/infrahub/core/diff/coordinator.py +++ b/backend/infrahub/core/diff/coordinator.py @@ -1,19 +1,19 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING, Iterable, Literal, Sequence, overload from uuid import uuid4 from infrahub import lock -from infrahub.core import registry from infrahub.core.timestamp import Timestamp from infrahub.log import get_logger from .model.path import ( BranchTrackingId, EnrichedDiffRoot, + EnrichedDiffRootMetadata, EnrichedDiffs, - EnrichedDiffsEmpty, + EnrichedDiffsMetadata, NameTrackingId, NodeFieldSpecifier, TrackingId, @@ -43,6 +43,7 @@ class EnrichedDiffRequest: diff_branch: Branch from_time: Timestamp to_time: Timestamp + tracking_id: TrackingId | None = field(default=None) node_field_specifiers: set[NodeFieldSpecifier] = field(default_factory=set) @@ -109,13 +110,25 @@ def _get_lock_name(self, base_branch_name: str, diff_branch_name: str, is_increm async def update_branch_diff_and_return(self, base_branch: Branch, diff_branch: Branch) -> EnrichedDiffRoot: enriched_diff = await self.update_branch_diff(base_branch=base_branch, diff_branch=diff_branch) - if enriched_diff: + if isinstance(enriched_diff, EnrichedDiffRoot): return enriched_diff - return await self.diff_repo.get_one( - diff_branch_name=diff_branch.name, tracking_id=BranchTrackingId(name=diff_branch.name) + return await self._finalize_diff_root_metadata(diff_root_metadata=enriched_diff) + + async def _finalize_diff_root_metadata(self, diff_root_metadata: EnrichedDiffRootMetadata) -> EnrichedDiffRoot: + # if this is EnrichedDiffMetadata, we need to retrieve the full diff and set its metadata to match + full_enriched_diff = await self.diff_repo.get_one( + diff_branch_name=diff_root_metadata.diff_branch_name, diff_id=diff_root_metadata.uuid + ) + full_enriched_diff.update_metadata( + from_time=diff_root_metadata.from_time, + to_time=diff_root_metadata.to_time, + tracking_id=diff_root_metadata.tracking_id, ) + return full_enriched_diff - async def update_branch_diff(self, base_branch: Branch, diff_branch: Branch) -> EnrichedDiffRoot | None: + async def update_branch_diff( + self, base_branch: Branch, diff_branch: Branch + ) -> EnrichedDiffRoot | EnrichedDiffRootMetadata: log.info(f"Received request to update branch diff for {base_branch.name} - {diff_branch.name}") incremental_lock_name = self._get_lock_name( base_branch_name=base_branch.name, diff_branch_name=diff_branch.name, is_incremental=True @@ -147,9 +160,11 @@ async def update_branch_diff(self, base_branch: Branch, diff_branch: Branch) -> from_time=from_time, to_time=to_time, tracking_id=tracking_id, + force_branch_refresh=False, ) - if not enriched_diffs: - return None + if not isinstance(enriched_diffs, EnrichedDiffs): + return enriched_diffs.diff_branch_diff + await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.base_branch_diff) await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.diff_branch_diff) await self.diff_repo.save(enriched_diffs=enriched_diffs) @@ -179,9 +194,10 @@ async def create_or_update_arbitrary_timeframe_diff( from_time=from_time, to_time=to_time, tracking_id=tracking_id, + force_branch_refresh=False, ) - if not enriched_diffs: - return await self.diff_repo.get_one(diff_branch_name=diff_branch.name, tracking_id=tracking_id) + if not isinstance(enriched_diffs, EnrichedDiffs): + return await self._finalize_diff_root_metadata(diff_root_metadata=enriched_diffs.diff_branch_diff) await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.base_branch_diff) await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.diff_branch_diff) @@ -221,11 +237,6 @@ async def recalculate( tracking_id=current_branch_diff.tracking_id, force_branch_refresh=True, ) - if not enriched_diffs: - return await self.diff_repo.get_one( - diff_branch_name=diff_branch.name, tracking_id=current_branch_diff.tracking_id - ) - if current_branch_diff: await self.conflict_transferer.transfer( earlier=current_branch_diff, later=enriched_diffs.diff_branch_diff @@ -239,12 +250,12 @@ async def recalculate( return enriched_diffs.diff_branch_diff def _get_ordered_diff_pairs( - self, diff_pairs: Iterable[EnrichedDiffsEmpty], allow_overlap: bool = False - ) -> list[EnrichedDiffsEmpty]: + self, diff_pairs: Iterable[EnrichedDiffsMetadata], allow_overlap: bool = False + ) -> list[EnrichedDiffsMetadata]: ordered_diffs = sorted(diff_pairs, key=lambda d: d.diff_branch_diff.from_time) if allow_overlap: return ordered_diffs - ordered_diffs_no_overlaps: list[EnrichedDiffsEmpty] = [] + ordered_diffs_no_overlaps: list[EnrichedDiffsMetadata] = [] for candidate_diff_pair in ordered_diffs: if not ordered_diffs_no_overlaps: ordered_diffs_no_overlaps.append(candidate_diff_pair) @@ -262,7 +273,7 @@ def _get_ordered_diff_pairs( ordered_diffs_no_overlaps[-1] = candidate_diff_pair return ordered_diffs_no_overlaps - def _build_empty_enriched_diffs(self, diff_request: EnrichedDiffRequest) -> EnrichedDiffs: + def _build_enriched_diffs_with_no_nodes(self, diff_request: EnrichedDiffRequest) -> EnrichedDiffs: base_uuid = str(uuid4()) branch_uuid = str(uuid4()) return EnrichedDiffs( @@ -273,6 +284,7 @@ def _build_empty_enriched_diffs(self, diff_request: EnrichedDiffRequest) -> Enri diff_branch_name=diff_request.base_branch.name, from_time=diff_request.from_time, to_time=diff_request.to_time, + tracking_id=diff_request.tracking_id, uuid=base_uuid, partner_uuid=branch_uuid, ), @@ -281,11 +293,34 @@ def _build_empty_enriched_diffs(self, diff_request: EnrichedDiffRequest) -> Enri diff_branch_name=diff_request.diff_branch.name, from_time=diff_request.from_time, to_time=diff_request.to_time, + tracking_id=diff_request.tracking_id, uuid=branch_uuid, partner_uuid=base_uuid, ), ) + @overload + async def _update_diffs( + self, + base_branch: Branch, + diff_branch: Branch, + from_time: Timestamp, + to_time: Timestamp, + tracking_id: TrackingId | None = None, + force_branch_refresh: Literal[True] = ..., + ) -> EnrichedDiffs: ... + + @overload + async def _update_diffs( + self, + base_branch: Branch, + diff_branch: Branch, + from_time: Timestamp, + to_time: Timestamp, + tracking_id: TrackingId | None = None, + force_branch_refresh: Literal[False] = ..., + ) -> EnrichedDiffs | EnrichedDiffsMetadata: ... + async def _update_diffs( self, base_branch: Branch, @@ -294,7 +329,7 @@ async def _update_diffs( to_time: Timestamp, tracking_id: TrackingId | None = None, force_branch_refresh: bool = False, - ) -> EnrichedDiffs | None: + ) -> EnrichedDiffs | EnrichedDiffsMetadata: diff_uuids_to_delete = [] # start with empty diffs b/c we only care about their metadata for now, hydrate them with data as needed empty_diff_pairs = await self.diff_repo.get_empty_diff_pairs( @@ -315,11 +350,18 @@ async def _update_diffs( diff_branch=diff_branch, from_time=from_time, to_time=to_time, + tracking_id=tracking_id, ), - partial_enriched_diffs=empty_diff_pairs if not force_branch_refresh else [], + partial_enriched_diffs=empty_diff_pairs if not force_branch_refresh else None, ) - if not aggregated_enriched_diffs: - return None + + if diff_uuids_to_delete: + await self.diff_repo.delete_diff_roots(diff_root_uuids=diff_uuids_to_delete) + + # this is an EnrichedDiffsMetadata, so there are no nodes to enrich + if not isinstance(aggregated_enriched_diffs, EnrichedDiffs): + aggregated_enriched_diffs.update_metadata(from_time=from_time, to_time=to_time, tracking_id=tracking_id) + return aggregated_enriched_diffs await self.conflicts_enricher.add_conflicts_to_branch_diff( base_diff_root=aggregated_enriched_diffs.base_branch_diff, @@ -329,114 +371,153 @@ async def _update_diffs( enriched_diff_root=aggregated_enriched_diffs.diff_branch_diff, conflicts_only=True ) - if tracking_id: - aggregated_enriched_diffs.base_branch_diff.tracking_id = tracking_id - aggregated_enriched_diffs.diff_branch_diff.tracking_id = tracking_id - if diff_uuids_to_delete: - await self.diff_repo.delete_diff_roots(diff_root_uuids=diff_uuids_to_delete) return aggregated_enriched_diffs + @overload async def _aggregate_enriched_diffs( - self, diff_request: EnrichedDiffRequest, partial_enriched_diffs: list[EnrichedDiffsEmpty] - ) -> EnrichedDiffs | None: + self, + diff_request: EnrichedDiffRequest, + partial_enriched_diffs: list[EnrichedDiffsMetadata], + ) -> EnrichedDiffs | EnrichedDiffsMetadata: ... + + @overload + async def _aggregate_enriched_diffs( + self, + diff_request: EnrichedDiffRequest, + partial_enriched_diffs: None, + ) -> EnrichedDiffs: ... + + async def _aggregate_enriched_diffs( + self, + diff_request: EnrichedDiffRequest, + partial_enriched_diffs: list[EnrichedDiffsMetadata] | None, + ) -> EnrichedDiffs | EnrichedDiffsMetadata: + """ + If return is an EnrichedDiffsMetadata, it acts as a pointer to a diff in the database that has all the + necessary data for this diff_request. Might have a different time range and/or tracking_id + """ + aggregated_enriched_diffs: EnrichedDiffs | EnrichedDiffsMetadata | None = None if not partial_enriched_diffs: - return await self._calculate_enriched_diff(diff_request=diff_request, is_incremental_diff=False) - - ordered_diffs = self._get_ordered_diff_pairs(diff_pairs=partial_enriched_diffs, allow_overlap=False) - ordered_diff_reprs = [repr(d) for d in ordered_diffs] - log.info(f"Ordered diffs for aggregation: {ordered_diff_reprs}") - incremental_diffs_and_requests: list[EnrichedDiffsEmpty | EnrichedDiffRequest | None] = [] - current_time = diff_request.from_time - while current_time < diff_request.to_time: - # the next diff to include has already been calculated - if ordered_diffs and ordered_diffs[0].diff_branch_diff.from_time == current_time: - current_diff = ordered_diffs.pop(0) - incremental_diffs_and_requests.append(current_diff) - current_time = current_diff.diff_branch_diff.to_time - continue - # set the end time to the start of the next calculated diff or the end of the time range - if ordered_diffs: - end_time = ordered_diffs[0].diff_branch_diff.from_time - else: - end_time = diff_request.to_time - # if there are no changes on either branch in this time range, then there cannot be a diff - num_changes_by_branch = await self.diff_repo.get_num_changes_in_time_range_by_branch( - branch_names=[diff_request.base_branch.name, diff_request.diff_branch.name], - from_time=current_time, - to_time=end_time, + aggregated_enriched_diffs = await self._calculate_enriched_diff( + diff_request=diff_request, is_incremental_diff=False ) - might_have_changes_in_time_range = any(num_changes_by_branch.values()) - if not might_have_changes_in_time_range: - incremental_diffs_and_requests.append(None) - current_time = end_time - continue - incremental_diffs_and_requests.append( - EnrichedDiffRequest( - base_branch=diff_request.base_branch, - diff_branch=diff_request.diff_branch, + if partial_enriched_diffs is not None and not aggregated_enriched_diffs: + ordered_diffs = self._get_ordered_diff_pairs(diff_pairs=partial_enriched_diffs, allow_overlap=False) + ordered_diff_reprs = [repr(d) for d in ordered_diffs] + log.info(f"Ordered diffs for aggregation: {ordered_diff_reprs}") + incremental_diffs_and_requests: list[EnrichedDiffsMetadata | EnrichedDiffRequest | None] = [] + current_time = diff_request.from_time + while current_time < diff_request.to_time: + # the next diff to include has already been calculated + if ordered_diffs and ordered_diffs[0].diff_branch_diff.from_time == current_time: + current_diff = ordered_diffs.pop(0) + incremental_diffs_and_requests.append(current_diff) + current_time = current_diff.diff_branch_diff.to_time + continue + # set the end time to the start of the next calculated diff or the end of the time range + if ordered_diffs: + end_time = ordered_diffs[0].diff_branch_diff.from_time + else: + end_time = diff_request.to_time + # if there are no changes on either branch in this time range, then there cannot be a diff + num_changes_by_branch = await self.diff_repo.get_num_changes_in_time_range_by_branch( + branch_names=[diff_request.base_branch.name, diff_request.diff_branch.name], from_time=current_time, to_time=end_time, ) + might_have_changes_in_time_range = any(num_changes_by_branch.values()) + if not might_have_changes_in_time_range: + incremental_diffs_and_requests.append(None) + current_time = end_time + continue + + incremental_diffs_and_requests.append( + EnrichedDiffRequest( + base_branch=diff_request.base_branch, + diff_branch=diff_request.diff_branch, + from_time=current_time, + to_time=end_time, + ) + ) + current_time = end_time + + aggregated_enriched_diffs = await self._concatenate_diffs_and_requests( + diff_or_request_list=incremental_diffs_and_requests, full_diff_request=diff_request ) - current_time = end_time - aggregated_enriched_diffs = await self._concatenate_diffs_and_requests( - diff_or_request_list=incremental_diffs_and_requests, full_diff_request=diff_request - ) + # no changes during this time period, so generate an EnrichedDiffs with no nodes + if not aggregated_enriched_diffs: + return self._build_enriched_diffs_with_no_nodes(diff_request=diff_request) - if aggregated_enriched_diffs: - aggregated_enriched_diffs.base_branch_diff.from_time = diff_request.from_time - aggregated_enriched_diffs.diff_branch_diff.from_time = diff_request.from_time - aggregated_enriched_diffs.base_branch_diff.to_time = diff_request.to_time - aggregated_enriched_diffs.diff_branch_diff.to_time = diff_request.to_time + # a diff exists in the database that covers at least part of this time period, + # but it might need to have its start or end time extended to cover time ranges + # with no changes + if not isinstance(aggregated_enriched_diffs, EnrichedDiffs): return aggregated_enriched_diffs - return self._build_empty_enriched_diffs(diff_request=diff_request) + + # a new diff (with nodes) covering the time period + aggregated_enriched_diffs.update_metadata( + from_time=diff_request.from_time, to_time=diff_request.to_time, tracking_id=diff_request.tracking_id + ) + aggregated_enriched_diffs.set_fresh_uuids() + return aggregated_enriched_diffs async def _concatenate_diffs_and_requests( self, - diff_or_request_list: list[EnrichedDiffsEmpty | EnrichedDiffRequest | None], + diff_or_request_list: Sequence[EnrichedDiffsMetadata | EnrichedDiffRequest | None], full_diff_request: EnrichedDiffRequest, - ) -> EnrichedDiffs | None: - calculations_required = False - existing_diff_count = 0 + ) -> EnrichedDiffs | EnrichedDiffsMetadata | None: + """ + Returns None if diff_or_request_list is empty or all Nones + meaning there are no changes for the diff during this time period + Returns EnrichedDiffsMetadata if diff_or_request_list includes one EnrichedDiffsMetadata and no EnrichedDiffRequests + meaning no diffs needed to be hydrated and combined + Otherwise, returns EnrichedDiffs + meaning multiple diffs (some that may have been freshly calculated) were combined + """ + previous_diff_pair: EnrichedDiffs | EnrichedDiffsMetadata | None = None for diff_or_request in diff_or_request_list: - # a diff needs to be calculated if isinstance(diff_or_request, EnrichedDiffRequest): - calculations_required = True - break - if isinstance(diff_or_request, EnrichedDiffsEmpty): - existing_diff_count += 1 - # multiple existing diffs need to be added together - if existing_diff_count > 1: - calculations_required = True - break - if not calculations_required: - return None - - complete_enriched_diffs: None | EnrichedDiffs = None - for diff_or_request in diff_or_request_list: - single_enriched_diffs: EnrichedDiffs | None = None - if isinstance(diff_or_request, EnrichedDiffsEmpty): - single_enriched_diffs = await self.diff_repo.hydrate_diff_pair(enriched_diffs=diff_or_request) - elif isinstance(diff_or_request, EnrichedDiffRequest): - if complete_enriched_diffs: - diff_or_request.node_field_specifiers = self._get_node_field_specifiers( - enriched_diff=complete_enriched_diffs.diff_branch_diff + if previous_diff_pair: + node_field_specifiers = await self.diff_repo.get_node_field_specifiers( + diff_id=previous_diff_pair.diff_branch_diff.uuid, ) + diff_or_request.node_field_specifiers = node_field_specifiers is_incremental_diff = diff_or_request.from_time != full_diff_request.from_time - single_enriched_diffs = await self._calculate_enriched_diff( + single_enriched_diffs: EnrichedDiffs | EnrichedDiffsMetadata = await self._calculate_enriched_diff( diff_request=diff_or_request, is_incremental_diff=is_incremental_diff ) - if not single_enriched_diffs: + + elif isinstance(diff_or_request, EnrichedDiffsMetadata): + single_enriched_diffs = diff_or_request + else: + continue + + if previous_diff_pair is None: + previous_diff_pair = single_enriched_diffs + continue + + if isinstance(single_enriched_diffs, EnrichedDiffs) and single_enriched_diffs.is_empty: + if previous_diff_pair: + previous_diff_pair.base_branch_diff.to_time = single_enriched_diffs.base_branch_diff.to_time + previous_diff_pair.diff_branch_diff.to_time = single_enriched_diffs.diff_branch_diff.to_time + else: + previous_diff_pair = single_enriched_diffs continue - if complete_enriched_diffs: - complete_enriched_diffs = await self.diff_combiner.combine( - earlier_diffs=complete_enriched_diffs, later_diffs=single_enriched_diffs + + if not isinstance(previous_diff_pair, EnrichedDiffs): + previous_diff_pair = await self.diff_repo.hydrate_diff_pair(enriched_diffs_empty=previous_diff_pair) + if not isinstance(single_enriched_diffs, EnrichedDiffs): + single_enriched_diffs = await self.diff_repo.hydrate_diff_pair( + enriched_diffs_empty=single_enriched_diffs ) - else: - complete_enriched_diffs = single_enriched_diffs - return complete_enriched_diffs + + previous_diff_pair = await self.diff_combiner.combine( + earlier_diffs=previous_diff_pair, later_diffs=single_enriched_diffs + ) + + return previous_diff_pair async def _update_core_data_checks(self, enriched_diff: EnrichedDiffRoot) -> list[Node]: return await self.data_check_synchronizer.synchronize(enriched_diff=enriched_diff) @@ -454,18 +535,3 @@ async def _calculate_enriched_diff( ) enriched_diff_pair = await self.diff_enricher.enrich(calculated_diffs=calculated_diff_pair) return enriched_diff_pair - - def _get_node_field_specifiers(self, enriched_diff: EnrichedDiffRoot) -> set[NodeFieldSpecifier]: - specifiers: set[NodeFieldSpecifier] = set() - schema_branch = registry.schema.get_schema_branch(name=enriched_diff.diff_branch_name) - for node in enriched_diff.nodes: - specifiers.update( - NodeFieldSpecifier(node_uuid=node.uuid, field_name=attribute.name) for attribute in node.attributes - ) - if not node.relationships: - continue - node_schema = schema_branch.get_node(name=node.kind, duplicate=False) - for relationship in node.relationships: - relationship_schema = node_schema.get_relationship(name=relationship.name) - specifiers.add(NodeFieldSpecifier(node_uuid=node.uuid, field_name=relationship_schema.get_identifier())) - return specifiers diff --git a/backend/infrahub/core/diff/model/path.py b/backend/infrahub/core/diff/model/path.py index 91dcd466fd..22bb5259bd 100644 --- a/backend/infrahub/core/diff/model/path.py +++ b/backend/infrahub/core/diff/model/path.py @@ -3,6 +3,7 @@ from dataclasses import asdict, dataclass, field, replace from enum import Enum from typing import TYPE_CHECKING, Any, Optional +from uuid import uuid4 from infrahub.core.constants import ( BranchSupportType, @@ -405,7 +406,7 @@ def from_calculated_node(cls, calculated_node: DiffNode) -> EnrichedDiffNode: @dataclass -class EnrichedDiffRootEmpty(BaseSummary): +class EnrichedDiffRootMetadata(BaseSummary): base_branch_name: str diff_branch_name: str from_time: Timestamp @@ -421,9 +422,27 @@ def __hash__(self) -> int: def time_range(self) -> Interval: return self.to_time.obj - self.from_time.obj + def update_metadata( + self, + from_time: Timestamp | None = None, + to_time: Timestamp | None = None, + tracking_id: TrackingId | None = None, + ) -> bool: + is_changed = False + if from_time and self.from_time != from_time: + self.from_time = from_time + is_changed = True + if to_time and self.to_time != to_time: + self.to_time = to_time + is_changed = True + if self.tracking_id != tracking_id: + self.tracking_id = tracking_id + is_changed = True + return is_changed + @dataclass -class EnrichedDiffRoot(EnrichedDiffRootEmpty): +class EnrichedDiffRoot(EnrichedDiffRootMetadata): nodes: set[EnrichedDiffNode] = field(default_factory=set) def __hash__(self) -> int: @@ -460,7 +479,7 @@ def get_all_conflicts(self) -> dict[str, EnrichedDiffConflict]: return all_conflicts @classmethod - def from_empty_root(cls, empty_root: EnrichedDiffRootEmpty) -> EnrichedDiffRoot: + def from_empty_root(cls, empty_root: EnrichedDiffRootMetadata) -> EnrichedDiffRoot: return EnrichedDiffRoot(**asdict(empty_root)) @classmethod @@ -522,11 +541,11 @@ def add_parent( @dataclass -class EnrichedDiffsEmpty: +class EnrichedDiffsMetadata: base_branch_name: str diff_branch_name: str - base_branch_diff: EnrichedDiffRootEmpty - diff_branch_diff: EnrichedDiffRootEmpty + base_branch_diff: EnrichedDiffRootMetadata + diff_branch_diff: EnrichedDiffRootMetadata def __repr__(self) -> str: return ( @@ -539,9 +558,31 @@ def __repr__(self) -> str: f"to_time={self.diff_branch_diff.to_time})" ) + def update_metadata( + self, + from_time: Timestamp | None = None, + to_time: Timestamp | None = None, + tracking_id: TrackingId | None = None, + ) -> bool: + is_changed = self.base_branch_diff.update_metadata( + from_time=from_time, to_time=to_time, tracking_id=tracking_id + ) + is_changed |= self.diff_branch_diff.update_metadata( + from_time=from_time, to_time=to_time, tracking_id=tracking_id + ) + return is_changed + + def set_fresh_uuids(self) -> None: + base_uuid = str(uuid4()) + branch_uuid = str(uuid4()) + self.base_branch_diff.uuid = base_uuid + self.base_branch_diff.partner_uuid = branch_uuid + self.diff_branch_diff.uuid = branch_uuid + self.diff_branch_diff.partner_uuid = base_uuid + @dataclass -class EnrichedDiffs(EnrichedDiffsEmpty): +class EnrichedDiffs(EnrichedDiffsMetadata): base_branch_diff: EnrichedDiffRoot diff_branch_diff: EnrichedDiffRoot @@ -575,6 +616,10 @@ def from_calculated_diffs(cls, calculated_diffs: CalculatedDiffs) -> EnrichedDif diff_branch_diff=diff_branch_diff, ) + @property + def is_empty(self) -> bool: + return len(self.base_branch_diff.nodes) == 0 and len(self.diff_branch_diff.nodes) == 0 + @dataclass class CalculatedDiffs: diff --git a/backend/infrahub/core/diff/query/field_specifiers.py b/backend/infrahub/core/diff/query/field_specifiers.py new file mode 100644 index 0000000000..1a4c312539 --- /dev/null +++ b/backend/infrahub/core/diff/query/field_specifiers.py @@ -0,0 +1,33 @@ +from typing import Any, Generator + +from infrahub.core.query import Query, QueryType +from infrahub.database import InfrahubDatabase + + +class EnrichedDiffFieldSpecifiersQuery(Query): + name = "enriched_diff_field_specifiers" + type = QueryType.READ + insert_return = False + + def __init__(self, diff_id: str, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.diff_id = diff_id + + async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: + self.params = {"diff_id": self.diff_id} + query = """ +MATCH (root:DiffRoot {uuid: $diff_id})-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_ATTRIBUTE]->(attr:DiffAttribute) +RETURN node.uuid AS node_uuid, attr.name AS field_name +UNION +MATCH (root:DiffRoot {uuid: $diff_id})-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_RELATIONSHIP]->(rel:DiffRelationship) +RETURN node.uuid AS node_uuid, rel.identifier AS field_name +""" + self.return_labels = ["node_uuid", "field_name"] + self.add_to_query(query=query) + + def get_node_field_specifier_tuples(self) -> Generator[tuple[str, str], None, None]: + for result in self.get_results(): + node_uuid = result.get_as_str("node_uuid") + field_name = result.get_as_str("field_name") + if node_uuid and field_name: + yield (node_uuid, field_name) diff --git a/backend/infrahub/core/diff/repository/deserializer.py b/backend/infrahub/core/diff/repository/deserializer.py index 14520e89e0..59df0cabe9 100644 --- a/backend/infrahub/core/diff/repository/deserializer.py +++ b/backend/infrahub/core/diff/repository/deserializer.py @@ -16,7 +16,7 @@ EnrichedDiffProperty, EnrichedDiffRelationship, EnrichedDiffRoot, - EnrichedDiffRootEmpty, + EnrichedDiffRootMetadata, EnrichedDiffSingleRelationship, deserialize_tracking_id, ) @@ -157,14 +157,14 @@ def _deserialize_diff_root(self, root_node: Neo4jNode) -> EnrichedDiffRoot: return enriched_root @classmethod - def build_diff_root_empty(cls, root_node: Neo4jNode) -> EnrichedDiffRootEmpty: + def build_diff_root_empty(cls, root_node: Neo4jNode) -> EnrichedDiffRootMetadata: from_time = Timestamp(str(root_node.get("from_time"))) to_time = Timestamp(str(root_node.get("to_time"))) tracking_id_str = cls._get_str_or_none_property_value(node=root_node, property_name="tracking_id") tracking_id = None if tracking_id_str: tracking_id = deserialize_tracking_id(tracking_id_str=tracking_id_str) - return EnrichedDiffRootEmpty( + return EnrichedDiffRootMetadata( base_branch_name=str(root_node.get("base_branch")), diff_branch_name=str(root_node.get("diff_branch")), from_time=from_time, diff --git a/backend/infrahub/core/diff/repository/repository.py b/backend/infrahub/core/diff/repository/repository.py index 8fcd125db9..a7739097e0 100644 --- a/backend/infrahub/core/diff/repository/repository.py +++ b/backend/infrahub/core/diff/repository/repository.py @@ -12,11 +12,12 @@ ConflictSelection, EnrichedDiffConflict, EnrichedDiffRoot, - EnrichedDiffRootEmpty, + EnrichedDiffRootMetadata, EnrichedDiffs, - EnrichedDiffsEmpty, + EnrichedDiffsMetadata, EnrichedNodeCreateRequest, NodeDiffFieldSummary, + NodeFieldSpecifier, TimeRange, TrackingId, ) @@ -25,6 +26,7 @@ from ..query.diff_summary import DiffSummaryCounters, DiffSummaryQuery from ..query.drop_tracking_id import EnrichedDiffDropTrackingIdQuery from ..query.empty_roots import EnrichedDiffEmptyRootsQuery +from ..query.field_specifiers import EnrichedDiffFieldSpecifiersQuery from ..query.filters import EnrichedDiffQueryFilters from ..query.get_conflict_query import EnrichedDiffConflictQuery from ..query.has_conflicts_query import EnrichedDiffHasConflictQuery @@ -121,16 +123,16 @@ async def get_pairs( for dbr in diff_branch_roots ] - async def hydrate_diff_pair(self, enriched_diffs: EnrichedDiffsEmpty) -> EnrichedDiffs: + async def hydrate_diff_pair(self, enriched_diffs_empty: EnrichedDiffsMetadata) -> EnrichedDiffs: hydrated_base_diff = await self.get_one( - diff_branch_name=enriched_diffs.base_branch_name, diff_id=enriched_diffs.base_branch_diff.uuid + diff_branch_name=enriched_diffs_empty.base_branch_name, diff_id=enriched_diffs_empty.base_branch_diff.uuid ) hydrated_branch_diff = await self.get_one( - diff_branch_name=enriched_diffs.diff_branch_name, diff_id=enriched_diffs.diff_branch_diff.uuid + diff_branch_name=enriched_diffs_empty.diff_branch_name, diff_id=enriched_diffs_empty.diff_branch_diff.uuid ) return EnrichedDiffs( - base_branch_name=enriched_diffs.base_branch_name, - diff_branch_name=enriched_diffs.diff_branch_name, + base_branch_name=enriched_diffs_empty.base_branch_name, + diff_branch_name=enriched_diffs_empty.diff_branch_name, base_branch_diff=hydrated_base_diff, diff_branch_diff=hydrated_branch_diff, ) @@ -234,7 +236,7 @@ async def get_empty_diff_pairs( base_branch_names: list[str] | None = None, from_time: Timestamp | None = None, to_time: Timestamp | None = None, - ) -> list[EnrichedDiffsEmpty]: + ) -> list[EnrichedDiffsMetadata]: if diff_branch_names and base_branch_names: diff_branch_names += base_branch_names empty_roots = await self.get_empty_roots( @@ -244,13 +246,13 @@ async def get_empty_diff_pairs( to_time=to_time, ) roots_by_id = {root.uuid: root for root in empty_roots} - pairs: list[EnrichedDiffsEmpty] = [] + pairs: list[EnrichedDiffsMetadata] = [] for branch_root in empty_roots: if branch_root.base_branch_name == branch_root.diff_branch_name: continue base_root = roots_by_id[branch_root.partner_uuid] pairs.append( - EnrichedDiffsEmpty( + EnrichedDiffsMetadata( base_branch_name=branch_root.base_branch_name, diff_branch_name=branch_root.diff_branch_name, base_branch_diff=base_root, @@ -265,7 +267,7 @@ async def get_empty_roots( base_branch_names: list[str] | None = None, from_time: Timestamp | None = None, to_time: Timestamp | None = None, - ) -> list[EnrichedDiffRootEmpty]: + ) -> list[EnrichedDiffRootMetadata]: query = await EnrichedDiffEmptyRootsQuery.init( db=self.db, diff_branch_names=diff_branch_names, @@ -328,3 +330,16 @@ async def get_num_changes_in_time_range_by_branch( query = await DiffCountChanges.init(db=self.db, branch_names=branch_names, diff_from=from_time, diff_to=to_time) await query.execute(db=self.db) return query.get_num_changes_by_branch() + + async def get_node_field_specifiers(self, diff_id: str) -> set[NodeFieldSpecifier]: + query = await EnrichedDiffFieldSpecifiersQuery.init(db=self.db, diff_id=diff_id) + await query.execute(db=self.db) + specifiers: set[NodeFieldSpecifier] = set() + specifiers.update( + NodeFieldSpecifier( + node_uuid=field_specifier_tuple[0], + field_name=field_specifier_tuple[1], + ) + for field_specifier_tuple in query.get_node_field_specifier_tuples() + ) + return specifiers diff --git a/backend/tests/integration/diff/test_diff_incremental_addition.py b/backend/tests/integration/diff/test_diff_incremental_addition.py index 40909ede9b..71abada137 100644 --- a/backend/tests/integration/diff/test_diff_incremental_addition.py +++ b/backend/tests/integration/diff/test_diff_incremental_addition.py @@ -273,7 +273,7 @@ async def test_update_previous_owner_on_branch( base_branch=default_branch, diff_branch=diff_branch, from_time=incremental_diff.from_time, - to_time=incremental_diff.to_time, + to_time=Timestamp(), name=str(uuid4()), ) await self.validate_diff_data_02(db=db, enriched_diff=full_diff, initial_dataset=initial_dataset) diff --git a/backend/tests/unit/core/diff/test_coordinator_lock.py b/backend/tests/unit/core/diff/test_coordinator_lock.py index 045fb74cf7..74f748859c 100644 --- a/backend/tests/unit/core/diff/test_coordinator_lock.py +++ b/backend/tests/unit/core/diff/test_coordinator_lock.py @@ -1,4 +1,5 @@ import asyncio +from datetime import timedelta from unittest.mock import AsyncMock from uuid import uuid4 @@ -78,16 +79,16 @@ async def test_arbitrary_diff_locks_queue_up( ) assert len(results) == 2 assert results[0].to_time != results[1].to_time - assert results[0].uuid != results[1].uuid - assert results[0].partner_uuid != results[1].partner_uuid + assert results[0].uuid == results[1].uuid + assert results[0].partner_uuid == results[1].partner_uuid results[0].to_time = results[1].to_time - results[0].uuid = results[1].uuid - results[0].partner_uuid = results[1].partner_uuid assert results[0] == results[1] - # called once to calculate diff on main and once to calculate diff on the branch - assert len(diff_coordinator.diff_calculator.calculate_diff.call_args_list) == 2 - # not called because diffs are calculated both times - diff_coordinator.diff_repo.get_one.assert_not_awaited() + # second diff uses first diff for its data and is not calculated + assert len(diff_coordinator.diff_calculator.calculate_diff.call_args_list) == 1 + # confirm that we retrieve the first diff to use when calculating the second, overlapping diff + diff_coordinator.diff_repo.get_one.assert_called_once_with( + diff_branch_name=diff_branch.name, diff_id=results[0].uuid + ) async def test_arbitrary_diff_blocks_incremental_diff( self, db: InfrahubDatabase, default_branch: Branch, branch_with_data: Branch @@ -106,18 +107,18 @@ async def test_arbitrary_diff_blocks_incremental_diff( ) assert len(results) == 2 assert results[0].to_time != results[1].to_time - assert results[0].uuid != results[1].uuid - assert results[0].partner_uuid != results[1].partner_uuid + assert results[0].uuid == results[1].uuid + assert results[0].partner_uuid == results[1].partner_uuid assert results[0].tracking_id != results[1].tracking_id results[0].to_time = results[1].to_time - results[0].uuid = results[1].uuid - results[0].partner_uuid = results[1].partner_uuid results[0].tracking_id = results[1].tracking_id assert results[0] == results[1] - # called once to calculate diff on main and once to calculate diff on the branch - assert len(diff_coordinator.diff_calculator.calculate_diff.call_args_list) == 2 - # not called because diffs are calculated both times - diff_coordinator.diff_repo.get_one.assert_not_awaited() + # second diff uses first diff for its data and is not calculated + assert len(diff_coordinator.diff_calculator.calculate_diff.call_args_list) == 1 + # confirm that we retrieve the first diff to use when calculating the second, overlapping diff + diff_coordinator.diff_repo.get_one.assert_called_once_with( + diff_branch_name=diff_branch.name, diff_id=results[0].uuid + ) async def test_incremental_diff_blocks_arbitrary_diff( self, db: InfrahubDatabase, default_branch: Branch, branch_with_data: Branch @@ -125,26 +126,28 @@ async def test_incremental_diff_blocks_arbitrary_diff( diff_branch = branch_with_data diff_coordinator = await self.get_diff_coordinator(db=db, diff_branch=diff_branch) + arbitrary_to_time = Timestamp() + arbitrary_to_time.obj += timedelta(seconds=5) results = await asyncio.gather( - diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=diff_branch), + diff_coordinator.update_branch_diff_and_return(base_branch=default_branch, diff_branch=diff_branch), diff_coordinator.create_or_update_arbitrary_timeframe_diff( base_branch=default_branch, diff_branch=diff_branch, from_time=Timestamp(branch_with_data.branched_from), - to_time=Timestamp(), + to_time=arbitrary_to_time, ), ) assert len(results) == 2 assert results[0].to_time != results[1].to_time - assert results[0].uuid != results[1].uuid - assert results[0].partner_uuid != results[1].partner_uuid + assert results[0].uuid == results[1].uuid + assert results[0].partner_uuid == results[1].partner_uuid assert results[0].tracking_id != results[1].tracking_id results[0].to_time = results[1].to_time - results[0].uuid = results[1].uuid - results[0].partner_uuid = results[1].partner_uuid results[0].tracking_id = results[1].tracking_id assert results[0] == results[1] - # called once to calculate diff on main and once to calculate diff on the branch - assert len(diff_coordinator.diff_calculator.calculate_diff.call_args_list) == 2 - # not called because diffs are calculated both times - diff_coordinator.diff_repo.get_one.assert_not_awaited() + # second diff uses first diff for its data and is not calculated + assert len(diff_coordinator.diff_calculator.calculate_diff.call_args_list) == 1 + # confirm that we retrieve the first diff to use when calculating the second, overlapping diff + diff_coordinator.diff_repo.get_one.assert_called_once_with( + diff_branch_name=diff_branch.name, diff_id=results[0].uuid + ) diff --git a/backend/tests/unit/core/diff/test_diff_and_merge.py b/backend/tests/unit/core/diff/test_diff_and_merge.py index 4b203d1d5b..19f2f1febd 100644 --- a/backend/tests/unit/core/diff/test_diff_and_merge.py +++ b/backend/tests/unit/core/diff/test_diff_and_merge.py @@ -576,7 +576,9 @@ async def test_branch_delete_with_added_base_relationship( await car_main.save(db=db) # check that the conflict is removed - enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff_and_return( + base_branch=default_branch, diff_branch=branch2 + ) conflicts_map = enriched_diff.get_all_conflicts() assert len(conflicts_map) == 0 diff --git a/backend/tests/unit/core/diff/test_diff_combiner.py b/backend/tests/unit/core/diff/test_diff_combiner.py index 9e2bd7d709..2d4088c168 100644 --- a/backend/tests/unit/core/diff/test_diff_combiner.py +++ b/backend/tests/unit/core/diff/test_diff_combiner.py @@ -450,6 +450,7 @@ async def test_relationship_one_combined(self, with_schema_manager): expected_relationship = EnrichedDiffRelationship( name=relationship_name, label=later_relationship.label, + identifier=later_relationship.identifier, cardinality=RelationshipCardinality.ONE, changed_at=later_relationship.changed_at, action=DiffAction.ADDED, @@ -618,6 +619,7 @@ async def test_relationship_many_combined(self, with_schema_manager): expected_relationship = EnrichedDiffRelationship( name=relationship_name, label=relationship_group_2.label, + identifier=relationship_group_2.identifier, cardinality=RelationshipCardinality.MANY, changed_at=relationship_group_2.changed_at, action=DiffAction.UPDATED, @@ -682,6 +684,7 @@ async def test_relationship_with_only_nodes(self, with_schema_manager): expected_relationship = EnrichedDiffRelationship( name=relationship_name, label=later_relationship.label, + identifier=later_relationship.identifier, cardinality=RelationshipCardinality.MANY, changed_at=later_relationship.changed_at, action=DiffAction.ADDED, @@ -851,6 +854,7 @@ async def test_unchanged_parents_correctly_updated(self): expected_relationship = EnrichedDiffRelationship( name=relationship_name, label=parent_rel_2.label, + identifier=parent_rel_2.identifier, changed_at=parent_rel_2.changed_at, cardinality=RelationshipCardinality.ONE, path_identifier=parent_rel_2.path_identifier, @@ -923,6 +927,7 @@ async def test_updated_parents_correctly_updated(self): expected_child_rel = EnrichedDiffRelationship( name=relationship_name, label=child_rel_2.label, + identifier=child_rel_2.identifier, changed_at=child_rel_2.changed_at, cardinality=RelationshipCardinality.ONE, path_identifier=child_rel_2.path_identifier, @@ -1083,6 +1088,7 @@ async def test_resetting_relationship_one_makes_it_unchanged(self, with_schema_m expected_relationship = EnrichedDiffRelationship( name=relationship_name, label=later_relationship.label, + identifier=later_relationship.identifier, cardinality=RelationshipCardinality.ONE, changed_at=later_relationship.changed_at, action=DiffAction.UPDATED, From bd7e561da3e7f3d31d9ef0a0dbbfa7b6221471d0 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Thu, 2 Jan 2025 15:11:49 -0800 Subject: [PATCH 09/18] deal with CoreDataCheck syncing --- backend/infrahub/core/diff/coordinator.py | 28 +++++++++------ .../core/diff/data_check_synchronizer.py | 34 ++++++++++++++++--- .../object_conflict/conflict_recorder.py | 7 +++- .../builder/diff/data_check_synchronizer.py | 2 ++ 4 files changed, 56 insertions(+), 15 deletions(-) diff --git a/backend/infrahub/core/diff/coordinator.py b/backend/infrahub/core/diff/coordinator.py index 15af7252e8..117d8aa9f1 100644 --- a/backend/infrahub/core/diff/coordinator.py +++ b/backend/infrahub/core/diff/coordinator.py @@ -163,6 +163,7 @@ async def update_branch_diff( force_branch_refresh=False, ) if not isinstance(enriched_diffs, EnrichedDiffs): + await self._update_core_data_checks(enriched_diff=enriched_diffs.diff_branch_diff) return enriched_diffs.diff_branch_diff await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.base_branch_diff) @@ -197,6 +198,7 @@ async def create_or_update_arbitrary_timeframe_diff( force_branch_refresh=False, ) if not isinstance(enriched_diffs, EnrichedDiffs): + await self._update_core_data_checks(enriched_diff=enriched_diffs.diff_branch_diff) return await self._finalize_diff_root_metadata(diff_root_metadata=enriched_diffs.diff_branch_diff) await self.summary_counts_enricher.enrich(enriched_diff_root=enriched_diffs.base_branch_diff) @@ -330,7 +332,6 @@ async def _update_diffs( tracking_id: TrackingId | None = None, force_branch_refresh: bool = False, ) -> EnrichedDiffs | EnrichedDiffsMetadata: - diff_uuids_to_delete = [] # start with empty diffs b/c we only care about their metadata for now, hydrate them with data as needed empty_diff_pairs = await self.diff_repo.get_empty_diff_pairs( base_branch_names=[base_branch.name], @@ -338,12 +339,6 @@ async def _update_diffs( from_time=from_time, to_time=to_time, ) - if tracking_id: - for diff_pair in empty_diff_pairs: - if diff_pair.base_branch_diff.tracking_id == tracking_id: - diff_uuids_to_delete.append(diff_pair.base_branch_diff.uuid) - if diff_pair.diff_branch_diff.tracking_id == tracking_id: - diff_uuids_to_delete.append(diff_pair.diff_branch_diff.uuid) aggregated_enriched_diffs = await self._aggregate_enriched_diffs( diff_request=EnrichedDiffRequest( base_branch=base_branch, @@ -354,9 +349,22 @@ async def _update_diffs( ), partial_enriched_diffs=empty_diff_pairs if not force_branch_refresh else None, ) + if tracking_id: + diff_uuids_to_delete: list[str] = [] + for diff_pair in empty_diff_pairs: + if ( + diff_pair.base_branch_diff.tracking_id == tracking_id + and diff_pair.base_branch_diff.uuid != aggregated_enriched_diffs.base_branch_diff.uuid + ): + diff_uuids_to_delete.append(diff_pair.base_branch_diff.uuid) + if ( + diff_pair.diff_branch_diff.tracking_id == tracking_id + and diff_pair.diff_branch_diff.uuid != aggregated_enriched_diffs.diff_branch_diff.uuid + ): + diff_uuids_to_delete.append(diff_pair.diff_branch_diff.uuid) - if diff_uuids_to_delete: - await self.diff_repo.delete_diff_roots(diff_root_uuids=diff_uuids_to_delete) + if diff_uuids_to_delete: + await self.diff_repo.delete_diff_roots(diff_root_uuids=diff_uuids_to_delete) # this is an EnrichedDiffsMetadata, so there are no nodes to enrich if not isinstance(aggregated_enriched_diffs, EnrichedDiffs): @@ -519,7 +527,7 @@ async def _concatenate_diffs_and_requests( return previous_diff_pair - async def _update_core_data_checks(self, enriched_diff: EnrichedDiffRoot) -> list[Node]: + async def _update_core_data_checks(self, enriched_diff: EnrichedDiffRoot | EnrichedDiffRootMetadata) -> list[Node]: return await self.data_check_synchronizer.synchronize(enriched_diff=enriched_diff) async def _calculate_enriched_diff( diff --git a/backend/infrahub/core/diff/data_check_synchronizer.py b/backend/infrahub/core/diff/data_check_synchronizer.py index 624fe31597..3b3f3c0436 100644 --- a/backend/infrahub/core/diff/data_check_synchronizer.py +++ b/backend/infrahub/core/diff/data_check_synchronizer.py @@ -9,7 +9,9 @@ from infrahub.proposed_change.constants import ProposedChangeState from .conflicts_extractor import DiffConflictsExtractor -from .model.path import ConflictSelection, EnrichedDiffConflict, EnrichedDiffRoot +from .model.diff import DataConflict +from .model.path import ConflictSelection, EnrichedDiffConflict, EnrichedDiffRoot, EnrichedDiffRootMetadata +from .repository.repository import DiffRepository class DiffDataCheckSynchronizer: @@ -18,12 +20,28 @@ def __init__( db: InfrahubDatabase, conflicts_extractor: DiffConflictsExtractor, conflict_recorder: ObjectConflictValidatorRecorder, + diff_repository: DiffRepository, ): self.db = db self.conflicts_extractor = conflicts_extractor self.conflict_recorder = conflict_recorder + self.diff_repository = diff_repository + self._enriched_conflicts_map: dict[str, EnrichedDiffConflict] | None = None + self._data_conflicts: list[DataConflict] | None = None - async def synchronize(self, enriched_diff: EnrichedDiffRoot) -> list[Node]: + def _get_enriched_conflicts_map(self, enriched_diff: EnrichedDiffRoot) -> dict[str, EnrichedDiffConflict]: + if self._enriched_conflicts_map is None: + self._enriched_conflicts_map = enriched_diff.get_all_conflicts() + return self._enriched_conflicts_map + + async def _get_data_conflicts(self, enriched_diff: EnrichedDiffRoot) -> list[DataConflict]: + if self._data_conflicts is None: + self._data_conflicts = await self.conflicts_extractor.get_data_conflicts(enriched_diff_root=enriched_diff) + return self._data_conflicts + + async def synchronize(self, enriched_diff: EnrichedDiffRoot | EnrichedDiffRootMetadata) -> list[Node]: + self._enriched_conflicts_map = None + self._data_conflicts = None try: proposed_changes = await NodeManager.query( db=self.db, @@ -35,10 +53,18 @@ async def synchronize(self, enriched_diff: EnrichedDiffRoot) -> list[Node]: proposed_changes = [] if not proposed_changes: return [] - enriched_conflicts_map = enriched_diff.get_all_conflicts() - data_conflicts = await self.conflicts_extractor.get_data_conflicts(enriched_diff_root=enriched_diff) all_data_checks = [] for pc in proposed_changes: + if not isinstance(enriched_diff, EnrichedDiffRoot): + has_validator = bool(await self.conflict_recorder.get_validator(proposed_change=pc)) + if has_validator: + continue + enriched_diff = await self.diff_repository.get_one( + diff_branch_name=enriched_diff.diff_branch_name, diff_id=enriched_diff.uuid + ) + + data_conflicts = await self._get_data_conflicts(enriched_diff=enriched_diff) + enriched_conflicts_map = self._get_enriched_conflicts_map(enriched_diff=enriched_diff) core_data_checks = await self.conflict_recorder.record_conflicts( proposed_change_id=pc.get_id(), conflicts=data_conflicts ) diff --git a/backend/infrahub/core/integrity/object_conflict/conflict_recorder.py b/backend/infrahub/core/integrity/object_conflict/conflict_recorder.py index eaf3d71eb0..889854825b 100644 --- a/backend/infrahub/core/integrity/object_conflict/conflict_recorder.py +++ b/backend/infrahub/core/integrity/object_conflict/conflict_recorder.py @@ -94,13 +94,18 @@ async def record_conflicts(self, proposed_change_id: str, conflicts: Sequence[Ob await self.finalize_validator(validator, is_success) return current_checks - async def get_or_create_validator(self, proposed_change: CoreProposedChange) -> Node: + async def get_validator(self, proposed_change: CoreProposedChange) -> Node | None: validations = await proposed_change.validations.get_peers(db=self.db, branch_agnostic=True) for validation in validations.values(): if validation.get_kind() == self.validator_kind: return validation + return None + async def get_or_create_validator(self, proposed_change: CoreProposedChange) -> Node: + validator_obj = await self.get_validator(proposed_change=proposed_change) + if validator_obj: + return validator_obj validator_obj = await Node.init(db=self.db, schema=self.validator_kind) await validator_obj.new( db=self.db, diff --git a/backend/infrahub/dependencies/builder/diff/data_check_synchronizer.py b/backend/infrahub/dependencies/builder/diff/data_check_synchronizer.py index 3250724e0a..427e5e22a8 100644 --- a/backend/infrahub/dependencies/builder/diff/data_check_synchronizer.py +++ b/backend/infrahub/dependencies/builder/diff/data_check_synchronizer.py @@ -3,6 +3,7 @@ from .conflicts_extractor import DiffConflictsExtractorDependency from .data_check_conflict_recorder import DataCheckConflictRecorderDependency +from .repository import DiffRepositoryDependency class DiffDataCheckSynchronizerDependency(DependencyBuilder[DiffDataCheckSynchronizer]): @@ -12,4 +13,5 @@ def build(cls, context: DependencyBuilderContext) -> DiffDataCheckSynchronizer: db=context.db, conflicts_extractor=DiffConflictsExtractorDependency.build(context=context), conflict_recorder=DataCheckConflictRecorderDependency.build(context=context), + diff_repository=DiffRepositoryDependency.build(context=context), ) From 37457980102cfb2703d28d3dee693a740e9af349 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Thu, 2 Jan 2025 15:46:21 -0800 Subject: [PATCH 10/18] update some method and variable names --- backend/infrahub/core/diff/coordinator.py | 6 +++--- backend/infrahub/core/diff/repository/repository.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/infrahub/core/diff/coordinator.py b/backend/infrahub/core/diff/coordinator.py index 117d8aa9f1..e058943ed6 100644 --- a/backend/infrahub/core/diff/coordinator.py +++ b/backend/infrahub/core/diff/coordinator.py @@ -333,7 +333,7 @@ async def _update_diffs( force_branch_refresh: bool = False, ) -> EnrichedDiffs | EnrichedDiffsMetadata: # start with empty diffs b/c we only care about their metadata for now, hydrate them with data as needed - empty_diff_pairs = await self.diff_repo.get_empty_diff_pairs( + diff_pairs_metadata = await self.diff_repo.get_diff_pairs_metadata( base_branch_names=[base_branch.name], diff_branch_names=[diff_branch.name], from_time=from_time, @@ -347,11 +347,11 @@ async def _update_diffs( to_time=to_time, tracking_id=tracking_id, ), - partial_enriched_diffs=empty_diff_pairs if not force_branch_refresh else None, + partial_enriched_diffs=diff_pairs_metadata if not force_branch_refresh else None, ) if tracking_id: diff_uuids_to_delete: list[str] = [] - for diff_pair in empty_diff_pairs: + for diff_pair in diff_pairs_metadata: if ( diff_pair.base_branch_diff.tracking_id == tracking_id and diff_pair.base_branch_diff.uuid != aggregated_enriched_diffs.base_branch_diff.uuid diff --git a/backend/infrahub/core/diff/repository/repository.py b/backend/infrahub/core/diff/repository/repository.py index a7739097e0..5077dedf03 100644 --- a/backend/infrahub/core/diff/repository/repository.py +++ b/backend/infrahub/core/diff/repository/repository.py @@ -230,7 +230,7 @@ async def get_time_ranges( await query.execute(db=self.db) return await query.get_time_ranges() - async def get_empty_diff_pairs( + async def get_diff_pairs_metadata( self, diff_branch_names: list[str] | None = None, base_branch_names: list[str] | None = None, From 75176d212f6e3d76069d1b732b95ea2430f1b93e Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Thu, 2 Jan 2025 18:26:48 -0800 Subject: [PATCH 11/18] unit tests, little more refactoring --- backend/infrahub/core/diff/coordinator.py | 39 +++--- .../core/diff/repository/repository.py | 12 +- .../test_propose_change_repository.py | 2 - .../tests/unit/core/diff/test_coordinator.py | 122 +++++++++++++++++- 4 files changed, 149 insertions(+), 26 deletions(-) diff --git a/backend/infrahub/core/diff/coordinator.py b/backend/infrahub/core/diff/coordinator.py index e058943ed6..97de1d84a8 100644 --- a/backend/infrahub/core/diff/coordinator.py +++ b/backend/infrahub/core/diff/coordinator.py @@ -506,27 +506,30 @@ async def _concatenate_diffs_and_requests( previous_diff_pair = single_enriched_diffs continue - if isinstance(single_enriched_diffs, EnrichedDiffs) and single_enriched_diffs.is_empty: - if previous_diff_pair: - previous_diff_pair.base_branch_diff.to_time = single_enriched_diffs.base_branch_diff.to_time - previous_diff_pair.diff_branch_diff.to_time = single_enriched_diffs.diff_branch_diff.to_time - else: - previous_diff_pair = single_enriched_diffs - continue - - if not isinstance(previous_diff_pair, EnrichedDiffs): - previous_diff_pair = await self.diff_repo.hydrate_diff_pair(enriched_diffs_empty=previous_diff_pair) - if not isinstance(single_enriched_diffs, EnrichedDiffs): - single_enriched_diffs = await self.diff_repo.hydrate_diff_pair( - enriched_diffs_empty=single_enriched_diffs - ) - - previous_diff_pair = await self.diff_combiner.combine( - earlier_diffs=previous_diff_pair, later_diffs=single_enriched_diffs - ) + previous_diff_pair = await self._combine_diffs(earlier=previous_diff_pair, later=single_enriched_diffs) return previous_diff_pair + async def _combine_diffs( + self, earlier: EnrichedDiffs | EnrichedDiffsMetadata, later: EnrichedDiffs | EnrichedDiffsMetadata + ) -> EnrichedDiffs | EnrichedDiffsMetadata: + # if one of the diffs is hydrated and has no data, we can combine them without hydrating the other + if isinstance(earlier, EnrichedDiffs) and earlier.is_empty: + later.base_branch_diff.from_time = earlier.base_branch_diff.from_time + later.diff_branch_diff.from_time = earlier.diff_branch_diff.from_time + return later + if isinstance(later, EnrichedDiffs) and later.is_empty: + earlier.base_branch_diff.to_time = later.base_branch_diff.to_time + earlier.diff_branch_diff.to_time = later.diff_branch_diff.to_time + return earlier + + if not isinstance(earlier, EnrichedDiffs): + earlier = await self.diff_repo.hydrate_diff_pair(enriched_diffs_metadata=earlier) + if not isinstance(later, EnrichedDiffs): + later = await self.diff_repo.hydrate_diff_pair(enriched_diffs_metadata=later) + + return await self.diff_combiner.combine(earlier_diffs=earlier, later_diffs=later) + async def _update_core_data_checks(self, enriched_diff: EnrichedDiffRoot | EnrichedDiffRootMetadata) -> list[Node]: return await self.data_check_synchronizer.synchronize(enriched_diff=enriched_diff) diff --git a/backend/infrahub/core/diff/repository/repository.py b/backend/infrahub/core/diff/repository/repository.py index 5077dedf03..b87d211205 100644 --- a/backend/infrahub/core/diff/repository/repository.py +++ b/backend/infrahub/core/diff/repository/repository.py @@ -123,16 +123,18 @@ async def get_pairs( for dbr in diff_branch_roots ] - async def hydrate_diff_pair(self, enriched_diffs_empty: EnrichedDiffsMetadata) -> EnrichedDiffs: + async def hydrate_diff_pair(self, enriched_diffs_metadata: EnrichedDiffsMetadata) -> EnrichedDiffs: hydrated_base_diff = await self.get_one( - diff_branch_name=enriched_diffs_empty.base_branch_name, diff_id=enriched_diffs_empty.base_branch_diff.uuid + diff_branch_name=enriched_diffs_metadata.base_branch_name, + diff_id=enriched_diffs_metadata.base_branch_diff.uuid, ) hydrated_branch_diff = await self.get_one( - diff_branch_name=enriched_diffs_empty.diff_branch_name, diff_id=enriched_diffs_empty.diff_branch_diff.uuid + diff_branch_name=enriched_diffs_metadata.diff_branch_name, + diff_id=enriched_diffs_metadata.diff_branch_diff.uuid, ) return EnrichedDiffs( - base_branch_name=enriched_diffs_empty.base_branch_name, - diff_branch_name=enriched_diffs_empty.diff_branch_name, + base_branch_name=enriched_diffs_metadata.base_branch_name, + diff_branch_name=enriched_diffs_metadata.diff_branch_name, base_branch_diff=hydrated_base_diff, diff_branch_diff=hydrated_branch_diff, ) diff --git a/backend/tests/integration_docker/test_propose_change_repository.py b/backend/tests/integration_docker/test_propose_change_repository.py index 36984e2c20..96d88eddbf 100644 --- a/backend/tests/integration_docker/test_propose_change_repository.py +++ b/backend/tests/integration_docker/test_propose_change_repository.py @@ -84,5 +84,3 @@ async def test_create_propose_change(self, client: InfrahubClient, default_branc kind=CoreProposedChange, name="pc1", source_branch=branch.name, destination_branch=default_branch ) await pc.save() - - # breakpoint() diff --git a/backend/tests/unit/core/diff/test_coordinator.py b/backend/tests/unit/core/diff/test_coordinator.py index 98a129ad7c..7bee68ca06 100644 --- a/backend/tests/unit/core/diff/test_coordinator.py +++ b/backend/tests/unit/core/diff/test_coordinator.py @@ -1,8 +1,13 @@ +from typing import Any +from unittest.mock import AsyncMock, call + from infrahub.core.branch import Branch from infrahub.core.constants import DiffAction from infrahub.core.constants.database import DatabaseEdgeType +from infrahub.core.diff.calculator import DiffCalculator +from infrahub.core.diff.combiner import DiffCombiner from infrahub.core.diff.coordinator import DiffCoordinator -from infrahub.core.diff.model.path import BranchTrackingId +from infrahub.core.diff.model.path import BranchTrackingId, EnrichedDiffRootMetadata, NodeFieldSpecifier from infrahub.core.diff.repository.repository import DiffRepository from infrahub.core.initialization import create_branch from infrahub.core.manager import NodeManager @@ -13,6 +18,27 @@ class TestDiffCoordinator: + async def get_wrapped_diff_coordinator( + self, + db: InfrahubDatabase, + branch: Branch, + ) -> DiffCoordinator: + component_registry = get_component_registry() + diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=branch) + real_repository = await component_registry.get_component(DiffRepository, db=db, branch=branch) + real_calculator = await component_registry.get_component(DiffCalculator, db=db, branch=branch) + real_combiner = await component_registry.get_component(DiffCombiner, db=db, branch=branch) + diff_coordinator.diff_repo = AsyncMock(wraps=real_repository) + diff_coordinator.diff_calculator = AsyncMock(wraps=real_calculator) + diff_coordinator.diff_combiner = AsyncMock(wraps=real_combiner) + return diff_coordinator + + def reset_mocks(self, reset_it: Any) -> None: + for attr_name in dir(reset_it): + attr = getattr(reset_it, attr_name) + if isinstance(attr, AsyncMock): + attr.reset_mock() + async def test_node_deleted_after_branching( self, db: InfrahubDatabase, default_branch: Branch, person_john_main: Node ): @@ -109,3 +135,97 @@ async def test_overlapping_diffs(self, db: InfrahubDatabase, default_branch: Bra assert diff_property.action is DiffAction.UPDATED assert diff_property.previous_value == str(original_height) assert diff_property.new_value == "3" + + async def test_no_changes_skips_expensive_operations( + self, db: InfrahubDatabase, default_branch: Branch, person_john_main: Node + ): + branch = await create_branch(db=db, branch_name="branch") + wrapped_diff_coordinator = await self.get_wrapped_diff_coordinator(db=db, branch=branch) + + time1 = Timestamp() + person_john_branch = await NodeManager.get_one(db=db, branch=branch, id=person_john_main.id) + person_john_branch.height.value += 1 + await person_john_branch.save(db=db) + time2 = Timestamp() + + # calculate this diff in the middle of change timeframe + diff_with_data = await wrapped_diff_coordinator.create_or_update_arbitrary_timeframe_diff( + base_branch=default_branch, diff_branch=branch, from_time=time1, to_time=time2 + ) + self.reset_mocks(wrapped_diff_coordinator) + + # get the whole diff with no-change time periods before and after the calculated diff + no_changes_diff = await wrapped_diff_coordinator.update_branch_diff( + base_branch=default_branch, diff_branch=branch + ) + assert type(no_changes_diff) is EnrichedDiffRootMetadata + assert no_changes_diff.uuid == diff_with_data.uuid + assert no_changes_diff.from_time == Timestamp(branch.get_branched_from()) + assert no_changes_diff.from_time < diff_with_data.from_time + assert no_changes_diff.to_time > diff_with_data.to_time + wrapped_diff_coordinator.diff_calculator.calculate_diff.assert_not_awaited() + wrapped_diff_coordinator.diff_repo.get_one.assert_not_awaited() + wrapped_diff_coordinator.diff_repo.save.assert_not_awaited() + wrapped_diff_coordinator.diff_repo.hydrate_diff_pair.assert_not_awaited() + + async def test_unrelated_changes_skip_some_expensive_operations( + self, db: InfrahubDatabase, default_branch: Branch, person_john_main: Node + ): + branch = await create_branch(db=db, branch_name="branch") + wrapped_diff_coordinator = await self.get_wrapped_diff_coordinator(db=db, branch=branch) + + # unrelated change on main before + john_main = await NodeManager.get_one(db=db, id=person_john_main.id) + john_main.name.value = "Before John" + await john_main.save(db=db) + + # change on branch for the diff + time1 = Timestamp() + person_john_branch = await NodeManager.get_one(db=db, branch=branch, id=person_john_main.id) + person_john_branch.height.value += 1 + await person_john_branch.save(db=db) + time2 = Timestamp() + + # unrelated change on main after + john_main = await NodeManager.get_one(db=db, id=person_john_main.id) + john_main.name.value = "After John" + await john_main.save(db=db) + + # calculate this diff in the middle of change timeframe + diff_with_data = await wrapped_diff_coordinator.create_or_update_arbitrary_timeframe_diff( + base_branch=default_branch, diff_branch=branch, from_time=time1, to_time=time2 + ) + self.reset_mocks(wrapped_diff_coordinator) + + # get the whole diff with no-change time periods before and after the calculated diff + no_changes_diff = await wrapped_diff_coordinator.update_branch_diff( + base_branch=default_branch, diff_branch=branch + ) + assert type(no_changes_diff) is EnrichedDiffRootMetadata + assert no_changes_diff.uuid == diff_with_data.uuid + assert no_changes_diff.from_time == Timestamp(branch.get_branched_from()) + assert no_changes_diff.from_time < diff_with_data.from_time + assert no_changes_diff.to_time > diff_with_data.to_time + wrapped_diff_coordinator.diff_calculator.calculate_diff.assert_has_awaits( + [ + call( + base_branch=default_branch, + diff_branch=branch, + from_time=Timestamp(branch.get_branched_from()), + to_time=diff_with_data.from_time, + include_unchanged=False, + previous_node_specifiers=set(), + ), + call( + base_branch=default_branch, + diff_branch=branch, + from_time=diff_with_data.to_time, + to_time=no_changes_diff.to_time, + include_unchanged=True, + previous_node_specifiers={NodeFieldSpecifier(node_uuid=person_john_branch.id, field_name="height")}, + ), + ] + ) + wrapped_diff_coordinator.diff_repo.get_one.assert_not_awaited() + wrapped_diff_coordinator.diff_repo.save.assert_not_awaited() + wrapped_diff_coordinator.diff_repo.hydrate_diff_pair.assert_not_awaited() From 8a7d8394c6ad11de6c4c07a3ec9630efdb237a88 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Fri, 3 Jan 2025 09:02:15 -0800 Subject: [PATCH 12/18] some more comments and naming cleanup --- backend/infrahub/core/diff/coordinator.py | 9 ++-- .../core/diff/data_check_synchronizer.py | 3 ++ backend/infrahub/core/diff/merger/merger.py | 2 +- backend/infrahub/core/diff/model/path.py | 2 +- .../infrahub/core/diff/query/empty_roots.py | 48 ------------------- .../core/diff/repository/deserializer.py | 6 +-- .../core/diff/repository/repository.py | 12 ++--- backend/infrahub/core/diff/tasks.py | 2 +- .../graph/m015_diff_format_update.py | 2 +- .../graph/m016_diff_delete_bug_fix.py | 2 +- .../message_bus/operations/event/branch.py | 2 +- .../tests/unit/core/diff/test_diff_merger.py | 8 ++-- backend/tests/unit/core/ipam/conftest.py | 2 +- 13 files changed, 29 insertions(+), 71 deletions(-) delete mode 100644 backend/infrahub/core/diff/query/empty_roots.py diff --git a/backend/infrahub/core/diff/coordinator.py b/backend/infrahub/core/diff/coordinator.py index 97de1d84a8..a6e5062491 100644 --- a/backend/infrahub/core/diff/coordinator.py +++ b/backend/infrahub/core/diff/coordinator.py @@ -197,6 +197,7 @@ async def create_or_update_arbitrary_timeframe_diff( tracking_id=tracking_id, force_branch_refresh=False, ) + # metadata-only diff, so no nodes to enrich if not isinstance(enriched_diffs, EnrichedDiffs): await self._update_core_data_checks(enriched_diff=enriched_diffs.diff_branch_diff) return await self._finalize_diff_root_metadata(diff_root_metadata=enriched_diffs.diff_branch_diff) @@ -406,6 +407,7 @@ async def _aggregate_enriched_diffs( """ aggregated_enriched_diffs: EnrichedDiffs | EnrichedDiffsMetadata | None = None if not partial_enriched_diffs: + # no existing diffs to use in calculating this diff, so calculate the whole thing and return it aggregated_enriched_diffs = await self._calculate_enriched_diff( diff_request=diff_request, is_incremental_diff=False ) @@ -458,9 +460,9 @@ async def _aggregate_enriched_diffs( if not aggregated_enriched_diffs: return self._build_enriched_diffs_with_no_nodes(diff_request=diff_request) - # a diff exists in the database that covers at least part of this time period, - # but it might need to have its start or end time extended to cover time ranges - # with no changes + # metadata-only diff, means that a diff exists in the database that covers at least + # part of this time period, but it might need to have its start or end time extended + # to cover time ranges with no changes if not isinstance(aggregated_enriched_diffs, EnrichedDiffs): return aggregated_enriched_diffs @@ -523,6 +525,7 @@ async def _combine_diffs( earlier.diff_branch_diff.to_time = later.diff_branch_diff.to_time return earlier + # hydrate the diffs to combine, if necessary if not isinstance(earlier, EnrichedDiffs): earlier = await self.diff_repo.hydrate_diff_pair(enriched_diffs_metadata=earlier) if not isinstance(later, EnrichedDiffs): diff --git a/backend/infrahub/core/diff/data_check_synchronizer.py b/backend/infrahub/core/diff/data_check_synchronizer.py index 3b3f3c0436..178f437b8f 100644 --- a/backend/infrahub/core/diff/data_check_synchronizer.py +++ b/backend/infrahub/core/diff/data_check_synchronizer.py @@ -55,10 +55,13 @@ async def synchronize(self, enriched_diff: EnrichedDiffRoot | EnrichedDiffRootMe return [] all_data_checks = [] for pc in proposed_changes: + # if the enriched_diff is EnrichedDiffRootMetadata, then it has no new data in it if not isinstance(enriched_diff, EnrichedDiffRoot): has_validator = bool(await self.conflict_recorder.get_validator(proposed_change=pc)) + # if this pc does not have a validator, then it is a new ProposedChange if has_validator: continue + # if this is a new ProposedChange, we need to hydrate then EnrichedDiffRoot so that we can get the conflicts from it enriched_diff = await self.diff_repository.get_one( diff_branch_name=enriched_diff.diff_branch_name, diff_id=enriched_diff.uuid ) diff --git a/backend/infrahub/core/diff/merger/merger.py b/backend/infrahub/core/diff/merger/merger.py index 02afb698da..fc858cd744 100644 --- a/backend/infrahub/core/diff/merger/merger.py +++ b/backend/infrahub/core/diff/merger/merger.py @@ -31,7 +31,7 @@ def __init__( self.serializer = serializer async def merge_graph(self, at: Timestamp) -> None: - enriched_diffs = await self.diff_repository.get_empty_roots( + enriched_diffs = await self.diff_repository.get_roots_metadata( diff_branch_names=[self.source_branch.name], base_branch_names=[self.destination_branch.name] ) latest_diff = None diff --git a/backend/infrahub/core/diff/model/path.py b/backend/infrahub/core/diff/model/path.py index 22bb5259bd..eb83a10c83 100644 --- a/backend/infrahub/core/diff/model/path.py +++ b/backend/infrahub/core/diff/model/path.py @@ -479,7 +479,7 @@ def get_all_conflicts(self) -> dict[str, EnrichedDiffConflict]: return all_conflicts @classmethod - def from_empty_root(cls, empty_root: EnrichedDiffRootMetadata) -> EnrichedDiffRoot: + def from_root_metadata(cls, empty_root: EnrichedDiffRootMetadata) -> EnrichedDiffRoot: return EnrichedDiffRoot(**asdict(empty_root)) @classmethod diff --git a/backend/infrahub/core/diff/query/empty_roots.py b/backend/infrahub/core/diff/query/empty_roots.py deleted file mode 100644 index 688fc2fb03..0000000000 --- a/backend/infrahub/core/diff/query/empty_roots.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Any, Generator - -from neo4j.graph import Node as Neo4jNode - -from infrahub.core.query import Query, QueryType -from infrahub.core.timestamp import Timestamp -from infrahub.database import InfrahubDatabase - - -class EnrichedDiffEmptyRootsQuery(Query): - name = "enriched_diff_empty_roots" - type = QueryType.READ - - def __init__( - self, - diff_branch_names: list[str] | None = None, - base_branch_names: list[str] | None = None, - from_time: Timestamp | None = None, - to_time: Timestamp | None = None, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - self.diff_branch_names = diff_branch_names - self.base_branch_names = base_branch_names - self.from_time = from_time - self.to_time = to_time - - async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: - self.params = { - "diff_branch_names": self.diff_branch_names, - "base_branch_names": self.base_branch_names, - "from_time": self.from_time.to_string() if self.from_time else None, - "to_time": self.to_time.to_string() if self.to_time else None, - } - - query = """ - MATCH (diff_root:DiffRoot) - WHERE ($diff_branch_names IS NULL OR diff_root.diff_branch IN $diff_branch_names) - AND ($base_branch_names IS NULL OR diff_root.base_branch IN $base_branch_names) - AND ($from_time IS NULL OR diff_root.from_time >= $from_time) - AND ($to_time IS NULL OR diff_root.to_time <= $to_time) - """ - self.return_labels = ["diff_root"] - self.add_to_query(query=query) - - def get_empty_root_nodes(self) -> Generator[Neo4jNode, None, None]: - for result in self.get_results(): - yield result.get_node("diff_root") diff --git a/backend/infrahub/core/diff/repository/deserializer.py b/backend/infrahub/core/diff/repository/deserializer.py index 59df0cabe9..246375dec4 100644 --- a/backend/infrahub/core/diff/repository/deserializer.py +++ b/backend/infrahub/core/diff/repository/deserializer.py @@ -151,13 +151,13 @@ def _deserialize_diff_root(self, root_node: Neo4jNode) -> EnrichedDiffRoot: root_uuid = str(root_node.get("uuid")) if root_uuid in self._diff_root_map: return self._diff_root_map[root_uuid] - root_empty = self.build_diff_root_empty(root_node=root_node) - enriched_root = EnrichedDiffRoot.from_empty_root(empty_root=root_empty) + root_empty = self.build_diff_root_metadata(root_node=root_node) + enriched_root = EnrichedDiffRoot.from_root_metadata(empty_root=root_empty) self._diff_root_map[root_uuid] = enriched_root return enriched_root @classmethod - def build_diff_root_empty(cls, root_node: Neo4jNode) -> EnrichedDiffRootMetadata: + def build_diff_root_metadata(cls, root_node: Neo4jNode) -> EnrichedDiffRootMetadata: from_time = Timestamp(str(root_node.get("from_time"))) to_time = Timestamp(str(root_node.get("to_time"))) tracking_id_str = cls._get_str_or_none_property_value(node=root_node, property_name="tracking_id") diff --git a/backend/infrahub/core/diff/repository/repository.py b/backend/infrahub/core/diff/repository/repository.py index b87d211205..b6c83c064d 100644 --- a/backend/infrahub/core/diff/repository/repository.py +++ b/backend/infrahub/core/diff/repository/repository.py @@ -25,11 +25,11 @@ from ..query.diff_get import EnrichedDiffGetQuery from ..query.diff_summary import DiffSummaryCounters, DiffSummaryQuery from ..query.drop_tracking_id import EnrichedDiffDropTrackingIdQuery -from ..query.empty_roots import EnrichedDiffEmptyRootsQuery from ..query.field_specifiers import EnrichedDiffFieldSpecifiersQuery from ..query.filters import EnrichedDiffQueryFilters from ..query.get_conflict_query import EnrichedDiffConflictQuery from ..query.has_conflicts_query import EnrichedDiffHasConflictQuery +from ..query.roots_metadata import EnrichedDiffRootsMetadataQuery from ..query.save import EnrichedDiffRootsCreateQuery, EnrichedNodeBatchCreateQuery, EnrichedNodesLinkQuery from ..query.time_range_query import EnrichedDiffTimeRangeQuery from ..query.update_conflict_query import EnrichedDiffConflictUpdateQuery @@ -241,7 +241,7 @@ async def get_diff_pairs_metadata( ) -> list[EnrichedDiffsMetadata]: if diff_branch_names and base_branch_names: diff_branch_names += base_branch_names - empty_roots = await self.get_empty_roots( + empty_roots = await self.get_roots_metadata( diff_branch_names=diff_branch_names, base_branch_names=base_branch_names, from_time=from_time, @@ -263,14 +263,14 @@ async def get_diff_pairs_metadata( ) return pairs - async def get_empty_roots( + async def get_roots_metadata( self, diff_branch_names: list[str] | None = None, base_branch_names: list[str] | None = None, from_time: Timestamp | None = None, to_time: Timestamp | None = None, ) -> list[EnrichedDiffRootMetadata]: - query = await EnrichedDiffEmptyRootsQuery.init( + query = await EnrichedDiffRootsMetadataQuery.init( db=self.db, diff_branch_names=diff_branch_names, base_branch_names=base_branch_names, @@ -279,8 +279,8 @@ async def get_empty_roots( ) await query.execute(db=self.db) diff_roots = [] - for neo4j_node in query.get_empty_root_nodes(): - diff_roots.append(self.deserializer.build_diff_root_empty(root_node=neo4j_node)) + for neo4j_node in query.get_root_nodes_metadata(): + diff_roots.append(self.deserializer.build_diff_root_metadata(root_node=neo4j_node)) return diff_roots async def diff_has_conflicts( diff --git a/backend/infrahub/core/diff/tasks.py b/backend/infrahub/core/diff/tasks.py index 7dba986e7c..9301241837 100644 --- a/backend/infrahub/core/diff/tasks.py +++ b/backend/infrahub/core/diff/tasks.py @@ -57,7 +57,7 @@ async def refresh_diff_all(branch_name: str) -> None: component_registry = get_component_registry() default_branch = registry.get_branch_from_registry() diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=default_branch) - diff_roots_to_refresh = await diff_repository.get_empty_roots(diff_branch_names=[branch_name]) + diff_roots_to_refresh = await diff_repository.get_roots_metadata(diff_branch_names=[branch_name]) for diff_root in diff_roots_to_refresh: if diff_root.base_branch_name != diff_root.diff_branch_name: diff --git a/backend/infrahub/core/migrations/graph/m015_diff_format_update.py b/backend/infrahub/core/migrations/graph/m015_diff_format_update.py index 2fa2e772ea..ba8d3d1b2e 100644 --- a/backend/infrahub/core/migrations/graph/m015_diff_format_update.py +++ b/backend/infrahub/core/migrations/graph/m015_diff_format_update.py @@ -31,6 +31,6 @@ async def execute(self, db: InfrahubDatabase) -> MigrationResult: component_registry = get_component_registry() diff_repo = await component_registry.get_component(DiffRepository, db=db, branch=default_branch) - diff_roots = await diff_repo.get_empty_roots() + diff_roots = await diff_repo.get_roots_metadata() await diff_repo.delete_diff_roots(diff_root_uuids=[d.uuid for d in diff_roots]) return MigrationResult() diff --git a/backend/infrahub/core/migrations/graph/m016_diff_delete_bug_fix.py b/backend/infrahub/core/migrations/graph/m016_diff_delete_bug_fix.py index 57a1bd0550..484ca214e8 100644 --- a/backend/infrahub/core/migrations/graph/m016_diff_delete_bug_fix.py +++ b/backend/infrahub/core/migrations/graph/m016_diff_delete_bug_fix.py @@ -31,6 +31,6 @@ async def execute(self, db: InfrahubDatabase) -> MigrationResult: component_registry = get_component_registry() diff_repo = await component_registry.get_component(DiffRepository, db=db, branch=default_branch) - diff_roots = await diff_repo.get_empty_roots() + diff_roots = await diff_repo.get_roots_metadata() await diff_repo.delete_diff_roots(diff_root_uuids=[d.uuid for d in diff_roots]) return MigrationResult() diff --git a/backend/infrahub/message_bus/operations/event/branch.py b/backend/infrahub/message_bus/operations/event/branch.py index 32c5d63657..a501da69a7 100644 --- a/backend/infrahub/message_bus/operations/event/branch.py +++ b/backend/infrahub/message_bus/operations/event/branch.py @@ -29,7 +29,7 @@ async def merge(message: messages.EventBranchMerge, service: InfrahubServices) - default_branch = registry.get_branch_from_registry() diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=default_branch) # send diff update requests for every branch-tracking diff - branch_diff_roots = await diff_repository.get_empty_roots(base_branch_names=[message.target_branch]) + branch_diff_roots = await diff_repository.get_roots_metadata(base_branch_names=[message.target_branch]) await service.workflow.submit_workflow( workflow=TRIGGER_ARTIFACT_DEFINITION_GENERATE, diff --git a/backend/tests/unit/core/diff/test_diff_merger.py b/backend/tests/unit/core/diff/test_diff_merger.py index 89e7b534c8..5d7b060026 100644 --- a/backend/tests/unit/core/diff/test_diff_merger.py +++ b/backend/tests/unit/core/diff/test_diff_merger.py @@ -326,7 +326,7 @@ async def test_merge_node_added( check_idempotent: bool, ): empty_diff_root.nodes = {added_person_node_diff} - mock_diff_repository.get_empty_roots.return_value = [empty_diff_root] + mock_diff_repository.get_roots_metadata.return_value = [empty_diff_root] mock_diff_repository.get_one.return_value = empty_diff_root at = Timestamp() @@ -380,7 +380,7 @@ async def test_merge_node_deleted( person_branch = await NodeManager.get_one(db=db, branch=source_branch, id=person_node_main.id) await person_branch.delete(db=db) empty_diff_root.nodes = {deleted_person_node_diff} - mock_diff_repository.get_empty_roots.return_value = [empty_diff_root] + mock_diff_repository.get_roots_metadata.return_value = [empty_diff_root] mock_diff_repository.get_one.return_value = empty_diff_root at = Timestamp() @@ -433,7 +433,7 @@ async def test_merge_node_deleted_with_conflict( ) deleted_node_diff.conflict = node_conflict empty_diff_root.nodes = {deleted_node_diff} - mock_diff_repository.get_empty_roots.return_value = [empty_diff_root] + mock_diff_repository.get_roots_metadata.return_value = [empty_diff_root] mock_diff_repository.get_one.return_value = empty_diff_root at = Timestamp() @@ -654,7 +654,7 @@ async def test_merge_node_updated( await car_branch.save(db=db) empty_diff_root.nodes = {updated_person_node_diff, updated_car_diff} - mock_diff_repository.get_empty_roots.return_value = [empty_diff_root] + mock_diff_repository.get_roots_metadata.return_value = [empty_diff_root] mock_diff_repository.get_one.return_value = empty_diff_root at = Timestamp() diff --git a/backend/tests/unit/core/ipam/conftest.py b/backend/tests/unit/core/ipam/conftest.py index 9a2776b17d..4efceedb8c 100644 --- a/backend/tests/unit/core/ipam/conftest.py +++ b/backend/tests/unit/core/ipam/conftest.py @@ -191,7 +191,7 @@ async def ip_dataset_01( ): yield ip_dataset_01_load - all_diff_roots = await diff_repository.get_empty_roots() + all_diff_roots = await diff_repository.get_roots_metadata() root_uuids_to_delete = [] for diff_root in all_diff_roots: if start_time <= diff_root.from_time: From 84f4904a1ee0643721ccbefe1b80f5601eb5515a Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Fri, 3 Jan 2025 09:04:29 -0800 Subject: [PATCH 13/18] add changelog --- changelog/+incremental-diff-performance.fixed.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/+incremental-diff-performance.fixed.md diff --git a/changelog/+incremental-diff-performance.fixed.md b/changelog/+incremental-diff-performance.fixed.md new file mode 100644 index 0000000000..005a73a2a5 --- /dev/null +++ b/changelog/+incremental-diff-performance.fixed.md @@ -0,0 +1 @@ +Update how we calculate an incremental diff to skip potentially expensive operations if at all possible \ No newline at end of file From 95a4b32373af8773b557b945e68b8e30b8b2f468 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Fri, 3 Jan 2025 09:07:17 -0800 Subject: [PATCH 14/18] add moved file --- .../core/diff/query/roots_metadata.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 backend/infrahub/core/diff/query/roots_metadata.py diff --git a/backend/infrahub/core/diff/query/roots_metadata.py b/backend/infrahub/core/diff/query/roots_metadata.py new file mode 100644 index 0000000000..0d0d321af7 --- /dev/null +++ b/backend/infrahub/core/diff/query/roots_metadata.py @@ -0,0 +1,48 @@ +from typing import Any, Generator + +from neo4j.graph import Node as Neo4jNode + +from infrahub.core.query import Query, QueryType +from infrahub.core.timestamp import Timestamp +from infrahub.database import InfrahubDatabase + + +class EnrichedDiffRootsMetadataQuery(Query): + name = "enriched_diff_roots_metadata" + type = QueryType.READ + + def __init__( + self, + diff_branch_names: list[str] | None = None, + base_branch_names: list[str] | None = None, + from_time: Timestamp | None = None, + to_time: Timestamp | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.diff_branch_names = diff_branch_names + self.base_branch_names = base_branch_names + self.from_time = from_time + self.to_time = to_time + + async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: + self.params = { + "diff_branch_names": self.diff_branch_names, + "base_branch_names": self.base_branch_names, + "from_time": self.from_time.to_string() if self.from_time else None, + "to_time": self.to_time.to_string() if self.to_time else None, + } + + query = """ + MATCH (diff_root:DiffRoot) + WHERE ($diff_branch_names IS NULL OR diff_root.diff_branch IN $diff_branch_names) + AND ($base_branch_names IS NULL OR diff_root.base_branch IN $base_branch_names) + AND ($from_time IS NULL OR diff_root.from_time >= $from_time) + AND ($to_time IS NULL OR diff_root.to_time <= $to_time) + """ + self.return_labels = ["diff_root"] + self.add_to_query(query=query) + + def get_root_nodes_metadata(self) -> Generator[Neo4jNode, None, None]: + for result in self.get_results(): + yield result.get_node("diff_root") From 6c8557f4d4182767ff1f14cda1b520628970f445 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Fri, 3 Jan 2025 09:32:27 -0800 Subject: [PATCH 15/18] fix mock function call name --- .../tests/unit/message_bus/operations/event/test_branch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/tests/unit/message_bus/operations/event/test_branch.py b/backend/tests/unit/message_bus/operations/event/test_branch.py index a4f3bcc868..52317b109c 100644 --- a/backend/tests/unit/message_bus/operations/event/test_branch.py +++ b/backend/tests/unit/message_bus/operations/event/test_branch.py @@ -71,7 +71,7 @@ async def test_merged(default_branch: Branch, init_service: InfrahubServices, pr for _ in range(2) ] diff_repo = AsyncMock(spec=DiffRepository) - diff_repo.get_empty_roots.return_value = untracked_diff_roots + tracked_diff_roots + diff_repo.get_roots_metadata.return_value = untracked_diff_roots + tracked_diff_roots mock_component_registry = Mock(spec=ComponentDependencyRegistry) mock_get_component_registry = MagicMock(return_value=mock_component_registry) mock_component_registry.get_component.return_value = diff_repo @@ -104,7 +104,7 @@ async def test_merged(default_branch: Branch, init_service: InfrahubServices, pr # Use `db=ANY` as a new InfrahubDatabase object is created as we use a new session mock_component_registry.get_component.assert_awaited_once_with(DiffRepository, db=ANY, branch=default_branch) - diff_repo.get_empty_roots.assert_awaited_once_with(base_branch_names=[target_branch_name]) + diff_repo.get_roots_metadata.assert_awaited_once_with(base_branch_names=[target_branch_name]) assert len(service.message_bus.messages) == 1 assert service.message_bus.messages[0] == messages.RefreshRegistryBranches() From ffa586b8ff8cb9c7f4b2f4244dd9a6fa061db25f Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Wed, 8 Jan 2025 14:09:47 -0800 Subject: [PATCH 16/18] add more diff-related logging, get node field specifiers in batches --- backend/infrahub/core/branch/tasks.py | 1 + backend/infrahub/core/diff/coordinator.py | 15 +++++++++ .../core/diff/query/field_specifiers.py | 20 ++++++------ .../core/diff/repository/repository.py | 31 +++++++++++++------ 4 files changed, 49 insertions(+), 18 deletions(-) diff --git a/backend/infrahub/core/branch/tasks.py b/backend/infrahub/core/branch/tasks.py index 542f332bab..ff5661e61d 100644 --- a/backend/infrahub/core/branch/tasks.py +++ b/backend/infrahub/core/branch/tasks.py @@ -185,6 +185,7 @@ async def merge_branch(branch: str) -> None: try: await merger.merge() except Exception as exc: + log.exception("Merge failed, beginning rollback") await merger.rollback() raise MergeFailedError(branch_name=branch) from exc await merger.update_schema() diff --git a/backend/infrahub/core/diff/coordinator.py b/backend/infrahub/core/diff/coordinator.py index a6e5062491..76119b5d0d 100644 --- a/backend/infrahub/core/diff/coordinator.py +++ b/backend/infrahub/core/diff/coordinator.py @@ -46,6 +46,14 @@ class EnrichedDiffRequest: tracking_id: TrackingId | None = field(default=None) node_field_specifiers: set[NodeFieldSpecifier] = field(default_factory=set) + def __repr__(self) -> str: + return ( + f"EnrichedDiffRequest(base_branch_name={self.base_branch.name}, diff_branch_name={self.diff_branch.name}," + f" from_time={self.from_time.to_string()}, to_time={self.to_time.to_string()}," + f" tracking_id={self.tracking_id.serialize() if self.tracking_id else None})," + f" num_node_field_specifiers={len(self.node_field_specifiers)}" + ) + class DiffCoordinator: lock_namespace = "diff-update" @@ -431,11 +439,13 @@ async def _aggregate_enriched_diffs( else: end_time = diff_request.to_time # if there are no changes on either branch in this time range, then there cannot be a diff + log.info(f"Checking number of changes on branches for {diff_request!r}") num_changes_by_branch = await self.diff_repo.get_num_changes_in_time_range_by_branch( branch_names=[diff_request.base_branch.name, diff_request.diff_branch.name], from_time=current_time, to_time=end_time, ) + log.info(f"Number of changes: {num_changes_by_branch}") might_have_changes_in_time_range = any(num_changes_by_branch.values()) if not might_have_changes_in_time_range: incremental_diffs_and_requests.append(None) @@ -490,9 +500,11 @@ async def _concatenate_diffs_and_requests( for diff_or_request in diff_or_request_list: if isinstance(diff_or_request, EnrichedDiffRequest): if previous_diff_pair: + log.info(f"Getting node field specifiers diff uuid={previous_diff_pair.diff_branch_diff.uuid}") node_field_specifiers = await self.diff_repo.get_node_field_specifiers( diff_id=previous_diff_pair.diff_branch_diff.uuid, ) + log.info(f"Number node field specifiers: {len(node_field_specifiers)}") diff_or_request.node_field_specifiers = node_field_specifiers is_incremental_diff = diff_or_request.from_time != full_diff_request.from_time single_enriched_diffs: EnrichedDiffs | EnrichedDiffsMetadata = await self._calculate_enriched_diff( @@ -539,6 +551,7 @@ async def _update_core_data_checks(self, enriched_diff: EnrichedDiffRoot | Enric async def _calculate_enriched_diff( self, diff_request: EnrichedDiffRequest, is_incremental_diff: bool ) -> EnrichedDiffs: + log.info(f"Calculating diff for {diff_request!r}, include_unchanged={is_incremental_diff}") calculated_diff_pair = await self.diff_calculator.calculate_diff( base_branch=diff_request.base_branch, diff_branch=diff_request.diff_branch, @@ -547,5 +560,7 @@ async def _calculate_enriched_diff( include_unchanged=is_incremental_diff, previous_node_specifiers=diff_request.node_field_specifiers, ) + log.info("Calculation complete. Enriching diff...") enriched_diff_pair = await self.diff_enricher.enrich(calculated_diffs=calculated_diff_pair) + log.info("Enrichment complete") return enriched_diff_pair diff --git a/backend/infrahub/core/diff/query/field_specifiers.py b/backend/infrahub/core/diff/query/field_specifiers.py index 1a4c312539..2325d6d58f 100644 --- a/backend/infrahub/core/diff/query/field_specifiers.py +++ b/backend/infrahub/core/diff/query/field_specifiers.py @@ -7,23 +7,25 @@ class EnrichedDiffFieldSpecifiersQuery(Query): name = "enriched_diff_field_specifiers" type = QueryType.READ - insert_return = False def __init__(self, diff_id: str, **kwargs: Any) -> None: super().__init__(**kwargs) self.diff_id = diff_id async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: - self.params = {"diff_id": self.diff_id} + self.params["diff_id"] = self.diff_id query = """ -MATCH (root:DiffRoot {uuid: $diff_id})-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_ATTRIBUTE]->(attr:DiffAttribute) -RETURN node.uuid AS node_uuid, attr.name AS field_name -UNION -MATCH (root:DiffRoot {uuid: $diff_id})-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_RELATIONSHIP]->(rel:DiffRelationship) -RETURN node.uuid AS node_uuid, rel.identifier AS field_name -""" - self.return_labels = ["node_uuid", "field_name"] +CALL { + MATCH (root:DiffRoot {uuid: $diff_id})-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_ATTRIBUTE]->(attr:DiffAttribute) + RETURN node.uuid AS node_uuid, attr.name AS field_name + UNION + MATCH (root:DiffRoot {uuid: $diff_id})-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_RELATIONSHIP]->(rel:DiffRelationship) + RETURN node.uuid AS node_uuid, rel.identifier AS field_name +} + """ self.add_to_query(query=query) + self.return_labels = ["node_uuid", "field_name"] + self.order_by = ["node_uuid", "field_name"] def get_node_field_specifier_tuples(self) -> Generator[tuple[str, str], None, None]: for result in self.get_results(): diff --git a/backend/infrahub/core/diff/repository/repository.py b/backend/infrahub/core/diff/repository/repository.py index b6c83c064d..82b751cbf0 100644 --- a/backend/infrahub/core/diff/repository/repository.py +++ b/backend/infrahub/core/diff/repository/repository.py @@ -7,6 +7,7 @@ from infrahub.core.timestamp import Timestamp from infrahub.database import InfrahubDatabase, retry_db_transaction from infrahub.exceptions import ResourceNotFoundError +from infrahub.log import get_logger from ..model.path import ( ConflictSelection, @@ -35,6 +36,8 @@ from ..query.update_conflict_query import EnrichedDiffConflictUpdateQuery from .deserializer import EnrichedDiffDeserializer +log = get_logger() + class DiffRepository: MAX_SAVE_BATCH_SIZE: int = 100 @@ -182,6 +185,7 @@ def _get_node_create_request_batch( @retry_db_transaction(name="enriched_diff_save") async def save(self, enriched_diffs: EnrichedDiffs) -> None: + log.info("Saving diff...") root_query = await EnrichedDiffRootsCreateQuery.init(db=self.db, enriched_diffs=enriched_diffs) await root_query.execute(db=self.db) for node_create_batch in self._get_node_create_request_batch(enriched_diffs=enriched_diffs): @@ -189,6 +193,7 @@ async def save(self, enriched_diffs: EnrichedDiffs) -> None: await node_query.execute(db=self.db) link_query = await EnrichedNodesLinkQuery.init(db=self.db, enriched_diffs=enriched_diffs) await link_query.execute(db=self.db) + log.info("Diff saved.") async def summary( self, @@ -334,14 +339,22 @@ async def get_num_changes_in_time_range_by_branch( return query.get_num_changes_by_branch() async def get_node_field_specifiers(self, diff_id: str) -> set[NodeFieldSpecifier]: - query = await EnrichedDiffFieldSpecifiersQuery.init(db=self.db, diff_id=diff_id) - await query.execute(db=self.db) + limit = 5000 + offset = 0 specifiers: set[NodeFieldSpecifier] = set() - specifiers.update( - NodeFieldSpecifier( - node_uuid=field_specifier_tuple[0], - field_name=field_specifier_tuple[1], - ) - for field_specifier_tuple in query.get_node_field_specifier_tuples() - ) + while True: + query = await EnrichedDiffFieldSpecifiersQuery.init(db=self.db, diff_id=diff_id, offset=offset, limit=limit) + await query.execute(db=self.db) + + new_specifiers = { + NodeFieldSpecifier( + node_uuid=field_specifier_tuple[0], + field_name=field_specifier_tuple[1], + ) + for field_specifier_tuple in query.get_node_field_specifier_tuples() + } + if not new_specifiers: + break + specifiers |= new_specifiers + offset += limit return specifiers From 92f4bc1f7cd4590f1bc8b6b674f95966e2f510ec Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Wed, 8 Jan 2025 14:29:12 -0800 Subject: [PATCH 17/18] add missing await in unit tests --- backend/tests/unit/graphql/test_diff_tree_query.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/tests/unit/graphql/test_diff_tree_query.py b/backend/tests/unit/graphql/test_diff_tree_query.py index f65f7e7399..0c6a551cd6 100644 --- a/backend/tests/unit/graphql/test_diff_tree_query.py +++ b/backend/tests/unit/graphql/test_diff_tree_query.py @@ -432,7 +432,7 @@ async def test_diff_tree_one_relationship_change( enriched_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) - params = prepare_graphql_params(db=db, include_mutation=False, include_subscription=False, branch=default_branch) + params = await prepare_graphql_params(db=db, include_mutation=False, include_subscription=False, branch=default_branch) result = await graphql( schema=params.schema, source=DIFF_TREE_QUERY, @@ -827,7 +827,7 @@ async def test_diff_summary_filters( enriched_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) - params = prepare_graphql_params(db=db, include_mutation=False, include_subscription=False, branch=default_branch) + params = await prepare_graphql_params(db=db, include_mutation=False, include_subscription=False, branch=default_branch) result = await graphql( schema=params.schema, From e10dc77c6f5a33af3085594521531192d2438453 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Wed, 8 Jan 2025 14:33:13 -0800 Subject: [PATCH 18/18] formatting --- backend/tests/unit/graphql/test_diff_tree_query.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/backend/tests/unit/graphql/test_diff_tree_query.py b/backend/tests/unit/graphql/test_diff_tree_query.py index 0c6a551cd6..999a2d851f 100644 --- a/backend/tests/unit/graphql/test_diff_tree_query.py +++ b/backend/tests/unit/graphql/test_diff_tree_query.py @@ -432,7 +432,9 @@ async def test_diff_tree_one_relationship_change( enriched_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) - params = await prepare_graphql_params(db=db, include_mutation=False, include_subscription=False, branch=default_branch) + params = await prepare_graphql_params( + db=db, include_mutation=False, include_subscription=False, branch=default_branch + ) result = await graphql( schema=params.schema, source=DIFF_TREE_QUERY, @@ -827,7 +829,9 @@ async def test_diff_summary_filters( enriched_diff = await diff_coordinator.update_branch_diff_and_return( base_branch=default_branch, diff_branch=diff_branch ) - params = await prepare_graphql_params(db=db, include_mutation=False, include_subscription=False, branch=default_branch) + params = await prepare_graphql_params( + db=db, include_mutation=False, include_subscription=False, branch=default_branch + ) result = await graphql( schema=params.schema,