Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure Transforms and Checks are executed with the correct timeout #5471

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion backend/infrahub/api/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ async def generate_artifact(

service = request.app.state.service
model = RequestArtifactDefinitionGenerate(
artifact_definition=artifact_definition.id, branch=branch_params.branch.name, limit=payload.nodes
artifact_definition_id=artifact_definition.id,
artifact_definition_name=artifact_definition.name.value,
branch=branch_params.branch.name,
limit=payload.nodes,
)

await service.workflow.submit_workflow(workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, parameters={"model": model})
2 changes: 2 additions & 0 deletions backend/infrahub/api/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def transform_python(
commit=repository.commit.value, # type: ignore[attr-defined]
branch=branch_params.branch.name,
transform_location=f"{transform.file_path.value}::{transform.class_name.value}",
timeout=transform.timeout.value,
data=data,
)

Expand Down Expand Up @@ -140,6 +141,7 @@ async def transform_jinja2(
commit=repository.commit.value, # type: ignore[attr-defined]
branch=branch_params.branch.name,
template_location=transform.template_path.value,
timeout=transform.timeout.value,
data=data,
)

Expand Down
14 changes: 9 additions & 5 deletions backend/infrahub/computed_attribute/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from typing import TYPE_CHECKING, Any

import ujson
from infrahub_sdk.protocols import CoreNode # noqa: TC002
from infrahub_sdk.protocols import (
CoreNode, # noqa: TC002
CoreTransformPython,
)
from prefect import flow
from prefect.automations import AutomationCore
from prefect.client.orchestration import get_client
Expand Down Expand Up @@ -88,17 +91,18 @@ async def process_transform(

for attribute_name, transform_attribute in transform_attributes.items():
transform = await service.client.get(
kind="CoreTransformPython",
kind=CoreTransformPython,
branch=branch_name,
id=transform_attribute.transform,
prefetch_relationships=True,
populate_store=True,
)

if not transform:
continue

repo_node = await service.client.get(
kind=transform.repository.peer.typename,
kind=str(transform.repository.peer.typename),
branch=branch_name,
id=transform.repository.peer.id,
raise_when_missing=True,
Expand All @@ -108,7 +112,7 @@ async def process_transform(
repository_id=transform.repository.peer.id,
name=transform.repository.peer.name.value,
service=service,
repository_kind=transform.repository.peer.typename,
repository_kind=str(transform.repository.peer.typename),
commit=repo_node.commit.value,
)

Expand All @@ -120,7 +124,7 @@ async def process_transform(
subscribers=[object_id],
)

transformed_data = await repo.execute_python_transform(
transformed_data = await repo.execute_python_transform.with_options(timeout_seconds=transform.timeout.value)(
branch_name=branch_name,
commit=repo_node.commit.value,
location=f"{transform.file_path.value}::{transform.class_name.value}",
Expand Down
22 changes: 12 additions & 10 deletions backend/infrahub/generators/tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from infrahub_sdk.exceptions import ModuleImportError
from infrahub_sdk.node import InfrahubNode
from infrahub_sdk.protocols import CoreGeneratorInstance
Expand All @@ -16,17 +18,17 @@
from infrahub.git.repository import get_initialized_repo
from infrahub.services import InfrahubServices, services
from infrahub.workflows.catalogue import REQUEST_GENERATOR_DEFINITION_RUN, REQUEST_GENERATOR_RUN
from infrahub.workflows.utils import add_branch_tag
from infrahub.workflows.utils import add_tags


@flow(
name="generator-run",
flow_run_name="Run generator {model.generator_definition.definition_name} for {model.target_name}",
flow_run_name="Run generator {model.generator_definition.definition_name}",
)
async def run_generator(model: RequestGeneratorRun) -> None:
service = services.service

await add_branch_tag(branch_name=model.branch_name)
await add_tags(branches=[model.branch_name], nodes=[model.target_id])

repository = await get_initialized_repo(
repository_id=model.repository_id,
Expand Down Expand Up @@ -70,10 +72,10 @@ async def run_generator(model: RequestGeneratorRun) -> None:
)
await generator.run(identifier=generator_definition.name)
generator_instance.status.value = GeneratorInstanceStatus.READY.value
except ModuleImportError:
generator_instance.status.value = GeneratorInstanceStatus.ERROR.value
except Exception: # pylint: disable=broad-exception-caught
except (ModuleImportError, Exception): # pylint: disable=broad-exception-caught
generator_instance.status.value = GeneratorInstanceStatus.ERROR.value
await generator_instance.update(do_full_update=True)
raise

await generator_instance.update(do_full_update=True)

Expand Down Expand Up @@ -116,11 +118,11 @@ async def _define_instance(model: RequestGeneratorRun, service: InfrahubServices
return instance


@flow(name="generator_definition_run", flow_run_name="Run all generators")
@flow(name="generator-definition-run", flow_run_name="Run all generators")
async def run_generator_definition(branch: str) -> None:
service = services.service

await add_branch_tag(branch_name=branch)
await add_tags(branches=[branch])

generators = await service.client.filters(
kind=InfrahubKind.GENERATORDEFINITION, prefetch_relationships=True, populate_store=True, branch=branch
Expand Down Expand Up @@ -148,13 +150,13 @@ async def run_generator_definition(branch: str) -> None:


@flow(
name="request_generator_definition_run",
name="request-generator-definition-run",
flow_run_name="Execute generator {model.generator_definition.definition_name}",
)
async def request_generator_definition_run(model: RequestGeneratorDefinitionRun) -> None:
service = services.service

await add_branch_tag(branch_name=model.branch)
await add_tags(branches=[model.branch], nodes=[model.generator_definition.definition_id])

group = await service.client.get(
kind=InfrahubKind.GENERICGROUP,
Expand Down
22 changes: 12 additions & 10 deletions backend/infrahub/git/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

from infrahub_sdk.checks import InfrahubCheck
from infrahub_sdk.node import InfrahubNode
from infrahub_sdk.schema import InfrahubRepositoryArtifactDefinitionConfig
from infrahub_sdk.schema.repository import InfrahubRepositoryArtifactDefinitionConfig
from infrahub_sdk.transforms import InfrahubTransform

from infrahub.git.models import RequestArtifactGenerate
Expand Down Expand Up @@ -771,7 +771,7 @@ async def import_python_transforms(
if str(self.directory_root) not in sys.path:
sys.path.append(str(self.directory_root))

transforms = []
transforms: list[TransformPythonInformation] = []
log.info(f"Found {len(config_file.python_transforms)} Python transforms in the repository")

for transform in config_file.python_transforms:
Expand Down Expand Up @@ -801,7 +801,7 @@ async def import_python_transforms(
transform_definition_in_graph = {
transform.name.value: transform
for transform in await self.sdk.filters(
kind=InfrahubKind.TRANSFORMPYTHON, branch=branch_name, repository__ids=[str(self.id)]
kind=CoreTransformPython, branch=branch_name, repository__ids=[str(self.id)]
)
}

Expand Down Expand Up @@ -1029,7 +1029,7 @@ async def create_python_transform(
create_payload = self.sdk.schema.generate_payload_create(
schema=schema,
data=data,
source=self.id,
source=str(self.id),
is_protected=True,
)
obj = await self.sdk.create(kind=CoreTransformPython, branch=branch_name, **create_payload)
Expand Down Expand Up @@ -1222,12 +1222,14 @@ async def artifact_generate(
)

if transformation.typename == InfrahubKind.TRANSFORMJINJA2:
artifact_content = await self.render_jinja2_template(
commit=commit, location=transformation.template_path.value, data=response
)
artifact_content = await self.render_jinja2_template.with_options(
timeout_seconds=transformation.timeout.value
)(commit=commit, location=transformation.template_path.value, data=response)
elif transformation.typename == InfrahubKind.TRANSFORMPYTHON:
transformation_location = f"{transformation.file_path.value}::{transformation.class_name.value}"
artifact_content = await self.execute_python_transform(
artifact_content = await self.execute_python_transform.with_options(
timeout_seconds=transformation.timeout.value
)(
branch_name=branch_name,
commit=commit,
location=transformation_location,
Expand Down Expand Up @@ -1271,11 +1273,11 @@ async def render_artifact(
)

if message.transform_type == InfrahubKind.TRANSFORMJINJA2:
artifact_content = await self.render_jinja2_template(
artifact_content = await self.render_jinja2_template.with_options(timeout_seconds=message.timeout)(
commit=message.commit, location=message.transform_location, data=response
)
elif message.transform_type == InfrahubKind.TRANSFORMPYTHON:
artifact_content = await self.execute_python_transform(
artifact_content = await self.execute_python_transform.with_options(timeout_seconds=message.timeout)(
branch_name=message.branch_name,
commit=message.commit,
location=message.transform_location,
Expand Down
5 changes: 3 additions & 2 deletions backend/infrahub/git/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
class RequestArtifactDefinitionGenerate(BaseModel):
"""Sent to trigger the generation of artifacts for a given branch."""

artifact_definition: str = Field(..., description="The unique ID of the Artifact Definition")
artifact_definition_id: str = Field(..., description="The unique ID of the Artifact Definition")
artifact_definition_name: str = Field(..., description="The name of the Artifact Definition")
branch: str = Field(..., description="The branch to target")
limit: list[str] = Field(
default_factory=list,
Expand All @@ -18,7 +19,7 @@ class RequestArtifactGenerate(BaseModel):
"""Runs to generate an artifact"""

artifact_name: str = Field(..., description="Name of the artifact")
artifact_definition: str = Field(..., description="The the ID of the artifact definition")
artifact_definition: str = Field(..., description="The ID of the artifact definition")
commit: str = Field(..., description="The commit to target")
content_type: str = Field(..., description="Content type of the artifact")
transform_type: str = Field(..., description="The type of transform associated with this artifact")
Expand Down
24 changes: 16 additions & 8 deletions backend/infrahub/git/tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import timedelta

from infrahub_sdk import InfrahubClient
from infrahub_sdk.protocols import CoreRepository
from infrahub_sdk.protocols import CoreArtifactDefinition, CoreRepository
from prefect import flow, task
from prefect.automations import AutomationCore
from prefect.cache_policies import NONE
Expand Down Expand Up @@ -243,10 +243,14 @@ async def generate_artifact_definition(branch: str) -> None:
service = services.service
await add_branch_tag(branch_name=branch)

artifact_definitions = await service.client.all(kind=InfrahubKind.ARTIFACTDEFINITION, branch=branch, include=["id"])
artifact_definitions = await service.client.all(kind=CoreArtifactDefinition, branch=branch, include=["id"])

for artifact_definition in artifact_definitions:
model = RequestArtifactDefinitionGenerate(branch=branch, artifact_definition=artifact_definition.id)
model = RequestArtifactDefinitionGenerate(
branch=branch,
artifact_definition_id=artifact_definition.id,
artifact_definition_name=artifact_definition.name.value,
)
await service.workflow.submit_workflow(
workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, parameters={"model": model}
)
Expand Down Expand Up @@ -277,15 +281,19 @@ async def generate_artifact(model: RequestArtifactGenerate) -> None:
log.exception("Failed to generate artifact")
artifact.status.value = "Error"
await artifact.save()
raise


@flow(name="request_artifact_definitions_generate", flow_run_name="Trigger Generation of Artifacts for ")
@flow(
name="request_artifact_definitions_generate",
flow_run_name="Trigger Generation of Artifacts for {model.artifact_definition_name}",
)
async def generate_request_artifact_definition(model: RequestArtifactDefinitionGenerate) -> None:
service = services.service
await add_tags(branches=[model.branch])
await add_tags(branches=[model.branch], nodes=[model.artifact_definition_id])

artifact_definition = await service.client.get(
kind=InfrahubKind.ARTIFACTDEFINITION, id=model.artifact_definition, branch=model.branch
kind=InfrahubKind.ARTIFACTDEFINITION, id=model.artifact_definition_id, branch=model.branch
)

await artifact_definition.targets.fetch()
Expand All @@ -295,7 +303,7 @@ async def generate_request_artifact_definition(model: RequestArtifactDefinitionG

existing_artifacts = await service.client.filters(
kind=InfrahubKind.ARTIFACT,
definition__ids=[model.artifact_definition],
definition__ids=[model.artifact_definition_id],
include=["object"],
branch=model.branch,
)
Expand Down Expand Up @@ -334,7 +342,7 @@ async def generate_request_artifact_definition(model: RequestArtifactDefinitionG
request_artifact_generate_model = RequestArtifactGenerate(
artifact_name=artifact_definition.name.value,
artifact_id=artifact_id,
artifact_definition=model.artifact_definition,
artifact_definition=model.artifact_definition_id,
commit=repository.commit.value,
content_type=artifact_definition.content_type.value,
transform_type=transform.typename,
Expand Down
12 changes: 10 additions & 2 deletions backend/infrahub/graphql/mutations/artifact_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ async def mutate_create(
artifact_definition, result = await super().mutate_create(info=info, data=data, branch=branch)

if context.service:
model = RequestArtifactDefinitionGenerate(branch=branch.name, artifact_definition=artifact_definition.id)
model = RequestArtifactDefinitionGenerate(
branch=branch.name,
artifact_definition_id=artifact_definition.id,
artifact_definition_name=artifact_definition.name.value, # type: ignore[attr-defined]
)
await context.service.workflow.submit_workflow(
workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, parameters={"model": model}
)
Expand All @@ -76,7 +80,11 @@ async def mutate_update(
artifact_definition, result = await super().mutate_update(info=info, data=data, branch=branch)

if context.service:
model = RequestArtifactDefinitionGenerate(branch=branch.name, artifact_definition=artifact_definition.id)
model = RequestArtifactDefinitionGenerate(
branch=branch.name,
artifact_definition_id=artifact_definition.id,
artifact_definition_name=artifact_definition.name.value, # type: ignore[attr-defined]
)
await context.service.workflow.submit_workflow(
workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, parameters={"model": model}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ class CheckRepositoryUserCheck(InfrahubMessage):
variables: dict = Field(default_factory=dict, description="Input variables when running the check")
name: str = Field(..., description="The name of the check")
branch_diff: ProposedChangeBranchDiff = Field(..., description="The calculated diff between the two branches")
timeout: int = Field(..., description="The timeout for the check")
7 changes: 5 additions & 2 deletions backend/infrahub/message_bus/operations/check/repository.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from infrahub_sdk.protocols import CoreCheckDefinition
from infrahub_sdk.uuidt import UUIDT
from prefect import flow
from prefect.logging import get_run_logger
Expand All @@ -23,7 +24,7 @@ async def check_definition(message: messages.CheckRepositoryCheckDefinition, ser
log = get_run_logger()

definition = await service.client.get(
kind=InfrahubKind.CHECKDEFINITION, id=message.check_definition_id, branch=message.branch_name
kind=CoreCheckDefinition, id=message.check_definition_id, branch=message.branch_name
)
proposed_change = await service.client.get(kind=InfrahubKind.PROPOSEDCHANGE, id=message.proposed_change)
validator_execution_id = str(UUIDT())
Expand Down Expand Up @@ -87,6 +88,7 @@ async def check_definition(message: messages.CheckRepositoryCheckDefinition, ser
proposed_change=message.proposed_change,
variables=member.extract(params=definition.parameters.value),
branch_diff=message.branch_diff,
timeout=definition.timeout.value,
)
)

Expand All @@ -108,6 +110,7 @@ async def check_definition(message: messages.CheckRepositoryCheckDefinition, ser
check_definition_id=message.check_definition_id,
proposed_change=message.proposed_change,
branch_diff=message.branch_diff,
timeout=definition.timeout.value,
)
)

Expand Down Expand Up @@ -229,7 +232,7 @@ async def user_check(message: messages.CheckRepositoryUserCheck, service: Infrah
severity = "critical"
log_entries = ""
try:
check_run = await repo.execute_python_check(
check_run = await repo.execute_python_check.with_options(timeout_seconds=message.timeout)(
branch_name=message.branch_name,
location=message.file_path,
class_name=message.class_name,
Expand Down
1 change: 1 addition & 0 deletions backend/infrahub/transformations/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DEFAULT_TRANSFORM_TIMEOUT = 10
4 changes: 3 additions & 1 deletion backend/infrahub/transformations/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class TransformPythonData(BaseModel):
data: dict = Field(..., description="Input data for the template")
branch: str = Field(..., description="The branch to target")
transform_location: str = Field(..., description="Location of the transform within the repository")
commit: str = Field(..., description="The commit id to use when rendering the template")
commit: str = Field(..., description="The commit id to use when generating the artifact")
timeout: int = Field(..., description="The timeout value to use when generating the artifact")


class TransformJinjaTemplateData(BaseModel):
Expand All @@ -23,3 +24,4 @@ class TransformJinjaTemplateData(BaseModel):
branch: str = Field(..., description="The branch to target")
template_location: str = Field(..., description="Location of the template within the repository")
commit: str = Field(..., description="The commit id to use when rendering the template")
timeout: int = Field(..., description="The timeout value to use when rendering the template")
Loading
Loading