Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use elementId() instead of ID() for neo4j queries #4377

Merged
merged 5 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/infrahub/core/protocols_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ async def run_query(
def render_list_comprehension(self, items: str, item_name: str) -> str: ...
def render_list_comprehension_with_list(self, items: str, item_names: list[str]) -> str: ...
def render_uuid_generation(self, node_label: str, node_attr: str) -> str: ...
def get_id_function_name(self) -> str: ...
def to_database_id(self, db_id: str | int) -> str | int: ...


@runtime_checkable
Expand Down
30 changes: 18 additions & 12 deletions backend/infrahub/core/query/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from infrahub import config
from infrahub.core.constants import RelationshipStatus
from infrahub.core.query import Query, QueryType
from infrahub.core.utils import element_id_to_id

if TYPE_CHECKING:
from infrahub.database import InfrahubDatabase
Expand All @@ -24,13 +23,15 @@ def __init__(self, node_id: int, **kwargs: Any):
async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
query = """
MATCH (root:Root)
MATCH (d) WHERE ID(d) = $node_id
MATCH (d) WHERE %(id_func)s(d) = $node_id
WITH root,d
CREATE (d)-[r:IS_PART_OF { branch: $branch, branch_level: $branch_level, from: $now, to: null, status: $status }]->(root)
RETURN ID(r)
"""
RETURN %(id_func)s(r)
""" % {
"id_func": db.get_id_function_name(),
}

self.params["node_id"] = element_id_to_id(self.node_id)
self.params["node_id"] = db.to_database_id(self.node_id)
self.params["now"] = self.at.to_string()
self.params["branch"] = self.branch.name
self.params["branch_level"] = self.branch.hierarchy_level
Expand Down Expand Up @@ -100,16 +101,18 @@ def __init__(self, ids: list[str], **kwargs: Any) -> None:
async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
query = """
MATCH ()-[r]->()
WHERE ID(r) IN $ids
WHERE %(id_func)s(r) IN $ids
SET r.from = $at
SET r.conflict = NULL
"""
""" % {
"id_func": db.get_id_function_name(),
}

self.add_to_query(query=query)

self.params["at"] = self.at.to_string()
self.params["ids"] = [element_id_to_id(id) for id in self.ids]
self.return_labels = ["ID(r)"]
self.params["ids"] = [db.to_database_id(id) for id in self.ids]
self.return_labels = [f"{db.get_id_function_name()}(r)"]


class RebaseBranchDeleteRelationshipQuery(Query):
Expand All @@ -126,21 +129,24 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
if config.SETTINGS.database.db_type == config.DatabaseType.MEMGRAPH:
query = """
MATCH p = (s)-[r]-(d)
WHERE ID(r) IN $ids
WHERE %(id_func)s(r) IN $ids
DELETE r
"""
else:
query = """
MATCH p = (s)-[r]-(d)
WHERE ID(r) IN $ids
WHERE %(id_func)s(r) IN $ids
DELETE r
WITH *
UNWIND nodes(p) AS n
MATCH (n)
WHERE NOT exists((n)--())
DELETE n
"""
query %= {
"id_func": db.get_id_function_name(),
}

self.add_to_query(query=query)

self.params["ids"] = [element_id_to_id(id) for id in self.ids]
self.params["ids"] = [db.to_database_id(id) for id in self.ids]
31 changes: 16 additions & 15 deletions backend/infrahub/core/query/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,13 +584,13 @@ async def query_init(self, db: InfrahubDatabase, **kwargs):
OPTIONAL MATCH path = (
(:Root)<-[r_root:IS_PART_OF]-(n:Node)-[r_node]-(inner_p)-[inner_diff_rel]->(inner_q)
)
WHERE ID(inner_p) = ID(p) AND ID(inner_diff_rel) = ID(diff_rel) AND ID(inner_q) = ID(q)
WHERE %(id_func)s(inner_p) = %(id_func)s(p) AND %(id_func)s(inner_diff_rel) = %(id_func)s(diff_rel) AND %(id_func)s(inner_q) = %(id_func)s(q)
AND any(l in labels(inner_p) WHERE l in ["Attribute", "Relationship"])
AND type(inner_diff_rel) IN ["IS_VISIBLE", "IS_PROTECTED", "HAS_SOURCE", "HAS_OWNER", "HAS_VALUE"]
AND any(l in labels(inner_q) WHERE l in ["Boolean", "Node", "AttributeValue"])
AND type(r_node) IN ["HAS_ATTRIBUTE", "IS_RELATED"]
AND %(n_node_where)s
AND [ID(n), type(r_node)] <> [ID(inner_q), type(inner_diff_rel)]
AND [%(id_func)s(n), type(r_node)] <> [%(id_func)s(inner_q), type(inner_diff_rel)]
AND ALL(
r in [r_root, r_node]
WHERE r.from <= $to_time AND r.branch IN $branch_names
Expand All @@ -600,8 +600,8 @@ async def query_init(self, db: InfrahubDatabase, **kwargs):
AND (r_node.status = "deleted" OR r_root.status = "active")
WITH path AS diff_rel_path, diff_rel, r_root, n, r_node, p
ORDER BY
ID(n) DESC,
ID(p) DESC,
%(id_func)s(n) DESC,
%(id_func)s(p) DESC,
r_node.branch = diff_rel.branch DESC,
r_root.branch = diff_rel.branch DESC,
r_node.from DESC,
Expand All @@ -619,9 +619,9 @@ async def query_init(self, db: InfrahubDatabase, **kwargs):
// get base branch version of the diff path, if it exists
WITH diff_rel_path, diff_rel, r_root, n, r_node, p
OPTIONAL MATCH latest_base_path = (:Root)<-[r_root2]-(n2)-[r_node2]-(inner_p2)-[base_diff_rel]->(base_prop)
WHERE ID(r_root2) = ID(r_root) AND ID(n2) = ID(n) AND ID(r_node2) = ID(r_node) AND ID(inner_p2) = ID(p)
WHERE %(id_func)s(r_root2) = %(id_func)s(r_root) AND %(id_func)s(n2) = %(id_func)s(n) AND %(id_func)s(r_node2) = %(id_func)s(r_node) AND %(id_func)s(inner_p2) = %(id_func)s(p)
AND any(r in relationships(diff_rel_path) WHERE r.branch = $branch_name)
AND ID(n2) <> ID(base_prop)
AND %(id_func)s(n2) <> %(id_func)s(base_prop)
AND type(base_diff_rel) = type(diff_rel)
AND all(
r in relationships(latest_base_path)
Expand All @@ -638,9 +638,9 @@ async def query_init(self, db: InfrahubDatabase, **kwargs):
OPTIONAL MATCH base_peer_path = (
(:Root)<-[r_root3]-(n3)-[r_node3]-(inner_p3:Relationship)-[base_r_peer:IS_RELATED]-(base_peer:Node)
)
WHERE ID(r_root3) = ID(r_root) AND ID(n3) = ID(n) AND ID(r_node3) = ID(r_node) AND ID(inner_p3) = ID(p)
WHERE %(id_func)s(r_root3) = %(id_func)s(r_root) AND %(id_func)s(n3) = %(id_func)s(n) AND %(id_func)s(r_node3) = %(id_func)s(r_node) AND %(id_func)s(inner_p3) = %(id_func)s(p)
AND type(diff_rel) <> "IS_RELATED"
AND [ID(n3), type(r_node3)] <> [ID(base_peer), type(base_r_peer)]
AND [%(id_func)s(n3), type(r_node3)] <> [%(id_func)s(base_peer), type(base_r_peer)]
AND base_r_peer.from <= $to_time
AND base_r_peer.branch IN $branch_names
// exclude paths where an active edge is below a deleted edge
Expand All @@ -663,7 +663,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs):
OPTIONAL MATCH path = (
(:Root)<-[r_root:IS_PART_OF]-(inner_p)-[inner_diff_rel]-(inner_q)-[r_prop]-(prop)
)
WHERE ID(inner_p) = ID(p) AND ID(inner_diff_rel) = ID(diff_rel) AND ID(inner_q) = ID(q)
WHERE %(id_func)s(inner_p) = %(id_func)s(p) AND %(id_func)s(inner_diff_rel) = %(id_func)s(diff_rel) AND %(id_func)s(inner_q) = %(id_func)s(q)
AND "Node" IN labels(inner_p)
AND type(inner_diff_rel) IN ["HAS_ATTRIBUTE", "IS_RELATED"]
AND any(l in labels(inner_q) WHERE l in ["Attribute", "Relationship"])
Expand All @@ -673,13 +673,13 @@ async def query_init(self, db: InfrahubDatabase, **kwargs):
r in [r_root, r_prop]
WHERE r.from <= $to_time AND r.branch IN $branch_names
)
AND [ID(inner_p), type(inner_diff_rel)] <> [ID(prop), type(r_prop)]
AND [%(id_func)s(inner_p), type(inner_diff_rel)] <> [%(id_func)s(prop), type(r_prop)]
// exclude paths where an active edge is below a deleted edge
AND (inner_diff_rel.status = "active" OR (r_prop.status = "deleted" AND inner_diff_rel.branch = r_prop.branch))
AND (inner_diff_rel.status = "deleted" OR r_root.status = "active")
WITH path, prop, r_prop, r_root
ORDER BY
ID(prop),
%(id_func)s(prop),
r_prop.branch = diff_rel.branch DESC,
r_root.branch = diff_rel.branch DESC,
r_prop.from DESC,
Expand All @@ -697,7 +697,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs):
OPTIONAL MATCH path = (
(inner_q:Root)<-[inner_diff_rel:IS_PART_OF]-(inner_p:Node)-[r_node]-(node)-[r_prop]-(prop)
)
WHERE ID(inner_p) = ID(p) AND ID(inner_diff_rel) = ID(diff_rel) AND ID(inner_q) = ID(q)
WHERE %(id_func)s(inner_p) = %(id_func)s(p) AND %(id_func)s(inner_diff_rel) = %(id_func)s(diff_rel) AND %(id_func)s(inner_q) = %(id_func)s(q)
AND type(r_node) IN ["HAS_ATTRIBUTE", "IS_RELATED"]
AND any(l in labels(node) WHERE l in ["Attribute", "Relationship"])
AND type(r_prop) IN ["IS_VISIBLE", "IS_PROTECTED", "HAS_SOURCE", "HAS_OWNER", "HAS_VALUE", "IS_RELATED"]
Expand All @@ -706,7 +706,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs):
r in [r_node, r_prop]
WHERE r.from <= $to_time AND r.branch IN $branch_names
)
AND [ID(inner_p), type(r_node)] <> [ID(prop), type(r_prop)]
AND [%(id_func)s(inner_p), type(r_node)] <> [%(id_func)s(prop), type(r_prop)]
// exclude paths where an active edge is below a deleted edge
AND (inner_diff_rel.status = "active" OR
(
Expand All @@ -717,8 +717,8 @@ async def query_init(self, db: InfrahubDatabase, **kwargs):
AND (r_prop.status = "deleted" OR r_node.status = "active")
WITH path, node, prop, r_prop, r_node
ORDER BY
ID(node),
ID(prop),
%(id_func)s(node),
%(id_func)s(prop),
r_prop.branch = diff_rel.branch DESC,
r_node.branch = diff_rel.branch DESC,
r_prop.from DESC,
Expand All @@ -730,6 +730,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs):
WITH p, q, diff_rel, full_diff_paths + latest_paths AS full_diff_paths
""" % {
"diff_rel_filter": diff_rel_filter,
"id_func": db.get_id_function_name(),
"p_node_where": p_node_where,
"n_node_where": n_node_where,
}
Expand Down
8 changes: 5 additions & 3 deletions backend/infrahub/core/query/ipam.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ def rel_filter(rel_name: str) -> str:
-[r_attr:HAS_ATTRIBUTE]->(attr:Attribute)
-[r_attr_val:HAS_VALUE]->(av:AttributeValue)
)
WHERE ID(r_1) = ID(r_rel1)
AND ID(r_2) = ID(r_rel2)
WHERE %(id_func)s(r_1) = %(id_func)s(r_rel1)
AND %(id_func)s(r_2) = %(id_func)s(r_rel2)
AND ({rel_filter("r_attr")})
AND ({rel_filter("r_attr_val")})
AND attr.name IN ["prefix", "address"]
Expand All @@ -301,7 +301,9 @@ def rel_filter(rel_name: str) -> str:
deepest_branch_details[1] AS branch,
head(collect(is_active)) AS is_latest_active
WHERE is_latest_active = TRUE
"""
""" % {
"id_func": db.get_id_function_name(),
}
self.return_labels = ["pfx", "child", "av", "branch_level", "branch"]
self.add_to_query(query)

Expand Down
2 changes: 1 addition & 1 deletion backend/infrahub/core/query/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
self.order_by = []
self.params["node_kind"] = self.schema.kind

self.return_labels = ["n.uuid", "rb.branch", "ID(rb) as rb_id"]
self.return_labels = ["n.uuid", "rb.branch", f"{db.get_id_function_name()}(rb) as rb_id"]
where_clause_elements = []

branch_filter, branch_params = self.branch.get_query_filter_path(
Expand Down
9 changes: 6 additions & 3 deletions backend/infrahub/core/query/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from infrahub.core.query import Query, QueryType
from infrahub.core.query.subquery import build_subquery_filter, build_subquery_order
from infrahub.core.timestamp import Timestamp
from infrahub.core.utils import element_id_to_id, extract_field_filters
from infrahub.core.utils import extract_field_filters

if TYPE_CHECKING:
from uuid import UUID
Expand Down Expand Up @@ -391,8 +391,11 @@ async def query_init(self, db: InfrahubDatabase, **kwargs):
self.return_labels = ["s", "d", "rl"]

for prop_name, prop in self.data.properties.items():
self.add_to_query("MATCH (prop_%s) WHERE ID(prop_%s) = $prop_%s_id" % (prop_name, prop_name, prop_name))
self.params[f"prop_{prop_name}_id"] = element_id_to_id(prop.prop_db_id)
self.add_to_query(
"MATCH (prop_%(prop_name)s) WHERE %(id_func)s(prop_%(prop_name)s) = $prop_%(prop_name)s_id"
% {"prop_name": prop_name, "id_func": db.get_id_function_name()}
)
self.params[f"prop_{prop_name}_id"] = db.to_database_id(prop.prop_db_id)
self.return_labels.append(f"prop_{prop_name}")

self.params["rel_prop"] = self.get_relationship_properties_dict(status=RelationshipStatus.DELETED)
Expand Down
13 changes: 5 additions & 8 deletions backend/infrahub/core/query/standard_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,10 @@ def __init__(self, node_id: str, node_type: str, **kwargs: Any) -> None:
super().__init__(**kwargs)

async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
query = (
"""
MATCH (n:%s)
WHERE ID(n) = $node_id OR n.uuid = $node_id
"""
% self.node_type
)
query = """
MATCH (n:%(node_type)s)
WHERE %(id_func)s(n) = $node_id OR n.uuid = $node_id
""" % {"node_type": self.node_type, "id_func": db.get_id_function_name()}

self.params["node_id"] = self.node_id
self.add_to_query(query)
Expand Down Expand Up @@ -161,4 +158,4 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
self.add_to_query(query)

self.return_labels = ["n"]
self.order_by = ["ID(n)"]
self.order_by = [f"{db.get_id_function_name()}(n)"]
49 changes: 17 additions & 32 deletions backend/infrahub/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@ async def add_relationship(
status=RelationshipStatus.ACTIVE,
):
create_rel_query = """
MATCH (s) WHERE ID(s) = $src_node_id
MATCH (d) WHERE ID(d) = $dst_node_id
MATCH (s) WHERE %(id_func)s(s) = $src_node_id
MATCH (d) WHERE %(id_func)s(d) = $dst_node_id
WITH s,d
CREATE (s)-[r:%s { branch: $branch, branch_level: $branch_level, from: $at, to: null, status: $status }]->(d)
RETURN ID(r)
""" % str(rel_type).upper()
CREATE (s)-[r:%(rel_type)s { branch: $branch, branch_level: $branch_level, from: $at, to: null, status: $status }]->(d)
RETURN %(id_func)s(r)
""" % {"id_func": db.get_id_function_name(), "rel_type": str(rel_type).upper()}

at = Timestamp(at)

params = {
"src_node_id": element_id_to_id(src_node_id),
"dst_node_id": element_id_to_id(dst_node_id),
"src_node_id": db.to_database_id(src_node_id),
"dst_node_id": db.to_database_id(dst_node_id),
"at": at.to_string(),
"branch": branch_name or registry.default_branch,
"branch_level": branch_level or 1,
Expand All @@ -65,18 +65,16 @@ async def update_relationships_to(ids: list[str], db: InfrahubDatabase, to: Time
if not ids:
return None

list_matches = [f"id(r) = {element_id_to_id(id)}" for id in ids]

to = Timestamp(to)

query = f"""
query = """
MATCH ()-[r]->()
WHERE {' or '.join(list_matches)}
WHERE %(id_func)s(r) IN $ids
SET r.to = $to
RETURN ID(r)
"""
RETURN %(id_func)s(r)
""" % {"id_func": db.get_id_function_name()}

params = {"to": to.to_string()}
params = {"to": to.to_string(), "ids": [db.to_database_id(_id) for _id in ids]}

return await db.execute_query(query=query, params=params, name="update_relationships_to")

Expand All @@ -98,20 +96,17 @@ async def get_paths_between_nodes(
relationships_str = ":" + "|".join(relationships)

query = """
MATCH p = (s)-[%s*%s]-(d)
WHERE ID(s) = $source_id AND ID(d) = $destination_id
MATCH p = (s)-[%(rel)s*%(length_limit)s]-(d)
WHERE %(id_func)s(s) = $source_id AND %(id_func)s(d) = $destination_id
RETURN p
""" % (
relationships_str.upper(),
length_limit,
)
""" % {"rel": relationships_str.upper(), "length_limit": length_limit, "id_func": db.get_id_function_name()}

if print_query:
print(query)

params = {
"source_id": element_id_to_id(source_id),
"destination_id": element_id_to_id(destination_id),
"source_id": db.to_database_id(source_id),
"destination_id": db.to_database_id(destination_id),
}

return await db.execute_query(query=query, params=params, name="get_paths_between_nodes")
Expand Down Expand Up @@ -170,16 +165,6 @@ async def delete_all_nodes(db: InfrahubDatabase):
return await db.execute_query(query=query, params=params, name="delete_all_nodes")


def element_id_to_id(element_id: Union[str, int]) -> int:
if isinstance(element_id, int):
return element_id

if isinstance(element_id, str) and ":" not in element_id:
return int(element_id)

return int(element_id.split(":")[2])


def extract_field_filters(field_name: str, filters: dict) -> dict[str, Any]:
"""Extract the filters for a given field (attribute or relationship) from a filters dict."""
return {
Expand Down
13 changes: 13 additions & 0 deletions backend/infrahub/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,19 @@ def render_uuid_generation(self, node_label: str, node_attr: str, index: int = 1
"""
return generate_uuid_query

def get_id_function_name(self) -> str:
if self.db_type == DatabaseType.NEO4J:
return "elementId"
return "ID"

def to_database_id(self, db_id: str | int) -> str | int:
if self.db_type == DatabaseType.NEO4J:
return db_id
try:
return int(db_id)
except ValueError:
return db_id


async def create_database(driver: AsyncDriver, database_name: str) -> None:
default_db = driver.session()
Expand Down
Loading
Loading