Skip to content

Commit

Permalink
exclude branch=local relationships from diff (#5011)
Browse files Browse the repository at this point in the history
* test that reproduces failure

* make branch_support static in DiffAllPathsQuery

* more tests
  • Loading branch information
ajtmccarty authored Nov 21, 2024
1 parent de31fdb commit 6d1de2a
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 34 deletions.
19 changes: 10 additions & 9 deletions backend/infrahub/core/query/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,14 +528,12 @@ def __init__(
self,
base_branch: Branch,
diff_branch_from_time: Timestamp,
branch_support: list[BranchSupportType] | None = None,
current_node_field_specifiers: list[tuple[str, str]] | None = None,
new_node_field_specifiers: list[tuple[str, str]] | None = None,
**kwargs: Any,
):
self.base_branch = base_branch
self.diff_branch_from_time = diff_branch_from_time
self.branch_support = branch_support or [BranchSupportType.AWARE]
self.current_node_field_specifiers = current_node_field_specifiers
self.new_node_field_specifiers = new_node_field_specifiers

Expand All @@ -551,7 +549,9 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
"branch_from_time": self.diff_branch_from_time.to_string(),
"from_time": from_str,
"to_time": self.diff_to.to_string(),
"branch_support": [item.value for item in self.branch_support],
"branch_local": BranchSupportType.LOCAL.value,
"branch_aware": BranchSupportType.AWARE.value,
"branch_agnostic": BranchSupportType.AGNOSTIC.value,
"new_node_field_specifiers": self.new_node_field_specifiers,
"current_node_field_specifiers": self.current_node_field_specifiers,
}
Expand All @@ -577,7 +577,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
WHERE (node_ids_list IS NULL OR p.uuid IN node_ids_list)
AND (from_time <= diff_rel.from < $to_time)
AND (diff_rel.to IS NULL OR (from_time <= diff_rel.to < $to_time))
AND (p.branch_support IN $branch_support OR q.branch_support IN $branch_support)
AND p.branch_support = $branch_aware
WITH p, q, diff_rel, from_time
// -------------------------------------
// Exclude nodes added then removed on branch within timeframe
Expand All @@ -603,6 +603,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
WHERE %(id_func)s(diff_rel) = %(id_func)s(top_diff_rel)
AND type(r_node) IN ["HAS_ATTRIBUTE", "IS_RELATED"]
AND any(l in labels(node) WHERE l in ["Attribute", "Relationship"])
AND node.branch_support IN [$branch_aware, $branch_agnostic]
AND type(r_prop) IN ["IS_VISIBLE", "IS_PROTECTED", "HAS_SOURCE", "HAS_OWNER", "HAS_VALUE", "IS_RELATED"]
AND any(l in labels(prop) WHERE l in ["Boolean", "Node", "AttributeValue"])
AND ALL(
Expand Down Expand Up @@ -655,9 +656,9 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
// exclude attributes and relationships under added/removed nodes b/c they are covered above
WHERE (node_field_specifiers_list IS NULL OR [p.uuid, q.name] IN node_field_specifiers_list)
AND r_root.branch IN [$branch_name, $base_branch_name, $global_branch_name]
AND (p.branch_support IN $branch_support OR q.branch_support IN $branch_support)
AND q.branch_support = $branch_aware
// if p has a different type of branch support and was addded within our timeframe
AND (r_root.from < from_time OR NOT (p.branch_support IN $branch_support))
AND (r_root.from < from_time OR p.branch_support = $branch_agnostic)
AND r_root.status = "active"
// get attributes and relationships added on the branch during the timeframe
AND (from_time <= diff_rel.from < $to_time)
Expand All @@ -671,9 +672,9 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
// exclude attributes and relationships under added/removed nodes b/c they are covered above
WHERE (node_field_specifiers_list IS NULL OR [p.uuid, q.name] IN node_field_specifiers_list)
AND r_root.branch IN [$branch_name, $base_branch_name, $global_branch_name]
AND (p.branch_support IN $branch_support OR q.branch_support IN $branch_support)
AND q.branch_support = $branch_aware
// if p has a different type of branch support and was addded within our timeframe
AND (r_root.from < from_time OR NOT (p.branch_support IN $branch_support))
AND (r_root.from < from_time OR p.branch_support = $branch_agnostic)
// get attributes and relationships added on the branch during the timeframe
AND (from_time <= diff_rel.from < $to_time)
AND (diff_rel.to IS NULL OR (from_time <= diff_rel.to < $to_time))
Expand Down Expand Up @@ -773,7 +774,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
r in [r_root, r_node]
WHERE r.from <= from_time AND r.branch IN [$branch_name, $base_branch_name]
)
AND (p.branch_support IN $branch_support OR q.branch_support IN $branch_support)
AND p.branch_support = $branch_aware
AND any(l in labels(p) WHERE l in ["Attribute", "Relationship"])
AND type(diff_rel) IN ["IS_VISIBLE", "IS_PROTECTED", "HAS_SOURCE", "HAS_OWNER", "HAS_VALUE"]
AND any(l in labels(q) WHERE l in ["Boolean", "Node", "AttributeValue"])
Expand Down
62 changes: 61 additions & 1 deletion backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from infrahub.config import load_and_exit
from infrahub.core import registry
from infrahub.core.branch import Branch
from infrahub.core.constants import BranchSupportType, InfrahubKind
from infrahub.core.constants import BranchSupportType, InfrahubKind, RelationshipCardinality, RelationshipDirection
from infrahub.core.initialization import (
create_default_branch,
create_global_branch,
Expand All @@ -31,8 +31,11 @@
)
from infrahub.core.node import Node
from infrahub.core.schema import SchemaRoot, core_models, internal_schema
from infrahub.core.schema.attribute_schema import AttributeSchema
from infrahub.core.schema.definitions.core import core_profile_schema_definition
from infrahub.core.schema.manager import SchemaManager
from infrahub.core.schema.node_schema import NodeSchema
from infrahub.core.schema.relationship_schema import RelationshipSchema
from infrahub.core.schema.schema_branch import SchemaBranch
from infrahub.core.utils import delete_all_nodes
from infrahub.database import InfrahubDatabase, get_db
Expand Down Expand Up @@ -513,6 +516,63 @@ async def car_person_schema(
return registry.schema.register_schema(schema=car_person_schema_unregistered, branch=default_branch.name)


@pytest.fixture
async def car_person_schema_branch_local_root(db: InfrahubDatabase, default_branch: Branch) -> SchemaRoot:
schema = SchemaRoot(
nodes=[
NodeSchema(
name="Car",
namespace="Test",
default_filter="name__value",
display_labels=["name__value", "color__value"],
uniqueness_constraints=[["name__value"]],
branch=BranchSupportType.LOCAL,
attributes=[
AttributeSchema(name="name", kind="Text", unique=True),
AttributeSchema(name="color", kind="Text", default_value="#444444", optional=True),
],
relationships=[
RelationshipSchema(
name="owner",
peer="TestPerson",
optional=False,
cardinality=RelationshipCardinality.ONE,
direction=RelationshipDirection.OUTBOUND,
),
],
),
NodeSchema(
name="Person",
namespace="Test",
default_filter="name__value",
display_labels=["name__value"],
branch=BranchSupportType.AWARE,
uniqueness_constraints=[["name__value"]],
attributes=[
AttributeSchema(name="name", kind="Text", unique=True),
AttributeSchema(name="height", kind="Number", optional=True),
],
relationships=[
RelationshipSchema(
name="cars",
peer="TestCar",
cardinality=RelationshipCardinality.MANY,
direction=RelationshipDirection.INBOUND,
)
],
),
],
)
return schema


@pytest.fixture
async def car_person_schema_branch_local(
db: InfrahubDatabase, default_branch: Branch, car_person_schema_branch_local_root
) -> SchemaBranch:
return registry.schema.register_schema(schema=car_person_schema_branch_local_root, branch=default_branch.name)


@pytest.fixture
async def animal_person_schema_unregistered(db: InfrahubDatabase, node_group_schema, data_schema) -> SchemaRoot:
schema: dict[str, Any] = {
Expand Down
192 changes: 168 additions & 24 deletions backend/tests/unit/core/diff/test_diff_and_merge.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Literal
from unittest.mock import AsyncMock

import pytest
Expand Down Expand Up @@ -56,7 +57,11 @@ async def test_diff_and_merge_with_list_attribute(
assert updated_node.mylist.value == ["c", "d", 3, 4]

async def test_diff_and_merge_schema_with_default_values(
self, db: InfrahubDatabase, default_branch: Branch, register_core_models_schema, car_person_schema: SchemaBranch
self,
db: InfrahubDatabase,
default_branch: Branch,
register_core_models_schema: SchemaBranch,
car_person_schema: SchemaBranch,
):
schema_main = registry.schema.get_schema_branch(name=default_branch.name)
await registry.schema.update_schema_branch(
Expand Down Expand Up @@ -98,12 +103,12 @@ async def test_diff_and_merge_with_attribute_value_conflict(
db: InfrahubDatabase,
default_branch: Branch,
diff_repository: DiffRepository,
person_john_main,
person_jane_main,
person_alfred_main,
car_accord_main,
conflict_selection,
expected_value,
person_john_main: Node,
person_jane_main: Node,
person_alfred_main: Node,
car_accord_main: Node,
conflict_selection: ConflictSelection,
expected_value: Literal["John-main", "John-branch"],
):
branch2 = await create_branch(db=db, branch_name="branch2")
john_main = await NodeManager.get_one(db=db, id=person_john_main.id)
Expand Down Expand Up @@ -134,12 +139,12 @@ async def test_diff_and_merge_with_relationship_conflict(
db: InfrahubDatabase,
default_branch: Branch,
diff_repository: DiffRepository,
person_john_main,
person_jane_main,
person_alfred_main,
car_accord_main,
car_camry_main,
conflict_selection,
person_john_main: Node,
person_jane_main: Node,
person_alfred_main: Node,
car_accord_main: Node,
car_camry_main: Node,
conflict_selection: ConflictSelection,
):
branch2 = await create_branch(db=db, branch_name="branch2")
car_main = await NodeManager.get_one(db=db, id=car_accord_main.id)
Expand Down Expand Up @@ -174,11 +179,11 @@ async def test_diff_and_merge_with_attribute_property_conflict(
db: InfrahubDatabase,
default_branch: Branch,
diff_repository: DiffRepository,
person_john_main,
person_jane_main,
person_alfred_main,
car_accord_main,
conflict_selection,
person_john_main: Node,
person_jane_main: Node,
person_alfred_main: Node,
car_accord_main: Node,
conflict_selection: ConflictSelection,
):
branch2 = await create_branch(db=db, branch_name="branch2")
john_main = await NodeManager.get_one(db=db, id=person_john_main.id)
Expand Down Expand Up @@ -214,12 +219,12 @@ async def test_diff_and_merge_with_relationship_property_conflict(
db: InfrahubDatabase,
default_branch: Branch,
diff_repository: DiffRepository,
person_john_main,
person_jane_main,
person_alfred_main,
car_accord_main,
car_camry_main,
conflict_selection,
person_john_main: Node,
person_jane_main: Node,
person_alfred_main: Node,
car_accord_main: Node,
car_camry_main: Node,
conflict_selection: ConflictSelection,
):
branch2 = await create_branch(db=db, branch_name="branch2")
car_main = await NodeManager.get_one(db=db, id=car_accord_main.id)
Expand Down Expand Up @@ -299,3 +304,142 @@ async def test_relationship_set_to_null(self, db: InfrahubDatabase, default_bran
updated_friend = await NodeManager.get_one(db=db, id=friend_main.id)
best_friend_rels = await updated_friend.best_friends.get_relationships(db=db)
assert len(best_friend_rels) == 0

async def test_local_and_aware_nodes_added_on_branch(
self, db: InfrahubDatabase, default_branch: Branch, car_person_schema_branch_local: SchemaBranch
):
branch2 = await create_branch(db=db, branch_name="branch2")
person = await Node.init(db=db, schema="TestPerson", branch=branch2)
await person.new(db=db, name="Guy", height=180)
await person.save(db=db)
car = await Node.init(db=db, schema="TestCar", branch=branch2)
await car.new(db=db, name="camry", owner=person.id)
await car.save(db=db)

diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2)
enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2)
diff_person = enriched_diff.get_node(node_uuid=person.id)
assert diff_person.action is DiffAction.ADDED
# validate car is not in the diff
with pytest.raises(ValueError, match=rf"No node {car.id}"):
enriched_diff.get_node(node_uuid=car.id)

diff_merger = await self._get_diff_merger(db=db, branch=branch2)
await diff_merger.merge_graph(at=Timestamp())

# validate person update on main
updated_person = await NodeManager.get_one(db=db, id=person.id)
assert updated_person.height.value == 180
assert updated_person.name.value == "Guy"
# validate car (branch=local) not merged to main
updated_car = await NodeManager.get_one(db=db, id=car.id)
assert updated_car is None
person_schema = registry.schema.get(name="TestPerson", duplicate=False)
cars_rel_schema = person_schema.get_relationship(name="cars")
cars_rels = await NodeManager.query_peers(
db=db, ids=[person.id], source_kind="TestPerson", schema=cars_rel_schema, filters={}, fetch_peers=True
)
assert len(cars_rels) == 0
car_schema = registry.schema.get(name="TestCar", duplicate=False)
owner_rel_schema = car_schema.get_relationship(name="owner")
owner_rels = await NodeManager.query_peers(
db=db, ids=[car.id], source_kind="TestCar", schema=owner_rel_schema, filters={}, fetch_peers=True
)
assert len(owner_rels) == 0
# validate relationship still exists on branch
cars_rels = await NodeManager.query_peers(
db=db,
branch=branch2,
ids=[person.id],
source_kind="TestPerson",
schema=cars_rel_schema,
filters={},
fetch_peers=True,
)
assert len(cars_rels) == 1
assert cars_rels[0].peer_id == car.id
owner_rels = await NodeManager.query_peers(
db=db,
branch=branch2,
ids=[car.id],
source_kind="TestCar",
schema=owner_rel_schema,
filters={},
fetch_peers=True,
)
assert len(owner_rels) == 1
assert owner_rels[0].peer_id == person.id

async def test_agnostic_and_aware_nodes_added_on_branch(
self, db: InfrahubDatabase, default_branch: Branch, car_person_schema_global
):
branch2 = await create_branch(db=db, branch_name="branch2")
person = await Node.init(db=db, schema="TestPerson", branch=branch2)
await person.new(db=db, name="Guy", height=180)
await person.save(db=db)
car = await Node.init(db=db, schema="TestCar", branch=branch2)
await car.new(db=db, name="camry", nbr_seats=3, is_electric=False, owner=person.id)
await car.save(db=db)

diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2)
enriched_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch2)
diff_person = enriched_diff.get_node(node_uuid=person.id)
assert diff_person.action is DiffAction.UPDATED
diff_car = enriched_diff.get_node(node_uuid=car.id)
assert diff_car.action is DiffAction.ADDED

diff_merger = await self._get_diff_merger(db=db, branch=branch2)
await diff_merger.merge_graph(at=Timestamp())

# validate person (agnostic) exists on main
updated_person = await NodeManager.get_one(db=db, id=person.id)
assert updated_person.height.value == 180
assert updated_person.name.value == "Guy"
cars_rels = await updated_person.cars.get(db=db)
assert len(cars_rels) == 1
assert cars_rels[0].peer_id == car.id
# validate car merged to main
updated_car = await NodeManager.get_one(db=db, id=car.id)
assert updated_car.name.value == "camry"
assert updated_car.nbr_seats.value == 3
assert updated_car.is_electric.value is False
owner_rel = await updated_car.owner.get(db=db)
assert owner_rel.peer_id == person.id

person_schema = registry.schema.get(name="TestPerson", duplicate=False)
cars_rel_schema = person_schema.get_relationship(name="cars")
cars_rels = await NodeManager.query_peers(
db=db, ids=[person.id], source_kind="TestPerson", schema=cars_rel_schema, filters={}, fetch_peers=True
)
assert len(cars_rels) == 1
assert cars_rels[0].peer_id == car.id
car_schema = registry.schema.get(name="TestCar", duplicate=False)
owner_rel_schema = car_schema.get_relationship(name="owner")
owner_rels = await NodeManager.query_peers(
db=db, ids=[car.id], source_kind="TestCar", schema=owner_rel_schema, filters={}, fetch_peers=True
)
assert len(owner_rels) == 1
assert owner_rels[0].peer_id == person.id
# validate relationship still exists on branch
cars_rels = await NodeManager.query_peers(
db=db,
branch=branch2,
ids=[person.id],
source_kind="TestPerson",
schema=cars_rel_schema,
filters={},
fetch_peers=True,
)
assert len(cars_rels) == 1
assert cars_rels[0].peer_id == car.id
owner_rels = await NodeManager.query_peers(
db=db,
branch=branch2,
ids=[car.id],
source_kind="TestCar",
schema=owner_rel_schema,
filters={},
fetch_peers=True,
)
assert len(owner_rels) == 1
assert owner_rels[0].peer_id == person.id
Loading

0 comments on commit 6d1de2a

Please sign in to comment.