Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend InfrahubTask query to return multiple related nodes #5378

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions backend/infrahub/graphql/types/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from graphene import Enum, Field, Float, List, ObjectType, String
from graphene import Enum, Field, Float, List, NonNull, ObjectType, String
from graphene.types.generic import GenericScalar
from prefect.client.schemas.objects import StateType

Expand Down Expand Up @@ -28,9 +28,21 @@ class Task(ObjectType):
start_time = String(required=False)


class TaskRelatedNode(ObjectType):
id = String(required=True)
kind = String(required=True)


class TaskNode(Task):
related_node = String(required=False)
related_node_kind = String(required=False)
related_node = String(
required=False,
deprecation_reason="This field is deprecated and it will be removed in a future release, use related_nodes instead",
)
related_node_kind = String(
required=False,
deprecation_reason="This field is deprecated and it will be removed in a future release, use related_nodes instead",
)
related_nodes = List(of_type=NonNull(TaskRelatedNode))
logs = Field(TaskLogEdge)


Expand Down
36 changes: 33 additions & 3 deletions backend/infrahub/task_manager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,42 @@
from .constants import LOG_LEVEL_MAPPING


class RelatedNodeInfo(BaseModel):
id: str
kind: str | None = None


class RelatedNodesInfo(BaseModel):
id: dict[UUID, str] = Field(default_factory=dict)
kind: dict[UUID, str | None] = Field(default_factory=dict)
flows: dict[UUID, dict[str, RelatedNodeInfo]] = Field(default_factory=lambda: defaultdict(dict))
nodes: dict[str, RelatedNodeInfo] = Field(default_factory=dict)

def add_nodes(self, flow_id: UUID, node_ids: list[str]) -> None:
for node_id in node_ids:
self.add_node(flow_id=flow_id, node_id=node_id)

def add_node(self, flow_id: UUID, node_id: str) -> None:
if node_id not in self.nodes:
node = RelatedNodeInfo(id=node_id)
self.nodes[node_id] = node
self.flows[flow_id][node_id] = self.nodes[node_id]

def get_related_nodes(self, flow_id: UUID) -> list[RelatedNodeInfo]:
if flow_id not in self.flows or len(self.flows[flow_id].keys()) == 0:
return []
return list(self.flows[flow_id].values())

def get_related_nodes_as_dict(self, flow_id: UUID) -> list[dict[str, str | None]]:
if flow_id not in self.flows or len(self.flows[flow_id].keys()) == 0:
return []
return [item.model_dump() for item in list(self.flows[flow_id].values())]

def get_first_related_node(self, flow_id: UUID) -> RelatedNodeInfo | None:
if nodes := self.get_related_nodes(flow_id=flow_id):
return nodes[0]
return None

def get_unique_related_node_ids(self) -> list[str]:
return list(set(list(self.id.values())))
return list(self.nodes.keys())


class FlowLogs(BaseModel):
Expand Down
14 changes: 9 additions & 5 deletions backend/infrahub/task_manager/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,16 @@ async def _get_related_nodes(cls, db: InfrahubDatabase, flows: list[FlowRun]) ->
]
if not related_node_ids:
continue
related_nodes.id[flow.id] = related_node_ids[0]
related_nodes.add_nodes(flow_id=flow.id, node_ids=related_node_ids)

if unique_related_node_ids := related_nodes.get_unique_related_node_ids():
query = await NodeGetKindQuery.init(db=db, ids=unique_related_node_ids)
await query.execute(db=db)
unique_related_node_ids_kind = await query.get_node_kind_map()

for flow_id, node_id in related_nodes.id.items():
related_nodes.kind[flow_id] = unique_related_node_ids_kind.get(node_id, None)
for node_id, node_kind in unique_related_node_ids_kind.items():
if node_id in related_nodes.nodes:
related_nodes.nodes[node_id].kind = node_kind

return related_nodes

