Skip to content

Commit

Permalink
update repository tests for counts
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtmccarty committed Feb 14, 2025
1 parent 71b1963 commit 84792a6
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 32 deletions.
9 changes: 5 additions & 4 deletions backend/infrahub/core/diff/repository/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _get_node_create_request_batch(
yield node_requests

@retry_db_transaction(name="enriched_diff_save")
async def save(self, enriched_diffs: EnrichedDiffs | EnrichedDiffsMetadata) -> None:
async def save(self, enriched_diffs: EnrichedDiffs | EnrichedDiffsMetadata, do_summary_counts: bool = True) -> None:
log.info("Updating diff metadata...")
root_query = await EnrichedDiffRootsUpsertQuery.init(db=self.db, enriched_diffs=enriched_diffs)
await root_query.execute(db=self.db)
Expand All @@ -240,9 +240,10 @@ async def save(self, enriched_diffs: EnrichedDiffs | EnrichedDiffsMetadata) -> N
link_query = await EnrichedNodesLinkQuery.init(db=self.db, enriched_diffs=enriched_diffs)
await link_query.execute(db=self.db)
log.info("Diff saved.")
await self.add_summary_counts(
diff_branch_name=enriched_diffs.diff_branch_name, diff_id=enriched_diffs.diff_branch_diff.uuid
)
if do_summary_counts:
await self.add_summary_counts(
diff_branch_name=enriched_diffs.diff_branch_name, diff_id=enriched_diffs.diff_branch_diff.uuid
)

async def summary(
self,
Expand Down
4 changes: 2 additions & 2 deletions backend/tests/unit/core/diff/repository/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _build_nodes(self, num_nodes: int, num_sub_fields: int) -> set[EnrichedDiffN
return all_nodes

async def _save_single_diff(
self, diff_repository: DiffRepository, enriched_diff: EnrichedDiffRoot
self, diff_repository: DiffRepository, enriched_diff: EnrichedDiffRoot, do_summary_counts: bool = True
) -> EnrichedDiffs:
base_diff = EnrichedRootFactory.build(
base_branch_name=enriched_diff.base_branch_name,
Expand All @@ -106,5 +106,5 @@ async def _save_single_diff(
diff_branch_diff=enriched_diff,
base_branch_diff=base_diff,
)
await diff_repository.save(enriched_diffs=enriched_diffs)
await diff_repository.save(enriched_diffs=enriched_diffs, do_summary_counts=do_summary_counts)
return enriched_diffs
64 changes: 45 additions & 19 deletions backend/tests/unit/core/diff/repository/test_diff_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ async def test_save_and_retrieve(self, diff_repository: DiffRepository, reset_da
tracking_id=NameTrackingId(name="the-best-diff"),
)

await self._save_single_diff(diff_repository=diff_repository, enriched_diff=enriched_diff)
await self._save_single_diff(
diff_repository=diff_repository, enriched_diff=enriched_diff, do_summary_counts=False
)

retrieved = await diff_repository.get(
base_branch_name=self.base_branch_name,
Expand Down Expand Up @@ -113,7 +115,7 @@ async def test_save_and_retrieve_large_diff(self, diff_repository: DiffRepositor
diff_branch_diff=enriched_branch_diff,
)

await diff_repository.save(enriched_diffs=enriched_diffs)
await diff_repository.save(enriched_diffs=enriched_diffs, do_summary_counts=False)

retrieved = await diff_repository.get_pairs(
base_branch_name=self.base_branch_name,
Expand All @@ -136,7 +138,9 @@ async def test_base_branch_name_filter(self, diff_repository: DiffRepository, re
uuid=root_uuid,
nodes={EnrichedNodeFactory.build(relationships={})},
)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=enriched_diff)
await self._save_single_diff(
diff_repository=diff_repository, enriched_diff=enriched_diff, do_summary_counts=False
)

retrieved = await diff_repository.get(
base_branch_name=self.base_branch_name,
Expand Down Expand Up @@ -166,7 +170,9 @@ async def test_diff_branch_name_filter(self, diff_repository: DiffRepository, re
uuid=root_uuid,
nodes={EnrichedNodeFactory.build(relationships={})},
)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=enriched_diff)
await self._save_single_diff(
diff_repository=diff_repository, enriched_diff=enriched_diff, do_summary_counts=False
)

start_time = DateTime.create(2024, 6, 15, 18, 35, 20, tz=UTC)
end_time = start_time.add(months=1)
Expand Down Expand Up @@ -200,7 +206,9 @@ async def test_filter_time_ranges(self, diff_repository: DiffRepository, reset_d
uuid=root_uuid,
nodes={EnrichedNodeFactory.build(relationships={})},
)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=enriched_diff)
await self._save_single_diff(
diff_repository=diff_repository, enriched_diff=enriched_diff, do_summary_counts=False
)

# both before
retrieved = await diff_repository.get(
Expand Down Expand Up @@ -264,7 +272,9 @@ async def test_filter_root_node_uuids(self, diff_repository: DiffRepository, res
nodes=nodes,
)
enriched_diffs.append(enriched_diff)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=enriched_diff)
await self._save_single_diff(
diff_repository=diff_repository, enriched_diff=enriched_diff, do_summary_counts=False
)

parent_node = EnrichedNodeFactory.build()
middle_parent_rel = EnrichedRelationshipGroupFactory.build(nodes={parent_node})
Expand All @@ -281,7 +291,7 @@ async def test_filter_root_node_uuids(self, diff_repository: DiffRepository, res
to_time=Timestamp(self.diff_to_time),
nodes=other_nodes | {parent_node, middle_node, leaf_node},
)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=this_diff)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=this_diff, do_summary_counts=False)
diff_branch_names = [e.diff_branch_name for e in enriched_diffs] + ["diff"]

# get parent node
Expand Down Expand Up @@ -400,7 +410,9 @@ async def test_save_and_retrieve_many_diffs(self, diff_repository: DiffRepositor
to_time=Timestamp(start_time.add(minutes=(i * 30) + 29)),
nodes=nodes,
)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=enriched_diff)
await self._save_single_diff(
diff_repository=diff_repository, enriched_diff=enriched_diff, do_summary_counts=False
)
diffs_to_retrieve.append(enriched_diff)
for i in range(5):
nodes = self._build_nodes(num_nodes=3, num_sub_fields=2)
Expand All @@ -411,7 +423,9 @@ async def test_save_and_retrieve_many_diffs(self, diff_repository: DiffRepositor
to_time=Timestamp(start_time.add(days=3, minutes=(i * 30) + 29)),
nodes=nodes,
)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=enriched_diff)
await self._save_single_diff(
diff_repository=diff_repository, enriched_diff=enriched_diff, do_summary_counts=False
)

