diff --git a/backend/infrahub/core/query/diff.py b/backend/infrahub/core/query/diff.py index 8a26feb13c..5b71330e44 100644 --- a/backend/infrahub/core/query/diff.py +++ b/backend/infrahub/core/query/diff.py @@ -528,14 +528,12 @@ def __init__( self, base_branch: Branch, diff_branch_from_time: Timestamp, - branch_support: list[BranchSupportType] | None = None, current_node_field_specifiers: list[tuple[str, str]] | None = None, new_node_field_specifiers: list[tuple[str, str]] | None = None, **kwargs: Any, ): self.base_branch = base_branch self.diff_branch_from_time = diff_branch_from_time - self.branch_support = branch_support or [BranchSupportType.AWARE] self.current_node_field_specifiers = current_node_field_specifiers self.new_node_field_specifiers = new_node_field_specifiers @@ -551,7 +549,9 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: "branch_from_time": self.diff_branch_from_time.to_string(), "from_time": from_str, "to_time": self.diff_to.to_string(), - "branch_support": [item.value for item in self.branch_support], + "branch_local": BranchSupportType.LOCAL.value, + "branch_aware": BranchSupportType.AWARE.value, + "branch_agnostic": BranchSupportType.AGNOSTIC.value, "new_node_field_specifiers": self.new_node_field_specifiers, "current_node_field_specifiers": self.current_node_field_specifiers, } @@ -577,7 +577,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: WHERE (node_ids_list IS NULL OR p.uuid IN node_ids_list) AND (from_time <= diff_rel.from < $to_time) AND (diff_rel.to IS NULL OR (from_time <= diff_rel.to < $to_time)) - AND (p.branch_support IN $branch_support OR q.branch_support IN $branch_support) + AND p.branch_support = $branch_aware WITH p, q, diff_rel, from_time // ------------------------------------- // Exclude nodes added then removed on branch within timeframe @@ -603,6 +603,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: WHERE %(id_func)s(diff_rel) = %(id_func)s(top_diff_rel) AND type(r_node) IN ["HAS_ATTRIBUTE", "IS_RELATED"] AND any(l in labels(node) WHERE l in ["Attribute", "Relationship"]) + AND node.branch_support IN [$branch_aware, $branch_agnostic] AND type(r_prop) IN ["IS_VISIBLE", "IS_PROTECTED", "HAS_SOURCE", "HAS_OWNER", "HAS_VALUE", "IS_RELATED"] AND any(l in labels(prop) WHERE l in ["Boolean", "Node", "AttributeValue"]) AND ALL( @@ -655,9 +656,9 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: // exclude attributes and relationships under added/removed nodes b/c they are covered above WHERE (node_field_specifiers_list IS NULL OR [p.uuid, q.name] IN node_field_specifiers_list) AND r_root.branch IN [$branch_name, $base_branch_name, $global_branch_name] - AND (p.branch_support IN $branch_support OR q.branch_support IN $branch_support) + AND q.branch_support = $branch_aware // if p has a different type of branch support and was addded within our timeframe - AND (r_root.from < from_time OR NOT (p.branch_support IN $branch_support)) + AND (r_root.from < from_time OR p.branch_support = $branch_agnostic) AND r_root.status = "active" // get attributes and relationships added on the branch during the timeframe AND (from_time <= diff_rel.from < $to_time) @@ -671,9 +672,9 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: // exclude attributes and relationships under added/removed nodes b/c they are covered above WHERE (node_field_specifiers_list IS NULL OR [p.uuid, q.name] IN node_field_specifiers_list) AND r_root.branch IN [$branch_name, $base_branch_name, $global_branch_name] - AND (p.branch_support IN $branch_support OR q.branch_support IN $branch_support) + AND q.branch_support = $branch_aware // if p has a different type of branch support and was addded within our timeframe - AND (r_root.from < from_time OR NOT (p.branch_support IN $branch_support)) + AND (r_root.from < from_time OR p.branch_support = $branch_agnostic) // get attributes and relationships added on the branch during the timeframe AND (from_time <= diff_rel.from < $to_time) AND (diff_rel.to IS NULL OR (from_time <= diff_rel.to < $to_time)) @@ -773,7 +774,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: r in [r_root, r_node] WHERE r.from <= from_time AND r.branch IN [$branch_name, $base_branch_name] ) - AND (p.branch_support IN $branch_support OR q.branch_support IN $branch_support) + AND p.branch_support = $branch_aware AND any(l in labels(p) WHERE l in ["Attribute", "Relationship"]) AND type(diff_rel) IN ["IS_VISIBLE", "IS_PROTECTED", "HAS_SOURCE", "HAS_OWNER", "HAS_VALUE"] AND any(l in labels(q) WHERE l in ["Boolean", "Node", "AttributeValue"]) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 3b86ed5b6e..b05d138e15 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -22,7 +22,7 @@ from infrahub.config import load_and_exit from infrahub.core import registry from infrahub.core.branch import Branch -from infrahub.core.constants import BranchSupportType, InfrahubKind +from infrahub.core.constants import BranchSupportType, InfrahubKind, RelationshipCardinality, RelationshipDirection from infrahub.core.initialization import ( create_default_branch, create_global_branch, @@ -31,8 +31,11 @@ ) from infrahub.core.node import Node from infrahub.core.schema import SchemaRoot, core_models, internal_schema +from infrahub.core.schema.attribute_schema import AttributeSchema from infrahub.core.schema.definitions.core import core_profile_schema_definition from infrahub.core.schema.manager import SchemaManager +from infrahub.core.schema.node_schema import NodeSchema +from infrahub.core.schema.relationship_schema import RelationshipSchema from infrahub.core.schema.schema_branch import SchemaBranch from infrahub.core.utils import delete_all_nodes from infrahub.database import InfrahubDatabase, get_db @@ -513,6 +516,63 @@ async def car_person_schema( return registry.schema.register_schema(schema=car_person_schema_unregistered, branch=default_branch.name) +@pytest.fixture +async def car_person_schema_branch_local_root(db: InfrahubDatabase, default_branch: Branch) -> SchemaRoot: + schema = SchemaRoot( + nodes=[ + NodeSchema( + name="Car", + namespace="Test", + default_filter="name__value", + display_labels=["name__value", "color__value"], + uniqueness_constraints=[["name__value"]], + branch=BranchSupportType.LOCAL, + attributes=[ + AttributeSchema(name="name", kind="Text", unique=True), + AttributeSchema(name="color", kind="Text", default_value="#444444", optional=True), + ], + relationships=[ + RelationshipSchema( + name="owner", + peer="TestPerson", + optional=False, + cardinality=RelationshipCardinality.ONE, + direction=RelationshipDirection.OUTBOUND, + ), + ], + ), + NodeSchema( + name="Person", + namespace="Test", + default_filter="name__value", + display_labels=["name__value"], + branch=BranchSupportType.AWARE, + uniqueness_constraints=[["name__value"]], + attributes=[ + AttributeSchema(name="name", kind="Text", unique=True), + AttributeSchema(name="height", kind="Number", optional=True), + ], + relationships=[ + RelationshipSchema( + name="cars", + peer="TestCar", + cardinality=RelationshipCardinality.MANY, + direction=RelationshipDirection.INBOUND, + ) + ], + ), + ], + ) + return schema + + +@pytest.fixture +async def car_person_schema_branch_local( + db: InfrahubDatabase, default_branch: Branch, car_person_schema_branch_local_root +) -> SchemaBranch: + return registry.schema.register_schema(schema=car_person_schema_branch_local_root, branch=default_branch.name) + + @pytest.fixture async def animal_person_schema_unregistered(db: InfrahubDatabase, node_group_schema, data_schema) -> SchemaRoot: schema: dict[str, Any] = { diff --git a/backend/tests/unit/core/diff/test_diff_and_merge.py b/backend/tests/unit/core/diff/test_diff_and_merge.py index c45eaa01c7..df3429c21a 100644 --- a/backend/tests/unit/core/diff/test_diff_and_merge.py +++ b/backend/tests/unit/core/diff/test_diff_and_merge.py @@ -1,3 +1,4 @@ +from typing import Literal from unittest.mock import AsyncMock import pytest @@ -56,7 +57,11 @@ async def test_diff_and_merge_with_list_attribute( assert updated_node.mylist.value == ["c", "d", 3, 4] async def test_diff_and_merge_schema_with_default_values( - self, db: InfrahubDatabase, default_branch: Branch, register_core_models_schema, car_person_schema: SchemaBranch + self, + db: InfrahubDatabase, + default_branch: Branch, + register_core_models_schema: SchemaBranch, + car_person_schema: SchemaBranch, ): schema_main = registry.schema.get_schema_branch(name=default_branch.name) await registry.schema.update_schema_branch( @@ -98,12 +103,12 @@ async def test_diff_and_merge_with_attribute_value_conflict( db: InfrahubDatabase, default_branch: Branch, diff_repository: DiffRepository, - person_john_main, - person_jane_main, - person_alfred_main, - car_accord_main, - conflict_selection, - expected_value, + person_john_main: Node, + person_jane_main: Node, + person_alfred_main: Node, + car_accord_main: Node, + conflict_selection: ConflictSelection, + expected_value: Literal["John-main", "John-branch"], ): branch2 = await create_branch(db=db, branch_name="branch2") john_main = await NodeManager.get_one(db=db, id=person_john_main.id) @@ -134,12 +139,12 @@ async def test_diff_and_merge_with_relationship_conflict( db: InfrahubDatabase, default_branch: Branch, diff_repository: DiffRepository, - person_john_main, - person_jane_main, - person_alfred_main, - car_accord_main, - car_camry_main, - conflict_selection, + person_john_main: Node, + person_jane_main: Node, + person_alfred_main: Node, + car_accord_main: Node, + car_camry_main: Node, + conflict_selection: ConflictSelection, ): branch2 = await create_branch(db=db, branch_name="branch2") car_main = await NodeManager.get_one(db=db, id=car_accord_main.id) @@ -174,11 +179,11 @@ async def test_diff_and_merge_with_attribute_property_conflict( db: InfrahubDatabase, default_branch: Branch, diff_repository: DiffRepository, - person_john_main, - person_jane_main, - person_alfred_main, - car_accord_main, - conflict_selection, + person_john_main: Node, + person_jane_main: Node, + person_alfred_main: Node, + car_accord_main: Node, + conflict_selection: ConflictSelection, ): branch2 = await create_branch(db=db, branch_name="branch2") john_main = await NodeManager.get_one(db=db, id=person_john_main.id) @@ -214,12 +219,12 @@ async def test_diff_and_merge_with_relationship_property_conflict( db: InfrahubDatabase, default_branch: Branch, diff_repository: DiffRepository, - person_john_main, - person_jane_main, - person_alfred_main, - car_accord_main, - car_camry_main, - conflict_selection, + person_john_main: Node, + person_jane_main: Node, + person_alfred_main: Node, + car_accord_main: Node, + car_camry_main: Node, + conflict_selection: ConflictSelection, ): branch2 = await create_branch(db=db, branch_name="branch2") car_main = await NodeManager.get_one(db=db, id=car_accord_main.id) @@ -299,3 +304,142 @@ async def test_relationship_set_to_null(self, db: InfrahubDatabase, default_bran updated_friend = await NodeManager.get_one(db=db, id=friend_main.id) best_friend_rels = await updated_friend.best_friends.get_relationships(db=db) assert len(best_friend_rels) == 0 + + async def test_local_and_aware_nodes_added_on_branch( + self, db: InfrahubDatabase, default_branch: Branch, car_person_schema_branch_local: SchemaBranch + ): + branch2 = await create_branch(db=db, branch_name="branch2") + person = await Node.init(db=db, schema="TestPerson", branch=branch2) + await person.new(db=db, name="Guy", height=180) + await person.save(db=db) + car = await Node.init(db=db, schema="TestCar", branch=branch2) + await car.new(db=db, name="camry", owner=person.id) + await car.save(db=db) + + diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + diff_person = enriched_diff.get_node(node_uuid=person.id) + assert diff_person.action is DiffAction.ADDED + # validate car is not in the diff + with pytest.raises(ValueError, match=rf"No node {car.id}"): + enriched_diff.get_node(node_uuid=car.id) + + diff_merger = await self._get_diff_merger(db=db, branch=branch2) + await diff_merger.merge_graph(at=Timestamp()) + + # validate person update on main + updated_person = await NodeManager.get_one(db=db, id=person.id) + assert updated_person.height.value == 180 + assert updated_person.name.value == "Guy" + # validate car (branch=local) not merged to main + updated_car = await NodeManager.get_one(db=db, id=car.id) + assert updated_car is None + person_schema = registry.schema.get(name="TestPerson", duplicate=False) + cars_rel_schema = person_schema.get_relationship(name="cars") + cars_rels = await NodeManager.query_peers( + db=db, ids=[person.id], source_kind="TestPerson", schema=cars_rel_schema, filters={}, fetch_peers=True + ) + assert len(cars_rels) == 0 + car_schema = registry.schema.get(name="TestCar", duplicate=False) + owner_rel_schema = car_schema.get_relationship(name="owner") + owner_rels = await NodeManager.query_peers( + db=db, ids=[car.id], source_kind="TestCar", schema=owner_rel_schema, filters={}, fetch_peers=True + ) + assert len(owner_rels) == 0 + # validate relationship still exists on branch + cars_rels = await NodeManager.query_peers( + db=db, + branch=branch2, + ids=[person.id], + source_kind="TestPerson", + schema=cars_rel_schema, + filters={}, + fetch_peers=True, + ) + assert len(cars_rels) == 1 + assert cars_rels[0].peer_id == car.id + owner_rels = await NodeManager.query_peers( + db=db, + branch=branch2, + ids=[car.id], + source_kind="TestCar", + schema=owner_rel_schema, + filters={}, + fetch_peers=True, + ) + assert len(owner_rels) == 1 + assert owner_rels[0].peer_id == person.id + + async def test_agnostic_and_aware_nodes_added_on_branch( + self, db: InfrahubDatabase, default_branch: Branch, car_person_schema_global + ): + branch2 = await create_branch(db=db, branch_name="branch2") + person = await Node.init(db=db, schema="TestPerson", branch=branch2) + await person.new(db=db, name="Guy", height=180) + await person.save(db=db) + car = await Node.init(db=db, schema="TestCar", branch=branch2) + await car.new(db=db, name="camry", nbr_seats=3, is_electric=False, owner=person.id) + await car.save(db=db) + + diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) + enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2) + diff_person = enriched_diff.get_node(node_uuid=person.id) + assert diff_person.action is DiffAction.UPDATED + diff_car = enriched_diff.get_node(node_uuid=car.id) + assert diff_car.action is DiffAction.ADDED + + diff_merger = await self._get_diff_merger(db=db, branch=branch2) + await diff_merger.merge_graph(at=Timestamp()) + + # validate person (agnostic) exists on main + updated_person = await NodeManager.get_one(db=db, id=person.id) + assert updated_person.height.value == 180 + assert updated_person.name.value == "Guy" + cars_rels = await updated_person.cars.get(db=db) + assert len(cars_rels) == 1 + assert cars_rels[0].peer_id == car.id + # validate car merged to main + updated_car = await NodeManager.get_one(db=db, id=car.id) + assert updated_car.name.value == "camry" + assert updated_car.nbr_seats.value == 3 + assert updated_car.is_electric.value is False + owner_rel = await updated_car.owner.get(db=db) + assert owner_rel.peer_id == person.id + + person_schema = registry.schema.get(name="TestPerson", duplicate=False) + cars_rel_schema = person_schema.get_relationship(name="cars") + cars_rels = await NodeManager.query_peers( + db=db, ids=[person.id], source_kind="TestPerson", schema=cars_rel_schema, filters={}, fetch_peers=True + ) + assert len(cars_rels) == 1 + assert cars_rels[0].peer_id == car.id + car_schema = registry.schema.get(name="TestCar", duplicate=False) + owner_rel_schema = car_schema.get_relationship(name="owner") + owner_rels = await NodeManager.query_peers( + db=db, ids=[car.id], source_kind="TestCar", schema=owner_rel_schema, filters={}, fetch_peers=True + ) + assert len(owner_rels) == 1 + assert owner_rels[0].peer_id == person.id + # validate relationship still exists on branch + cars_rels = await NodeManager.query_peers( + db=db, + branch=branch2, + ids=[person.id], + source_kind="TestPerson", + schema=cars_rel_schema, + filters={}, + fetch_peers=True, + ) + assert len(cars_rels) == 1 + assert cars_rels[0].peer_id == car.id + owner_rels = await NodeManager.query_peers( + db=db, + branch=branch2, + ids=[car.id], + source_kind="TestCar", + schema=owner_rel_schema, + filters={}, + fetch_peers=True, + ) + assert len(owner_rels) == 1 + assert owner_rels[0].peer_id == person.id diff --git a/backend/tests/unit/core/diff/test_diff_calculator.py b/backend/tests/unit/core/diff/test_diff_calculator.py index 5c04c451d0..4944fefa4a 100644 --- a/backend/tests/unit/core/diff/test_diff_calculator.py +++ b/backend/tests/unit/core/diff/test_diff_calculator.py @@ -2427,3 +2427,83 @@ async def test_hierarchy_with_same_kind_parent_and_child( assert diff_prop.action is DiffAction.ADDED assert diff_prop.previous_value is None assert diff_prop.new_value == new_value + + +async def test_create_local_and_aware_nodes_on_branch( + db: InfrahubDatabase, default_branch: Branch, car_person_schema_branch_local: SchemaBranch +): + branch = await create_branch(db=db, branch_name="branch") + from_time = Timestamp() + person = await Node.init(db=db, schema="TestPerson", branch=branch) + await person.new(db=db, name="Guy", height=180) + await person.save(db=db) + # car is a local node + car = await Node.init(db=db, schema="TestCar", branch=branch) + await car.new(db=db, name="camry", owner=person.id) + await car.save(db=db) + + diff_calculator = DiffCalculator(db=db) + calculated_diffs = await diff_calculator.calculate_diff( + base_branch=default_branch, diff_branch=branch, from_time=from_time, to_time=Timestamp() + ) + + base_branch_diff = calculated_diffs.base_branch_diff + assert len(base_branch_diff.nodes) == 0 + + diff_branch_diff = calculated_diffs.diff_branch_diff + nodes_by_id = {n.uuid: n for n in diff_branch_diff.nodes} + assert set(nodes_by_id.keys()) == {person.id} + node_diff = nodes_by_id[person.id] + assert node_diff.action is DiffAction.ADDED + assert len(node_diff.relationships) == 0 + attrs_by_name = {a.name: a for a in node_diff.attributes} + assert set(attrs_by_name.keys()) == {"name", "height"} + for attr_diff in node_diff.attributes: + assert attr_diff.action is DiffAction.ADDED + + +async def test_create_aware_and_agnostic_nodes_on_branch( + db: InfrahubDatabase, default_branch: Branch, car_person_schema_global +): + branch = await create_branch(db=db, branch_name="branch") + from_time = Timestamp() + # person is an agnostic node + person = await Node.init(db=db, schema="TestPerson", branch=branch) + await person.new(db=db, name="Guy", height=180) + await person.save(db=db) + # nbr_seats is an agnostic attr + car = await Node.init(db=db, schema="TestCar", branch=branch) + await car.new(db=db, name="camry", nbr_seats=3, is_electric=True, owner=person.id) + await car.save(db=db) + + diff_calculator = DiffCalculator(db=db) + calculated_diffs = await diff_calculator.calculate_diff( + base_branch=default_branch, diff_branch=branch, from_time=from_time, to_time=Timestamp() + ) + + base_branch_diff = calculated_diffs.base_branch_diff + assert len(base_branch_diff.nodes) == 0 + + diff_branch_diff = calculated_diffs.diff_branch_diff + nodes_by_id = {n.uuid: n for n in diff_branch_diff.nodes} + assert set(nodes_by_id.keys()) == {car.id, person.id} + # check car attributes and relationship + node_diff = nodes_by_id[car.id] + assert node_diff.action is DiffAction.ADDED + assert len(node_diff.relationships) == 1 + rel_diff = node_diff.relationships.pop() + assert rel_diff.name == "owner" + assert rel_diff.action is DiffAction.ADDED + attrs_by_name = {a.name: a for a in node_diff.attributes} + # nbr_seats is agnostic, so is not included + assert set(attrs_by_name.keys()) == {"name", "color", "is_electric"} + for attr_diff in node_diff.attributes: + assert attr_diff.action is DiffAction.ADDED + # check person relationship + node_diff = nodes_by_id[person.id] + assert node_diff.action is DiffAction.UPDATED + assert len(node_diff.attributes) == 0 + assert len(node_diff.relationships) == 1 + rel_diff = node_diff.relationships.pop() + assert rel_diff.name == "cars" + assert rel_diff.action is DiffAction.UPDATED