Skip to content

Commit

Permalink
Send node mutation events with rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
ogenstad committed Mar 4, 2025
1 parent 9bcb57c commit 382a43b
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 25 deletions.
71 changes: 58 additions & 13 deletions backend/infrahub/core/branch/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import Any
from uuid import uuid4

import pydantic
from prefect import flow, get_run_logger
Expand All @@ -11,16 +12,17 @@
from infrahub.context import InfrahubContext # noqa: TC001 needed for prefect flow
from infrahub.core import registry
from infrahub.core.branch import Branch
from infrahub.core.changelog.diff import DiffChangelogCollector
from infrahub.core.changelog.diff import DiffChangelogCollector, MigrationTracker
from infrahub.core.constants import MutationAction
from infrahub.core.diff.coordinator import DiffCoordinator
from infrahub.core.diff.ipam_diff_parser import IpamDiffParser
from infrahub.core.diff.merger.merger import DiffMerger
from infrahub.core.diff.model.path import BranchTrackingId
from infrahub.core.diff.model.path import BranchTrackingId, EnrichedDiffRoot, EnrichedDiffRootMetadata
from infrahub.core.diff.repository.repository import DiffRepository
from infrahub.core.merge import BranchMerger
from infrahub.core.migrations.schema.models import SchemaApplyMigrationData
from infrahub.core.migrations.schema.tasks import schema_apply_migrations
from infrahub.core.timestamp import Timestamp
from infrahub.core.validators.determiner import ConstraintValidatorDeterminer
from infrahub.core.validators.models.validate_migration import SchemaValidateMigrationData
from infrahub.core.validators.tasks import schema_validate_migrations
Expand Down Expand Up @@ -54,6 +56,7 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=obj)
diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=obj)
diff_merger = await component_registry.get_component(DiffMerger, db=db, branch=obj)
initial_from_time = Timestamp(obj.get_branched_from())
merger = BranchMerger(
db=db,
diff_coordinator=diff_coordinator,
Expand All @@ -62,7 +65,7 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
source_branch=obj,
service=service,
)
diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=obj)

