From 9644918365e94bd05a0edb9a8fac99939dd322de Mon Sep 17 00:00:00 2001 From: Damien Garros Date: Wed, 8 Jan 2025 06:35:52 +0100 Subject: [PATCH] Extend InfrahubTask query to return multiple related nodes --- backend/infrahub/graphql/types/task.py | 18 ++++- backend/infrahub/task_manager/models.py | 36 +++++++++- backend/infrahub/task_manager/task.py | 14 ++-- .../tests/unit/graphql/queries/test_task.py | 69 +++++++++++-------- changelog/+task-deprecated.changed.md | 1 + changelog/+tasknode.added.md | 1 + 6 files changed, 99 insertions(+), 40 deletions(-) create mode 100644 changelog/+task-deprecated.changed.md create mode 100644 changelog/+tasknode.added.md diff --git a/backend/infrahub/graphql/types/task.py b/backend/infrahub/graphql/types/task.py index 7ac37441ef..78c581640f 100644 --- a/backend/infrahub/graphql/types/task.py +++ b/backend/infrahub/graphql/types/task.py @@ -1,6 +1,6 @@ from __future__ import annotations -from graphene import Enum, Field, Float, List, ObjectType, String +from graphene import Enum, Field, Float, List, NonNull, ObjectType, String from graphene.types.generic import GenericScalar from prefect.client.schemas.objects import StateType @@ -28,9 +28,21 @@ class Task(ObjectType): start_time = String(required=False) +class TaskRelatedNode(ObjectType): + id = String(required=True) + kind = String(required=True) + + class TaskNode(Task): - related_node = String(required=False) - related_node_kind = String(required=False) + related_node = String( + required=False, + deprecation_reason="This field is deprecated and it will be removed in a future release, use related_nodes instead", + ) + related_node_kind = String( + required=False, + deprecation_reason="This field is deprecated and it will be removed in a future release, use related_nodes instead", + ) + related_nodes = List(of_type=NonNull(TaskRelatedNode)) logs = Field(TaskLogEdge) diff --git a/backend/infrahub/task_manager/models.py b/backend/infrahub/task_manager/models.py index af8d8eef75..824b5defce 100644 --- a/backend/infrahub/task_manager/models.py +++ b/backend/infrahub/task_manager/models.py @@ -8,12 +8,42 @@ from .constants import LOG_LEVEL_MAPPING +class RelatedNodeInfo(BaseModel): + id: str + kind: str | None = None + + class RelatedNodesInfo(BaseModel): - id: dict[UUID, str] = Field(default_factory=dict) - kind: dict[UUID, str | None] = Field(default_factory=dict) + flows: dict[UUID, dict[str, RelatedNodeInfo]] = Field(default_factory=lambda: defaultdict(dict)) + nodes: dict[str, RelatedNodeInfo] = Field(default_factory=dict) + + def add_nodes(self, flow_id: UUID, node_ids: list[str]) -> None: + for node_id in node_ids: + self.add_node(flow_id=flow_id, node_id=node_id) + + def add_node(self, flow_id: UUID, node_id: str) -> None: + if node_id not in self.nodes: + node = RelatedNodeInfo(id=node_id) + self.nodes[node_id] = node + self.flows[flow_id][node_id] = self.nodes[node_id] + + def get_related_nodes(self, flow_id: UUID) -> list[RelatedNodeInfo]: + if flow_id not in self.flows or len(self.flows[flow_id].keys()) == 0: + return [] + return list(self.flows[flow_id].values()) + + def get_related_nodes_as_dict(self, flow_id: UUID) -> list[dict[str, str | None]]: + if flow_id not in self.flows or len(self.flows[flow_id].keys()) == 0: + return [] + return [item.model_dump() for item in list(self.flows[flow_id].values())] + + def get_first_related_node(self, flow_id: UUID) -> RelatedNodeInfo | None: + if nodes := self.get_related_nodes(flow_id=flow_id): + return nodes[0] + return None def get_unique_related_node_ids(self) -> list[str]: - return list(set(list(self.id.values()))) + return list(self.nodes.keys()) class FlowLogs(BaseModel): diff --git a/backend/infrahub/task_manager/task.py b/backend/infrahub/task_manager/task.py index a259dd29b3..207b47557c 100644 --- a/backend/infrahub/task_manager/task.py +++ b/backend/infrahub/task_manager/task.py @@ -69,15 +69,16 @@ async def _get_related_nodes(cls, db: InfrahubDatabase, flows: list[FlowRun]) -> ] if not related_node_ids: continue - related_nodes.id[flow.id] = related_node_ids[0] + related_nodes.add_nodes(flow_id=flow.id, node_ids=related_node_ids) if unique_related_node_ids := related_nodes.get_unique_related_node_ids(): query = await NodeGetKindQuery.init(db=db, ids=unique_related_node_ids) await query.execute(db=db) unique_related_node_ids_kind = await query.get_node_kind_map() - for flow_id, node_id in related_nodes.id.items(): - related_nodes.kind[flow_id] = unique_related_node_ids_kind.get(node_id, None) + for node_id, node_kind in unique_related_node_ids_kind.items(): + if node_id in related_nodes.nodes: + related_nodes.nodes[node_id].kind = node_kind return related_nodes @@ -238,6 +239,8 @@ async def query( if log_fields: logs = logs_flow.to_graphql(flow_id=flow.id) + related_node = related_nodes_info.get_first_related_node(flow_id=flow.id) + nodes.append( { "node": { @@ -251,8 +254,9 @@ async def query( "branch": await cls._extract_branch_name(flow=flow), "tags": flow.tags, "workflow": workflow_names.get(flow.flow_id, None), - "related_node": related_nodes_info.id.get(flow.id, None), - "related_node_kind": related_nodes_info.kind.get(flow.id, None), + "related_node": related_node.id if related_node else None, + "related_node_kind": related_node.kind if related_node else None, + "related_nodes": related_nodes_info.get_related_nodes_as_dict(flow_id=flow.id), "created_at": flow.created.to_iso8601_string(), # type: ignore "updated_at": flow.updated.to_iso8601_string(), # type: ignore "start_time": flow.start_time.to_iso8601_string() if flow.start_time else None, diff --git a/backend/tests/unit/graphql/queries/test_task.py b/backend/tests/unit/graphql/queries/test_task.py index f5e4ec3d8d..df36e688b5 100644 --- a/backend/tests/unit/graphql/queries/test_task.py +++ b/backend/tests/unit/graphql/queries/test_task.py @@ -18,33 +18,6 @@ from infrahub.workflows.constants import TAG_NAMESPACE, WorkflowTag from tests.helpers.graphql import graphql -CREATE_TASK = """ -mutation CreateTask( - $conclusion: TaskConclusion!, - $title: String!, - $task_id: UUID, - $created_by: String, - $related_node: String!, - $logs: [RelatedTaskLogCreateInput] - ) { - InfrahubTaskCreate( - data: { - id: $task_id, - created_by: $created_by, - title: $title, - conclusion: $conclusion, - related_node: $related_node, - logs: $logs - } - ) { - ok - object { - id - } - } -} -""" - QUERY_TASK = """ query TaskQuery( $related_nodes: [String] @@ -64,6 +37,10 @@ parameters related_node related_node_kind + related_nodes { + id + kind + } title updated_at start_time @@ -86,6 +63,10 @@ id related_node related_node_kind + related_nodes { + id + kind + } title updated_at logs { @@ -157,7 +138,7 @@ async def delete_flow_runs(prefect_client: PrefectClient): @pytest.fixture -async def flow_runs_data(prefect_client: PrefectClient, tag_blue, account_bob): +async def flow_runs_data(prefect_client: PrefectClient, tag_blue, tag_red, account_bob): branch1_tag = WorkflowTag.BRANCH.render(identifier="branch1") db_tag = WorkflowTag.DATABASE_CHANGE.render() items = [ @@ -200,7 +181,13 @@ async def flow_runs_data(prefect_client: PrefectClient, tag_blue, account_bob): flow=dummy_flow_broken, name="dummy-completed-account-br1-db", parameters={"firstname": "xxxx", "lastname": "zzzzz"}, - tags=[TAG_NAMESPACE, WorkflowTag.RELATED_NODE.render(identifier=account_bob.get_id()), branch1_tag, db_tag], + tags=[ + TAG_NAMESPACE, + WorkflowTag.RELATED_NODE.render(identifier=account_bob.get_id()), + WorkflowTag.RELATED_NODE.render(identifier=tag_red.get_id()), + branch1_tag, + db_tag, + ], state=State(type="COMPLETED"), ), await prefect_client.create_flow_run( @@ -489,6 +476,7 @@ async def test_task_query_filter_node( default_branch: Branch, register_core_models_schema: None, tag_blue, + tag_red, account_bob, account_bill, flow_runs_data, @@ -515,6 +503,12 @@ async def test_task_query_filter_node( "parameters": {"firstname": "xxxx", "lastname": "yyy"}, "related_node": tag_blue.get_id(), "related_node_kind": "BuiltinTag", + "related_nodes": [ + { + "id": tag_blue.get_id(), + "kind": "BuiltinTag", + }, + ], "title": flow.name, "updated_at": flow.updated.to_iso8601_string(), "start_time": None, @@ -544,12 +538,23 @@ async def test_task_query_filter_node( "tags": [ "infrahub.app", f"infrahub.app/node/{account_bob.get_id()}", + f"infrahub.app/node/{tag_red.get_id()}", "infrahub.app/branch/branch1", "infrahub.app/database-change", ], "parameters": {"firstname": "xxxx", "lastname": "zzzzz"}, "related_node": account_bob.get_id(), "related_node_kind": "CoreAccount", + "related_nodes": [ + { + "id": account_bob.get_id(), + "kind": "CoreAccount", + }, + { + "id": tag_red.get_id(), + "kind": "BuiltinTag", + }, + ], "title": flow.name, "updated_at": flow.updated.to_iso8601_string(), "start_time": None, @@ -688,6 +693,12 @@ async def test_task_query_progress( "parameters": {"firstname": "xxxx", "lastname": "yyy"}, "related_node": tag_red.get_id(), "related_node_kind": "BuiltinTag", + "related_nodes": [ + { + "id": tag_red.get_id(), + "kind": "BuiltinTag", + }, + ], "title": flow.name, "updated_at": flow.updated.to_iso8601_string(), "start_time": flow.start_time.to_iso8601_string(), diff --git a/changelog/+task-deprecated.changed.md b/changelog/+task-deprecated.changed.md new file mode 100644 index 0000000000..24ff554d3d --- /dev/null +++ b/changelog/+task-deprecated.changed.md @@ -0,0 +1 @@ +The fields `related_node` and `related_node_kind` on the GraphQL query `InfrahubTask` have been deprecated, please use `related_nodes` instead. \ No newline at end of file diff --git a/changelog/+tasknode.added.md b/changelog/+tasknode.added.md new file mode 100644 index 0000000000..d66ab52cb6 --- /dev/null +++ b/changelog/+tasknode.added.md @@ -0,0 +1 @@ +The query InfrahubTask in GraphQL, introduced a new `related_nodes` field to retrieve multiple related nodes per task. \ No newline at end of file