Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add test and changes for a relationship identifier migration #5652

Merged
merged 6 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions backend/infrahub/core/diff/enricher/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,20 @@ def _update_relationship_labels(self, enriched_diff: EnrichedDiffRoot) -> None:
for node in enriched_diff.nodes:
if not node.relationships:
continue
node_schema = self.db.schema.get(name=node.kind, branch=self.diff_branch_name, duplicate=False)

node_schema = self.db.schema.get(name=node.kind, branch=enriched_diff.diff_branch_name, duplicate=False)
alternate_node_schema = None
if enriched_diff.diff_branch_name != enriched_diff.base_branch_name and self.db.schema.has(
name=node.kind, branch=enriched_diff.base_branch_name
):
alternate_node_schema = self.db.schema.get(
name=node.kind, branch=enriched_diff.base_branch_name, duplicate=False
)
for relationship_diff in node.relationships:
relationship_schema = node_schema.get_relationship(name=relationship_diff.name)
relationship_diff.label = relationship_schema.label or ""
relationship_schema = node_schema.get_relationship_or_none(name=relationship_diff.name)
if not relationship_schema and alternate_node_schema:
relationship_schema = alternate_node_schema.get_relationship_or_none(name=relationship_diff.name)
relationship_diff.label = relationship_schema.label or "" if relationship_schema else ""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these changes are required to handle the case when a relationship name is updated