enriched_diff_metadata = await diff_coordinator.update_branch_diff(base_branch=base_branch, diff_branch=obj)
async for _ in diff_repository.get_all_conflicts_for_diff(
diff_branch_name=enriched_diff_metadata.diff_branch_name, diff_id=enriched_diff_metadata.uuid
Expand Down Expand Up @@ -97,7 +100,7 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
raise ValidationError(",\n".join(error_messages))

schema_in_main_before = merger.destination_schema.duplicate()

migrations = []
async with lock.registry.global_graph_lock():
async with db.start_transaction() as dbt:
await obj.rebase(db=dbt)
Expand Down Expand Up @@ -134,6 +137,14 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
for error in errors:
log.error(error)

default_branch_diff = await _get_diff_root(
diff_coordinator=diff_coordinator,
enriched_diff_metadata=enriched_diff_metadata,
diff_repository=diff_repository,
base_branch=base_branch,
target_from=initial_from_time,
)

# -------------------------------------------------------------
# Trigger the reconciliation of IPAM data after the rebase
# -------------------------------------------------------------
Expand All @@ -156,14 +167,26 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
# -------------------------------------------------------------
# Generate an event to indicate that a branch has been rebased
# -------------------------------------------------------------
# TODO Add account information
await service.event.send(
event=BranchRebasedEvent(
branch_name=obj.name,
branch_id=str(obj.uuid),
meta=EventMeta.from_context(context=context, branch=registry.get_global_branch()),
)
rebase_event = BranchRebasedEvent(
branch_name=obj.name, branch_id=str(obj.uuid), meta=EventMeta(branch=obj, context=context)
)
events: list[InfrahubEvent] = [rebase_event]
changelog_collector = DiffChangelogCollector(
diff=default_branch_diff, branch=obj, db=db, migration_tracker=MigrationTracker(migrations=migrations)
)
for action, node_changelog in changelog_collector.collect_changelogs():
mutate_event = NodeMutatedEvent(
kind=node_changelog.node_kind,
node_id=node_changelog.node_id,
data=node_changelog,
action=MutationAction.from_diff_action(diff_action=action),
fields=node_changelog.updated_fields,
meta=EventMeta.from_parent(parent=rebase_event, branch=obj),
)
events.append(mutate_event)

for event in events:
await service.event.send(event)


@flow(name="branch-merge", flow_run_name="Merge branch {branch} into main")
Expand Down Expand Up @@ -258,7 +281,7 @@ async def merge_branch(branch: str, context: InfrahubContext, service: InfrahubS
events: list[InfrahubEvent] = [merge_event]

for action, node_changelog in node_events:
meta = EventMeta.from_parent(parent=merge_event)
meta = EventMeta.from_parent(parent=merge_event, branch=default_branch)
mutate_event = NodeMutatedEvent(
kind=node_changelog.node_kind,
node_id=node_changelog.node_id,
Expand All @@ -267,7 +290,6 @@ async def merge_branch(branch: str, context: InfrahubContext, service: InfrahubS
fields=node_changelog.updated_fields,
meta=meta,
)
mutate_event.set_context_branch(branch=default_branch)
events.append(mutate_event)

for event in events:
Expand Down Expand Up @@ -364,3 +386,26 @@ async def create_branch(model: BranchCreateModel, context: InfrahubContext, serv
context=context,
parameters={"branch": obj.name, "branch_id": str(obj.uuid)},
)


async def _get_diff_root(
diff_coordinator: DiffCoordinator,
enriched_diff_metadata: EnrichedDiffRootMetadata,
diff_repository: DiffRepository,
base_branch: Branch,
target_from: Timestamp,
) -> EnrichedDiffRoot:
default_branch_diff = await diff_coordinator.create_or_update_arbitrary_timeframe_diff(
base_branch=base_branch,
diff_branch=base_branch,
from_time=target_from,
to_time=enriched_diff_metadata.to_time,
name=str(uuid4()),
)
# make sure we have the actual diff with data and not just the metadata
if not isinstance(default_branch_diff, EnrichedDiffRoot):
default_branch_diff = await diff_repository.get_one(
diff_branch_name=base_branch.name, diff_id=default_branch_diff.uuid
)

return default_branch_diff
46 changes: 42 additions & 4 deletions backend/infrahub/core/changelog/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
EnrichedDiffRelationship,
EnrichedDiffRoot,
)
from infrahub.core.models import SchemaUpdateMigrationInfo
from infrahub.core.schema import MainSchemaTypes
from infrahub.database import InfrahubDatabase

Expand All @@ -38,11 +39,18 @@ class NodeInDiff:


class DiffChangelogCollector:
def __init__(self, diff: EnrichedDiffRoot, branch: Branch, db: InfrahubDatabase) -> None:
def __init__(
self,
diff: EnrichedDiffRoot,
branch: Branch,
db: InfrahubDatabase,
migration_tracker: MigrationTracker | None = None,
) -> None:
self._diff = diff
self._branch = branch
self._db = db
self._diff_nodes: dict[str, NodeInDiff]
self.migration = migration_tracker or MigrationTracker()

def _populate_diff_nodes(self) -> None:
self._diff_nodes = {
Expand Down Expand Up @@ -83,14 +91,16 @@ def _process_node_attribute(
# then we don't have access to the attribute kind
attribute_kind = "n/a"

changelog_attribute = AttributeChangelog(name=attribute.name, kind=attribute_kind)
changelog_attribute = AttributeChangelog(
name=self.migration.get_attribute_name(node=node, attribute=attribute), kind=attribute_kind
)
for attr_property in attribute.properties:
match attr_property.property_type:
case DatabaseEdgeType.HAS_VALUE:
# TODO deserialize correct value type from string
if _keep_branch_update(diff_property=attr_property):
changelog_attribute.value = attr_property.new_value
changelog_attribute.value_previous = attr_property.previous_value
changelog_attribute.set_value(value=attr_property.new_value)
changelog_attribute.set_value_previous(value=attr_property.previous_value)
case DatabaseEdgeType.IS_PROTECTED:
if _keep_branch_update(diff_property=attr_property):
changelog_attribute.add_property(
Expand Down Expand Up @@ -243,3 +253,31 @@ def _keep_branch_update(diff_property: EnrichedDiffProperty) -> bool:
if diff_property.conflict and diff_property.conflict.selected_branch == ConflictSelection.BASE_BRANCH:
return False
return True


class MigrationTracker:
"""Keeps track of schema updates that happened as part of a migration"""

def __init__(self, migrations: list[SchemaUpdateMigrationInfo] | None = None) -> None:
# A dictionary of Node kind, previous attribute name and new attribute
# {"TestPerson": {"old_attribute_name": "new_attribute_name"}}
self._migrations_attribute_map: dict[str, dict[str, str]] = {}

migrations = migrations or []
for migration in migrations:
if migration.migration_name == "attribute.name.update":
if migration.path.schema_kind not in self._migrations_attribute_map:
self._migrations_attribute_map[migration.path.schema_kind] = {}
if migration.path.property_name and migration.path.field_name:
self._migrations_attribute_map[migration.path.schema_kind][migration.path.property_name] = (
migration.path.field_name
)

def get_attribute_name(self, node: NodeChangelog, attribute: EnrichedDiffAttribute) -> str:
"""Return the current name of the requested attribute"""
if node.node_kind not in self._migrations_attribute_map:
return attribute.name
if attribute.name not in self._migrations_attribute_map[node.node_kind]:
return attribute.name

return self._migrations_attribute_map[node.node_kind][attribute.name]
12 changes: 12 additions & 0 deletions backend/infrahub/core/changelog/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ def has_updates(self) -> bool:
return True
return False

def set_value(self, value: Any) -> None:
if isinstance(value, str) and value == NULL_VALUE:
self.value = None
return
self.value = value

def set_value_previous(self, value: Any) -> None:
if isinstance(value, str) and value == NULL_VALUE:
self.value_previous = None
return
self.value_previous = value

@field_validator("value", "value_previous")
@classmethod
def convert_null_values(cls, value: Any) -> Any:
Expand Down
9 changes: 7 additions & 2 deletions backend/infrahub/events/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,25 @@ def with_dummy_context(cls, branch: Branch) -> EventMeta:
)

@classmethod
def from_parent(cls, parent: InfrahubEvent) -> EventMeta:
def from_parent(cls, parent: InfrahubEvent, branch: Branch | None = None) -> EventMeta:
"""Create the metadata from an existing event
Note that this action will modify the existing event to indicate that children might be attached to the event
"""
parent.meta.has_children = True
context = deepcopy(parent.meta.context)
if branch:
context.branch.name = branch.name
context.branch.id = str(branch.get_uuid())

return cls(
parent=parent.meta.id,
branch=parent.meta.branch,
request_id=parent.meta.request_id,
initiator_id=parent.meta.initiator_id,
account_id=parent.meta.account_id,
level=parent.meta.level + 1,
context=deepcopy(parent.meta.context),
context=context,
ancestors=[ParentEvent(id=parent.get_id(), name=parent.get_name())] + parent.meta.ancestors,
)

Expand Down
5 changes: 0 additions & 5 deletions backend/infrahub/events/node_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from pydantic import Field, computed_field

from infrahub.core.branch import Branch
from infrahub.core.changelog.models import NodeChangelog
from infrahub.core.constants import MutationAction
from infrahub.message_bus import InfrahubMessage
Expand Down Expand Up @@ -95,10 +94,6 @@ def get_messages(self) -> list[InfrahubMessage]:
# )
]

def set_context_branch(self, branch: Branch) -> None:
self.meta.context.branch.id = str(branch.get_uuid())
self.meta.context.branch.name = branch.name


class NodeCreatedEvent(NodeMutatedEvent):
action: MutationAction = MutationAction.CREATED
Expand Down
3 changes: 3 additions & 0 deletions backend/infrahub/graphql/types/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class EventTypeFilter(InputObjectType):
branch_merged = Field(
BranchEventTypeFilter, required=False, description="Filters specific to infrahub.branch.merged events"
)
branch_rebased = Field(
BranchEventTypeFilter, required=False, description="Filters specific to infrahub.branch.rebased events"
)


# ---------------------------------------
Expand Down
7 changes: 7 additions & 0 deletions backend/infrahub/task_manager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ def add_event_type_filter(
if branches:
self.resource = EventResourceFilter(labels=ResourceSpecification({"infrahub.branch.name": branches}))

if branch_rebased := event_type_filter.get("branch_rebased"):
branches = branch_rebased.get("branches") or []
if "infrahub.branch.created" not in event_type:
event_type.append("infrahub.branch.rebased")
if branches:
self.resource = EventResourceFilter(labels=ResourceSpecification({"infrahub.branch.name": branches}))

if event_type:
self.event = EventNameFilter(name=event_type)

Expand Down
Loading

0 comments on commit 382a43b

Please sign in to comment.