retrieved = await diff_repository.get(
base_branch_name=self.base_branch_name,
Expand All @@ -434,7 +448,9 @@ async def test_delete_diff_by_uuid(self, diff_repository: DiffRepository, reset_
to_time=Timestamp(start_time.add(minutes=(i * 30) + 29)),
nodes=nodes,
)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=enriched_diff)
await self._save_single_diff(
diff_repository=diff_repository, enriched_diff=enriched_diff, do_summary_counts=False
)
diffs.append(enriched_diff)

diff_to_delete = diffs.pop()
Expand Down Expand Up @@ -464,7 +480,9 @@ async def test_get_by_tracking_id(self, diff_repository: DiffRepository, reset_d
to_time=Timestamp(end_time.add(minutes=(i * 30) + 29)),
nodes=nodes,
)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=enriched_diff)
await self._save_single_diff(
diff_repository=diff_repository, enriched_diff=enriched_diff, do_summary_counts=False
)
nodes = self._build_nodes(num_nodes=2, num_sub_fields=2)
branch_tracked_diff = EnrichedRootFactory.build(
base_branch_name=self.base_branch_name,
Expand All @@ -474,7 +492,9 @@ async def test_get_by_tracking_id(self, diff_repository: DiffRepository, reset_d
nodes=nodes,
tracking_id=branch_tracking_id,
)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=branch_tracked_diff)
await self._save_single_diff(
diff_repository=diff_repository, enriched_diff=branch_tracked_diff, do_summary_counts=False
)
name_tracked_diff = EnrichedRootFactory.build(
base_branch_name=self.base_branch_name,
diff_branch_name=self.diff_branch_name,
Expand All @@ -483,7 +503,9 @@ async def test_get_by_tracking_id(self, diff_repository: DiffRepository, reset_d
nodes=nodes,
tracking_id=name_tracking_id,
)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=name_tracked_diff)
await self._save_single_diff(
diff_repository=diff_repository, enriched_diff=name_tracked_diff, do_summary_counts=False
)

retrieved_branch_diff = await diff_repository.get_one(
tracking_id=branch_tracking_id,
Expand Down Expand Up @@ -518,7 +540,7 @@ async def test_get_node_field_summaries(self, diff_repository: DiffRepository):
diff_nodes.add(same_kind_diff_node)
diff_root = EnrichedRootFactory.build(nodes=diff_nodes)
diff_root.tracking_id = BranchTrackingId(name=diff_root.diff_branch_name)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=diff_root)
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=diff_root, do_summary_counts=False)