async def _get_display_label_map(
self, display_label_requests: set[DisplayLabelRequest]
Expand Down
10 changes: 2 additions & 8 deletions backend/infrahub/core/diff/enricher/path_identifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from infrahub.core.constants import PathType, RelationshipCardinality
from infrahub.core.constants import PathType
from infrahub.core.path import DataPath
from infrahub.database import InfrahubDatabase

Expand Down Expand Up @@ -44,14 +44,8 @@ async def enrich(self, enriched_diff_root: EnrichedDiffRoot, calculated_diffs: C
attribute_property.path_identifier = property_path.get_path()
if not node.relationships:
continue
node_schema = self.db.schema.get(name=node.kind, branch=self.diff_branch_name, duplicate=False)
for relationship in node.relationships:
relationship_schema = node_schema.get_relationship(name=relationship.name)
path_type = (
PathType.RELATIONSHIP_ONE
if relationship_schema.cardinality is RelationshipCardinality.ONE
else PathType.RELATIONSHIP_MANY
)
path_type = PathType.from_relationship(relationship.cardinality)
relationship_path = DataPath(
branch=enriched_diff_root.diff_branch_name,
path_type=path_type,
Expand Down
54 changes: 30 additions & 24 deletions backend/infrahub/core/diff/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
if TYPE_CHECKING:
from infrahub.core.branch import Branch
from infrahub.core.query import QueryResult
from infrahub.core.schema import MainSchemaTypes
from infrahub.core.schema.manager import SchemaManager
from infrahub.core.schema.relationship_schema import RelationshipSchema

Expand Down Expand Up @@ -401,7 +400,8 @@ class DiffNodeIntermediate(TrackedStatusUpdates):
from_time: Timestamp
status: RelationshipStatus
attributes_by_name: dict[str, DiffAttributeIntermediate] = field(default_factory=dict)
relationships_by_name: dict[str, DiffRelationshipIntermediate] = field(default_factory=dict)
# {(name, identifier): DiffRelationshipIntermediate}
relationships_by_identifier: dict[tuple[str, str], DiffRelationshipIntermediate] = field(default_factory=dict)

def to_diff_node(self, from_time: Timestamp, include_unchanged: bool) -> DiffNode:
attributes = []
Expand All @@ -411,7 +411,7 @@ def to_diff_node(self, from_time: Timestamp, include_unchanged: bool) -> DiffNod
attributes.append(diff_attr)
action, changed_at = self.get_action_and_timestamp(from_time=from_time)
relationships = []
for rel in self.relationships_by_name.values():
for rel in self.relationships_by_identifier.values():
diff_rel = rel.to_diff_relationship(include_unchanged=include_unchanged)
if include_unchanged or diff_rel.action is not DiffAction.UNCHANGED:
relationships.append(diff_rel)
Expand All @@ -434,7 +434,7 @@ def to_diff_node(self, from_time: Timestamp, include_unchanged: bool) -> DiffNod

@property
def is_empty(self) -> bool:
return len(self.attributes_by_name) == 0 and len(self.relationships_by_name) == 0
return len(self.attributes_by_name) == 0 and len(self.relationships_by_identifier) == 0


@dataclass
Expand Down Expand Up @@ -498,7 +498,7 @@ def get_diff_node_field_specifiers(self) -> dict[str, set[str]]:
for node in diff_root.nodes_by_id.values():
for attribute_name in node.attributes_by_name:
node_field_specifiers_map[node.uuid].add(attribute_name)
for relationship_diff in node.relationships_by_name.values():
for relationship_diff in node.relationships_by_identifier.values():
node_field_specifiers_map[node.uuid].add(relationship_diff.identifier)
return node_field_specifiers_map

Expand Down Expand Up @@ -594,27 +594,29 @@ def _get_diff_node(self, database_path: DatabasePath, diff_root: DiffRootInterme
diff_node.track_database_path(database_path=database_path)
return diff_node

def _get_relationship_schema(
self, database_path: DatabasePath, node_schema: MainSchemaTypes
) -> RelationshipSchema | None:
relationship_schemas = node_schema.get_relationships_by_identifier(id=database_path.attribute_name)
if len(relationship_schemas) == 1:
return relationship_schemas[0]
possible_path_directions = database_path.possible_relationship_directions
for rel_schema in relationship_schemas:
if rel_schema.direction in possible_path_directions:
return rel_schema
def _get_relationship_schema(self, database_path: DatabasePath) -> RelationshipSchema | None:
branches_to_check = [database_path.deepest_branch]
if database_path.deepest_branch == self.diff_branch_name:
branches_to_check.append(self.base_branch_name)
for schema_branch_name in branches_to_check:
node_schema = self.schema_manager.get(
name=database_path.node_kind, branch=schema_branch_name, duplicate=False
)
relationship_schemas = node_schema.get_relationships_by_identifier(id=database_path.attribute_name)
if len(relationship_schemas) == 1:
return relationship_schemas[0]
possible_path_directions = database_path.possible_relationship_directions
for rel_schema in relationship_schemas:
if rel_schema.direction in possible_path_directions:
return rel_schema
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have to use both the relationship name and identifier for uniqueness because...

  • an identifier is not unique if the schema has two different relationships for inbound/outbound with the same identifier, like in the case of parent and children in a hierarchy
  • a name is not unique if the relationship has been renamed

return None

def _update_attribute_level(self, database_path: DatabasePath, diff_node: DiffNodeIntermediate) -> None:
node_schema = self.schema_manager.get(
name=database_path.node_kind, branch=database_path.deepest_branch, duplicate=False
)
if "Attribute" in database_path.attribute_node.labels:
diff_attribute = self._get_diff_attribute(database_path=database_path, diff_node=diff_node)
self._update_attribute_property(database_path=database_path, diff_attribute=diff_attribute)
return
relationship_schema = self._get_relationship_schema(database_path=database_path, node_schema=node_schema)
relationship_schema = self._get_relationship_schema(database_path=database_path)
if not relationship_schema:
return
diff_relationship = self._get_diff_relationship(
Expand Down Expand Up @@ -668,7 +670,9 @@ def _get_diff_relationship(
relationship_schema: RelationshipSchema,
database_path: DatabasePath,
) -> DiffRelationshipIntermediate:
diff_relationship = diff_node.relationships_by_name.get(relationship_schema.name)
diff_relationship = diff_node.relationships_by_identifier.get(
(relationship_schema.name, relationship_schema.get_identifier())
)
if not diff_relationship:
branch_name = database_path.deepest_branch
from_time = self.from_time
Expand All @@ -682,7 +686,9 @@ def _get_diff_relationship(
identifier=relationship_schema.get_identifier(),
from_time=from_time,
)
diff_node.relationships_by_name[relationship_schema.name] = diff_relationship
diff_node.relationships_by_identifier[relationship_schema.name, relationship_schema.get_identifier()] = (
diff_relationship
)
return diff_relationship

def _apply_base_branch_previous_values(self) -> None:
Expand Down Expand Up @@ -719,8 +725,8 @@ def _apply_attribute_previous_values(
def _apply_relationship_previous_values(
self, diff_node: DiffNodeIntermediate, base_diff_node: DiffNodeIntermediate
) -> None:
for relationship_name, diff_relationship in diff_node.relationships_by_name.items():
base_diff_relationship = base_diff_node.relationships_by_name.get(relationship_name)
for relationship_key, diff_relationship in diff_node.relationships_by_identifier.items():
base_diff_relationship = base_diff_node.relationships_by_identifier.get(relationship_key)
if not base_diff_relationship:
continue
for db_id, property_set in diff_relationship.properties_by_db_id.items():
Expand Down Expand Up @@ -773,7 +779,7 @@ def _remove_empty_base_diff_root(self) -> None:
continue
if ordered_diff_values[-1].changed_at >= self.diff_branched_from_time:
return
for relationship_diff in node_diff.relationships_by_name.values():
for relationship_diff in node_diff.relationships_by_identifier.values():
for diff_relationship_property_list in relationship_diff.properties_by_db_id.values():
for diff_relationship_property in diff_relationship_property_list:
if diff_relationship_property.changed_at >= self.diff_branched_from_time:
Expand Down
Loading
Loading