Expand Down Expand Up @@ -238,6 +239,8 @@ async def query(
if log_fields:
logs = logs_flow.to_graphql(flow_id=flow.id)

related_node = related_nodes_info.get_first_related_node(flow_id=flow.id)

nodes.append(
{
"node": {
Expand All @@ -251,8 +254,9 @@ async def query(
"branch": await cls._extract_branch_name(flow=flow),
"tags": flow.tags,
"workflow": workflow_names.get(flow.flow_id, None),
"related_node": related_nodes_info.id.get(flow.id, None),
"related_node_kind": related_nodes_info.kind.get(flow.id, None),
"related_node": related_node.id if related_node else None,
"related_node_kind": related_node.kind if related_node else None,
"related_nodes": related_nodes_info.get_related_nodes_as_dict(flow_id=flow.id),
"created_at": flow.created.to_iso8601_string(), # type: ignore
"updated_at": flow.updated.to_iso8601_string(), # type: ignore
"start_time": flow.start_time.to_iso8601_string() if flow.start_time else None,
Expand Down
69 changes: 40 additions & 29 deletions backend/tests/unit/graphql/queries/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,6 @@
from infrahub.workflows.constants import TAG_NAMESPACE, WorkflowTag
from tests.helpers.graphql import graphql

CREATE_TASK = """
mutation CreateTask(
$conclusion: TaskConclusion!,
$title: String!,
$task_id: UUID,
$created_by: String,
$related_node: String!,
$logs: [RelatedTaskLogCreateInput]
) {
InfrahubTaskCreate(
data: {
id: $task_id,
created_by: $created_by,
title: $title,
conclusion: $conclusion,
related_node: $related_node,
logs: $logs
}
) {
ok
object {
id
}
}
}
"""

QUERY_TASK = """
query TaskQuery(
$related_nodes: [String]
Expand All @@ -64,6 +37,10 @@
parameters
related_node
related_node_kind
related_nodes {
id
kind
}
title
updated_at
start_time
Expand All @@ -86,6 +63,10 @@
id
related_node
related_node_kind
related_nodes {
id
kind
}
title
updated_at
logs {
Expand Down Expand Up @@ -157,7 +138,7 @@ async def delete_flow_runs(prefect_client: PrefectClient):


@pytest.fixture
async def flow_runs_data(prefect_client: PrefectClient, tag_blue, account_bob):
async def flow_runs_data(prefect_client: PrefectClient, tag_blue, tag_red, account_bob):
branch1_tag = WorkflowTag.BRANCH.render(identifier="branch1")
db_tag = WorkflowTag.DATABASE_CHANGE.render()
items = [
Expand Down Expand Up @@ -200,7 +181,13 @@ async def flow_runs_data(prefect_client: PrefectClient, tag_blue, account_bob):
flow=dummy_flow_broken,
name="dummy-completed-account-br1-db",
parameters={"firstname": "xxxx", "lastname": "zzzzz"},
tags=[TAG_NAMESPACE, WorkflowTag.RELATED_NODE.render(identifier=account_bob.get_id()), branch1_tag, db_tag],
tags=[
TAG_NAMESPACE,
WorkflowTag.RELATED_NODE.render(identifier=account_bob.get_id()),
WorkflowTag.RELATED_NODE.render(identifier=tag_red.get_id()),
branch1_tag,
db_tag,
],
state=State(type="COMPLETED"),
),
await prefect_client.create_flow_run(
Expand Down Expand Up @@ -489,6 +476,7 @@ async def test_task_query_filter_node(
default_branch: Branch,
register_core_models_schema: None,
tag_blue,
tag_red,
account_bob,
account_bill,
flow_runs_data,
Expand All @@ -515,6 +503,12 @@ async def test_task_query_filter_node(
"parameters": {"firstname": "xxxx", "lastname": "yyy"},
"related_node": tag_blue.get_id(),
"related_node_kind": "BuiltinTag",
"related_nodes": [
{
"id": tag_blue.get_id(),
"kind": "BuiltinTag",
},
],
"title": flow.name,
"updated_at": flow.updated.to_iso8601_string(),
"start_time": None,
Expand Down Expand Up @@ -544,12 +538,23 @@ async def test_task_query_filter_node(
"tags": [
"infrahub.app",
f"infrahub.app/node/{account_bob.get_id()}",
f"infrahub.app/node/{tag_red.get_id()}",
"infrahub.app/branch/branch1",
"infrahub.app/database-change",
],
"parameters": {"firstname": "xxxx", "lastname": "zzzzz"},
"related_node": account_bob.get_id(),
"related_node_kind": "CoreAccount",
"related_nodes": [
{
"id": account_bob.get_id(),
"kind": "CoreAccount",
},
{
"id": tag_red.get_id(),
"kind": "BuiltinTag",
},
],
"title": flow.name,
"updated_at": flow.updated.to_iso8601_string(),
"start_time": None,
Expand Down Expand Up @@ -688,6 +693,12 @@ async def test_task_query_progress(
"parameters": {"firstname": "xxxx", "lastname": "yyy"},
"related_node": tag_red.get_id(),
"related_node_kind": "BuiltinTag",
"related_nodes": [
{
"id": tag_red.get_id(),
"kind": "BuiltinTag",
},
],
"title": flow.name,
"updated_at": flow.updated.to_iso8601_string(),
"start_time": flow.start_time.to_iso8601_string(),
Expand Down
1 change: 1 addition & 0 deletions changelog/+task-deprecated.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The fields `related_node` and `related_node_kind` on the GraphQL query `InfrahubTask` have been deprecated, please use `related_nodes` instead.
1 change: 1 addition & 0 deletions changelog/+tasknode.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The query InfrahubTask in GraphQL, introduced a new `related_nodes` field to retrieve multiple related nodes per task.
Loading