expected_map: dict[str, NodeDiffFieldSummary] = {}
for node in diff_root.nodes:
Expand Down Expand Up @@ -550,11 +572,15 @@ async def test_merge_tracking_ids(self, diff_repository: DiffRepository, reset_d
tracking_id_diff_1 = EnrichedRootFactory.build(base_branch_name=base_branch_name)
tracking_id_1 = BranchTrackingId(name=tracking_id_diff_1.diff_branch_name)
tracking_id_diff_1.tracking_id = tracking_id_1
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=tracking_id_diff_1)
await self._save_single_diff(
diff_repository=diff_repository, enriched_diff=tracking_id_diff_1, do_summary_counts=False
)
tracking_id_diff_2 = EnrichedRootFactory.build(base_branch_name=base_branch_name)
tracking_id_2 = BranchTrackingId(name=tracking_id_diff_2.diff_branch_name)
tracking_id_diff_2.tracking_id = tracking_id_2
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=tracking_id_diff_2)
await self._save_single_diff(
diff_repository=diff_repository, enriched_diff=tracking_id_diff_2, do_summary_counts=False
)

await diff_repository.mark_tracking_ids_merged(tracking_ids=[tracking_id_1])

Expand Down Expand Up @@ -610,7 +636,7 @@ async def test_limit_and_offset(self, diff_repository: DiffRepository, reset_dat
diff_branch_diff=enriched_branch_diff,
)

await diff_repository.save(enriched_diffs=enriched_diffs)
await diff_repository.save(enriched_diffs=enriched_diffs, do_summary_counts=False)

# validate limit
retrieved = await diff_repository.get(
Expand Down Expand Up @@ -712,7 +738,7 @@ async def test_update_existing(self, db: InfrahubDatabase, diff_repository: Diff
diff_branch_diff=enriched_diff,
base_branch_diff=base_diff,
)
await diff_repository.save(enriched_diffs=enriched_diffs)
await diff_repository.save(enriched_diffs=enriched_diffs, do_summary_counts=False)

# removed node conflict
node_with_removes.conflict = None
Expand Down Expand Up @@ -797,7 +823,7 @@ async def test_update_existing(self, db: InfrahubDatabase, diff_repository: Diff
# update relationship element property conflict
updated_element_property.conflict.diff_branch_value = "DIFF SOMETHING"

await diff_repository.save(enriched_diffs=enriched_diffs)
await diff_repository.save(enriched_diffs=enriched_diffs, do_summary_counts=False)

retrieved = await diff_repository.get(
base_branch_name=self.base_branch_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ async def __save_and_update_diff(
self, diff_repository: DiffRepository, enriched_diff: EnrichedDiffRoot
) -> EnrichedDiffRoot:
await self._save_single_diff(diff_repository=diff_repository, enriched_diff=enriched_diff)
await diff_repository.add_summary_counts(
diff_branch_name=enriched_diff.diff_branch_name, diff_id=enriched_diff.uuid
)
return await diff_repository.get_one(
diff_branch_name=enriched_diff.diff_branch_name, diff_id=enriched_diff.uuid
)
Expand Down Expand Up @@ -182,7 +179,7 @@ async def test_existing_node_with_changes(self, diff_repository: DiffRepository)
self._set_conflicts(diff_node=updated_node, conflict_chance=0.5)
enriched_diffs.diff_branch_diff.nodes = {updated_node}
# set the counts again
await diff_repository.save(enriched_diffs=enriched_diffs)
await diff_repository.save(enriched_diffs=enriched_diffs, do_summary_counts=False)
await diff_repository.add_summary_counts(
diff_branch_name=diff_root.diff_branch_name, diff_id=diff_root.uuid, node_uuids=[node_to_update.uuid]
)
Expand Down Expand Up @@ -213,8 +210,6 @@ async def test_existing_node_with_changes_and_parents(self, diff_repository: Dif
base_branch_name=self.base_branch_name, diff_branch_name=self.diff_branch_name, nodes=diff_nodes
)
enriched_diffs = await self._save_single_diff(diff_repository=diff_repository, enriched_diff=diff_root)
# set the counts for this diff
await diff_repository.add_summary_counts(diff_branch_name=diff_root.diff_branch_name, diff_id=diff_root.uuid)
# make some changes to nodes with parents
nodes_to_update = set()
action_choices = list(DiffAction)
Expand All @@ -237,7 +232,7 @@ async def test_existing_node_with_changes_and_parents(self, diff_repository: Dif

enriched_diffs.diff_branch_diff.nodes = nodes_to_update
# set the counts again
await diff_repository.save(enriched_diffs=enriched_diffs)
await diff_repository.save(enriched_diffs=enriched_diffs, do_summary_counts=False)
await diff_repository.add_summary_counts(
diff_branch_name=diff_root.diff_branch_name,
diff_id=diff_root.uuid,
Expand Down

0 comments on commit 84792a6

Please sign in to comment.