Skip to content

Commit

Permalink
add test and changes for a relationship identifier migration (#5652)
Browse files Browse the repository at this point in the history
* add test and changes for a relationship identifier migration

* mypy fix

* fix handling to be unique on name AND identifier

* look before you leap

* use PathType.from_relationship

* reset python_sdk commit
  • Loading branch information
ajtmccarty authored Feb 4, 2025
1 parent 08f4269 commit fb0bdc2
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 37 deletions.
16 changes: 13 additions & 3 deletions backend/infrahub/core/diff/enricher/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,20 @@ def _update_relationship_labels(self, enriched_diff: EnrichedDiffRoot) -> None:
for node in enriched_diff.nodes:
if not node.relationships:
continue
node_schema = self.db.schema.get(name=node.kind, branch=self.diff_branch_name, duplicate=False)

node_schema = self.db.schema.get(name=node.kind, branch=enriched_diff.diff_branch_name, duplicate=False)
alternate_node_schema = None
if enriched_diff.diff_branch_name != enriched_diff.base_branch_name and self.db.schema.has(
name=node.kind, branch=enriched_diff.base_branch_name
):
alternate_node_schema = self.db.schema.get(
name=node.kind, branch=enriched_diff.base_branch_name, duplicate=False
)
for relationship_diff in node.relationships:
relationship_schema = node_schema.get_relationship(name=relationship_diff.name)
relationship_diff.label = relationship_schema.label or ""
relationship_schema = node_schema.get_relationship_or_none(name=relationship_diff.name)
if not relationship_schema and alternate_node_schema:
relationship_schema = alternate_node_schema.get_relationship_or_none(name=relationship_diff.name)
relationship_diff.label = relationship_schema.label or "" if relationship_schema else ""

async def _get_display_label_map(
self, display_label_requests: set[DisplayLabelRequest]
Expand Down
10 changes: 2 additions & 8 deletions backend/infrahub/core/diff/enricher/path_identifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from infrahub.core.constants import PathType, RelationshipCardinality
from infrahub.core.constants import PathType
from infrahub.core.path import DataPath
from infrahub.database import InfrahubDatabase

Expand Down Expand Up @@ -44,14 +44,8 @@ async def enrich(self, enriched_diff_root: EnrichedDiffRoot, calculated_diffs: C
attribute_property.path_identifier = property_path.get_path()
if not node.relationships:
continue
node_schema = self.db.schema.get(name=node.kind, branch=self.diff_branch_name, duplicate=False)
for relationship in node.relationships:
relationship_schema = node_schema.get_relationship(name=relationship.name)
path_type = (
PathType.RELATIONSHIP_ONE
if relationship_schema.cardinality is RelationshipCardinality.ONE
else PathType.RELATIONSHIP_MANY
)
path_type = PathType.from_relationship(relationship.cardinality)
relationship_path = DataPath(
branch=enriched_diff_root.diff_branch_name,
path_type=path_type,
Expand Down
54 changes: 30 additions & 24 deletions backend/infrahub/core/diff/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
if TYPE_CHECKING:
from infrahub.core.branch import Branch
from infrahub.core.query import QueryResult
from infrahub.core.schema import MainSchemaTypes
from infrahub.core.schema.manager import SchemaManager
from infrahub.core.schema.relationship_schema import RelationshipSchema

Expand Down Expand Up @@ -401,7 +400,8 @@ class DiffNodeIntermediate(TrackedStatusUpdates):
from_time: Timestamp
status: RelationshipStatus
attributes_by_name: dict[str, DiffAttributeIntermediate] = field(default_factory=dict)
relationships_by_name: dict[str, DiffRelationshipIntermediate] = field(default_factory=dict)
# {(name, identifier): DiffRelationshipIntermediate}
relationships_by_identifier: dict[tuple[str, str], DiffRelationshipIntermediate] = field(default_factory=dict)

def to_diff_node(self, from_time: Timestamp, include_unchanged: bool) -> DiffNode:
attributes = []
Expand All @@ -411,7 +411,7 @@ def to_diff_node(self, from_time: Timestamp, include_unchanged: bool) -> DiffNod
attributes.append(diff_attr)
action, changed_at = self.get_action_and_timestamp(from_time=from_time)
relationships = []
for rel in self.relationships_by_name.values():
for rel in self.relationships_by_identifier.values():
diff_rel = rel.to_diff_relationship(include_unchanged=include_unchanged)
if include_unchanged or diff_rel.action is not DiffAction.UNCHANGED:
relationships.append(diff_rel)
Expand All @@ -434,7 +434,7 @@ def to_diff_node(self, from_time: Timestamp, include_unchanged: bool) -> DiffNod

@property
def is_empty(self) -> bool:
return len(self.attributes_by_name) == 0 and len(self.relationships_by_name) == 0
return len(self.attributes_by_name) == 0 and len(self.relationships_by_identifier) == 0


@dataclass
Expand Down Expand Up @@ -498,7 +498,7 @@ def get_diff_node_field_specifiers(self) -> dict[str, set[str]]:
for node in diff_root.nodes_by_id.values():
for attribute_name in node.attributes_by_name:
node_field_specifiers_map[node.uuid].add(attribute_name)
for relationship_diff in node.relationships_by_name.values():
for relationship_diff in node.relationships_by_identifier.values():
node_field_specifiers_map[node.uuid].add(relationship_diff.identifier)
return node_field_specifiers_map

Expand Down Expand Up @@ -594,27 +594,29 @@ def _get_diff_node(self, database_path: DatabasePath, diff_root: DiffRootInterme
diff_node.track_database_path(database_path=database_path)
return diff_node

def _get_relationship_schema(
self, database_path: DatabasePath, node_schema: MainSchemaTypes
) -> RelationshipSchema | None:
relationship_schemas = node_schema.get_relationships_by_identifier(id=database_path.attribute_name)
if len(relationship_schemas) == 1:
return relationship_schemas[0]
possible_path_directions = database_path.possible_relationship_directions
for rel_schema in relationship_schemas:
if rel_schema.direction in possible_path_directions:
return rel_schema
def _get_relationship_schema(self, database_path: DatabasePath) -> RelationshipSchema | None:
branches_to_check = [database_path.deepest_branch]
if database_path.deepest_branch == self.diff_branch_name:
branches_to_check.append(self.base_branch_name)
for schema_branch_name in branches_to_check:
node_schema = self.schema_manager.get(
name=database_path.node_kind, branch=schema_branch_name, duplicate=False
)
relationship_schemas = node_schema.get_relationships_by_identifier(id=database_path.attribute_name)
if len(relationship_schemas) == 1:
return relationship_schemas[0]
possible_path_directions = database_path.possible_relationship_directions
for rel_schema in relationship_schemas:
if rel_schema.direction in possible_path_directions:
return rel_schema
return None

def _update_attribute_level(self, database_path: DatabasePath, diff_node: DiffNodeIntermediate) -> None:
node_schema = self.schema_manager.get(
name=database_path.node_kind, branch=database_path.deepest_branch, duplicate=False
)
if "Attribute" in database_path.attribute_node.labels:
diff_attribute = self._get_diff_attribute(database_path=database_path, diff_node=diff_node)
self._update_attribute_property(database_path=database_path, diff_attribute=diff_attribute)
return
relationship_schema = self._get_relationship_schema(database_path=database_path, node_schema=node_schema)
relationship_schema = self._get_relationship_schema(database_path=database_path)
if not relationship_schema:
return
diff_relationship = self._get_diff_relationship(
Expand Down Expand Up @@ -668,7 +670,9 @@ def _get_diff_relationship(
relationship_schema: RelationshipSchema,
database_path: DatabasePath,
) -> DiffRelationshipIntermediate:
diff_relationship = diff_node.relationships_by_name.get(relationship_schema.name)
diff_relationship = diff_node.relationships_by_identifier.get(
(relationship_schema.name, relationship_schema.get_identifier())
)
if not diff_relationship:
branch_name = database_path.deepest_branch
from_time = self.from_time
Expand All @@ -682,7 +686,9 @@ def _get_diff_relationship(
identifier=relationship_schema.get_identifier(),
from_time=from_time,
)
diff_node.relationships_by_name[relationship_schema.name] = diff_relationship
diff_node.relationships_by_identifier[relationship_schema.name, relationship_schema.get_identifier()] = (
diff_relationship
)
return diff_relationship

def _apply_base_branch_previous_values(self) -> None:
Expand Down Expand Up @@ -719,8 +725,8 @@ def _apply_attribute_previous_values(
def _apply_relationship_previous_values(
self, diff_node: DiffNodeIntermediate, base_diff_node: DiffNodeIntermediate
) -> None:
for relationship_name, diff_relationship in diff_node.relationships_by_name.items():
base_diff_relationship = base_diff_node.relationships_by_name.get(relationship_name)
for relationship_key, diff_relationship in diff_node.relationships_by_identifier.items():
base_diff_relationship = base_diff_node.relationships_by_identifier.get(relationship_key)
if not base_diff_relationship:
continue
for db_id, property_set in diff_relationship.properties_by_db_id.items():
Expand Down Expand Up @@ -773,7 +779,7 @@ def _remove_empty_base_diff_root(self) -> None:
continue
if ordered_diff_values[-1].changed_at >= self.diff_branched_from_time:
return
for relationship_diff in node_diff.relationships_by_name.values():
for relationship_diff in node_diff.relationships_by_identifier.values():
for diff_relationship_property_list in relationship_diff.properties_by_db_id.values():
for diff_relationship_property in diff_relationship_property_list:
if diff_relationship_property.changed_at >= self.diff_branched_from_time:
Expand Down
Loading

0 comments on commit fb0bdc2

Please sign in to comment.