Skip to content

Commit

Permalink
Ensure Transform and Checks are executed with the correct timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
dgarros committed Jan 20, 2025
1 parent 4d7d824 commit d3561ec
Show file tree
Hide file tree
Showing 24 changed files with 170 additions and 84 deletions.
5 changes: 4 additions & 1 deletion backend/infrahub/api/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ async def generate_artifact(

service = request.app.state.service
model = RequestArtifactDefinitionGenerate(
artifact_definition=artifact_definition.id, branch=branch_params.branch.name, limit=payload.nodes
artifact_definition_id=artifact_definition.id,
artifact_definition_name=artifact_definition.name.value,
branch=branch_params.branch.name,
limit=payload.nodes,
)

await service.workflow.submit_workflow(workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, parameters={"model": model})
2 changes: 2 additions & 0 deletions backend/infrahub/api/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def transform_python(
commit=repository.commit.value, # type: ignore[attr-defined]
branch=branch_params.branch.name,
transform_location=f"{transform.file_path.value}::{transform.class_name.value}",
timeout=transform.timeout.value,
data=data,
)

Expand Down Expand Up @@ -140,6 +141,7 @@ async def transform_jinja2(
commit=repository.commit.value, # type: ignore[attr-defined]
branch=branch_params.branch.name,
template_location=transform.template_path.value,
timeout=transform.timeout.value,
data=data,
)

Expand Down
14 changes: 9 additions & 5 deletions backend/infrahub/computed_attribute/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from typing import TYPE_CHECKING, Any

import ujson
from infrahub_sdk.protocols import CoreNode # noqa: TC002
from infrahub_sdk.protocols import (
CoreNode, # noqa: TC002
CoreTransformPython,
)
from prefect import flow
from prefect.automations import AutomationCore
from prefect.client.orchestration import get_client
Expand Down Expand Up @@ -88,17 +91,18 @@ async def process_transform(

for attribute_name, transform_attribute in transform_attributes.items():
transform = await service.client.get(
kind="CoreTransformPython",
kind=CoreTransformPython,
branch=branch_name,
id=transform_attribute.transform,
prefetch_relationships=True,
populate_store=True,
)

if not transform:
continue

repo_node = await service.client.get(
kind=transform.repository.peer.typename,
kind=str(transform.repository.peer.typename),
branch=branch_name,
id=transform.repository.peer.id,
raise_when_missing=True,
Expand All @@ -108,7 +112,7 @@ async def process_transform(
repository_id=transform.repository.peer.id,
name=transform.repository.peer.name.value,
service=service,
repository_kind=transform.repository.peer.typename,
repository_kind=str(transform.repository.peer.typename),
commit=repo_node.commit.value,
)

Expand All @@ -120,7 +124,7 @@ async def process_transform(
subscribers=[object_id],
)

transformed_data = await repo.execute_python_transform(
transformed_data = await repo.execute_python_transform.with_options(timeout_seconds=transform.timeout.value)(
branch_name=branch_name,
commit=repo_node.commit.value,
location=f"{transform.file_path.value}::{transform.class_name.value}",
Expand Down
22 changes: 12 additions & 10 deletions backend/infrahub/generators/tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from infrahub_sdk.exceptions import ModuleImportError
from infrahub_sdk.node import InfrahubNode
from infrahub_sdk.protocols import CoreGeneratorInstance
Expand All @@ -16,17 +18,17 @@
from infrahub.git.repository import get_initialized_repo
from infrahub.services import InfrahubServices, services
from infrahub.workflows.catalogue import REQUEST_GENERATOR_DEFINITION_RUN, REQUEST_GENERATOR_RUN
from infrahub.workflows.utils import add_branch_tag
from infrahub.workflows.utils import add_tags


@flow(
name="generator-run",
flow_run_name="Run generator {model.generator_definition.definition_name} for {model.target_name}",
flow_run_name="Run generator {model.generator_definition.definition_name}",
)
async def run_generator(model: RequestGeneratorRun) -> None:
service = services.service

await add_branch_tag(branch_name=model.branch_name)
await add_tags(branches=[model.branch_name], nodes=[model.target_id])

repository = await get_initialized_repo(
repository_id=model.repository_id,
Expand Down Expand Up @@ -70,10 +72,10 @@ async def run_generator(model: RequestGeneratorRun) -> None:
)
await generator.run(identifier=generator_definition.name)
generator_instance.status.value = GeneratorInstanceStatus.READY.value
except ModuleImportError:
generator_instance.status.value = GeneratorInstanceStatus.ERROR.value
except Exception: # pylint: disable=broad-exception-caught
except (ModuleImportError, Exception): # pylint: disable=broad-exception-caught
generator_instance.status.value = GeneratorInstanceStatus.ERROR.value
await generator_instance.update(do_full_update=True)
raise

await generator_instance.update(do_full_update=True)

Expand Down Expand Up @@ -116,11 +118,11 @@ async def _define_instance(model: RequestGeneratorRun, service: InfrahubServices
return instance


@flow(name="generator_definition_run", flow_run_name="Run all generators")
@flow(name="generator-definition-run", flow_run_name="Run all generators")
async def run_generator_definition(branch: str) -> None:
service = services.service

await add_branch_tag(branch_name=branch)
await add_tags(branches=[branch])

generators = await service.client.filters(
kind=InfrahubKind.GENERATORDEFINITION, prefetch_relationships=True, populate_store=True, branch=branch
Expand Down Expand Up @@ -148,13 +150,13 @@ async def run_generator_definition(branch: str) -> None:


@flow(
name="request_generator_definition_run",
name="request-generator-definition-run",
flow_run_name="Execute generator {model.generator_definition.definition_name}",
)
async def request_generator_definition_run(model: RequestGeneratorDefinitionRun) -> None:
service = services.service

await add_branch_tag(branch_name=model.branch)
await add_tags(branches=[model.branch], nodes=[model.generator_definition.definition_id])

group = await service.client.get(
kind=InfrahubKind.GENERICGROUP,
Expand Down
22 changes: 12 additions & 10 deletions backend/infrahub/git/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

from infrahub_sdk.checks import InfrahubCheck
from infrahub_sdk.node import InfrahubNode
from infrahub_sdk.schema import InfrahubRepositoryArtifactDefinitionConfig
from infrahub_sdk.schema.repository import InfrahubRepositoryArtifactDefinitionConfig
from infrahub_sdk.transforms import InfrahubTransform

from infrahub.git.models import RequestArtifactGenerate
Expand Down Expand Up @@ -771,7 +771,7 @@ async def import_python_transforms(
if str(self.directory_root) not in sys.path:
sys.path.append(str(self.directory_root))

transforms = []
transforms: list[TransformPythonInformation] = []
log.info(f"Found {len(config_file.python_transforms)} Python transforms in the repository")

for transform in config_file.python_transforms:
Expand Down Expand Up @@ -801,7 +801,7 @@ async def import_python_transforms(
transform_definition_in_graph = {
transform.name.value: transform
for transform in await self.sdk.filters(
kind=InfrahubKind.TRANSFORMPYTHON, branch=branch_name, repository__ids=[str(self.id)]
kind=CoreTransformPython, branch=branch_name, repository__ids=[str(self.id)]
)
}

Expand Down Expand Up @@ -1029,7 +1029,7 @@ async def create_python_transform(
create_payload = self.sdk.schema.generate_payload_create(
schema=schema,
data=data,
source=self.id,
source=str(self.id),
is_protected=True,
)
obj = await self.sdk.create(kind=CoreTransformPython, branch=branch_name, **create_payload)
Expand Down Expand Up @@ -1222,12 +1222,14 @@ async def artifact_generate(
)

if transformation.typename == InfrahubKind.TRANSFORMJINJA2:
artifact_content = await self.render_jinja2_template(
commit=commit, location=transformation.template_path.value, data=response
)
artifact_content = await self.render_jinja2_template.with_options(
timeout_seconds=transformation.timeout.value
)(commit=commit, location=transformation.template_path.value, data=response)
elif transformation.typename == InfrahubKind.TRANSFORMPYTHON:
transformation_location = f"{transformation.file_path.value}::{transformation.class_name.value}"
artifact_content = await self.execute_python_transform(
artifact_content = await self.execute_python_transform.with_options(
timeout_seconds=transformation.timeout.value
)(
branch_name=branch_name,
commit=commit,
location=transformation_location,
Expand Down Expand Up @@ -1271,11 +1273,11 @@ async def render_artifact(
)

if message.transform_type == InfrahubKind.TRANSFORMJINJA2:
artifact_content = await self.render_jinja2_template(
artifact_content = await self.render_jinja2_template.with_options(timeout_seconds=message.timeout)(
commit=message.commit, location=message.transform_location, data=response
)
elif message.transform_type == InfrahubKind.TRANSFORMPYTHON:
artifact_content = await self.execute_python_transform(
artifact_content = await self.execute_python_transform.with_options(timeout_seconds=message.timeout)(
branch_name=message.branch_name,
commit=message.commit,
location=message.transform_location,
Expand Down
5 changes: 3 additions & 2 deletions backend/infrahub/git/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
class RequestArtifactDefinitionGenerate(BaseModel):
"""Sent to trigger the generation of artifacts for a given branch."""

artifact_definition: str = Field(..., description="The unique ID of the Artifact Definition")
artifact_definition_id: str = Field(..., description="The unique ID of the Artifact Definition")
artifact_definition_name: str = Field(..., description="The name of the Artifact Definition")
branch: str = Field(..., description="The branch to target")
limit: list[str] = Field(
default_factory=list,
Expand All @@ -18,7 +19,7 @@ class RequestArtifactGenerate(BaseModel):
"""Runs to generate an artifact"""

artifact_name: str = Field(..., description="Name of the artifact")
artifact_definition: str = Field(..., description="The the ID of the artifact definition")
artifact_definition: str = Field(..., description="The ID of the artifact definition")
commit: str = Field(..., description="The commit to target")
content_type: str = Field(..., description="Content type of the artifact")
transform_type: str = Field(..., description="The type of transform associated with this artifact")
Expand Down
24 changes: 16 additions & 8 deletions backend/infrahub/git/tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import timedelta

from infrahub_sdk import InfrahubClient
from infrahub_sdk.protocols import CoreRepository
from infrahub_sdk.protocols import CoreArtifactDefinition, CoreRepository
from prefect import flow, task
from prefect.automations import AutomationCore
from prefect.cache_policies import NONE
Expand Down Expand Up @@ -243,10 +243,14 @@ async def generate_artifact_definition(branch: str) -> None:
service = services.service
await add_branch_tag(branch_name=branch)

artifact_definitions = await service.client.all(kind=InfrahubKind.ARTIFACTDEFINITION, branch=branch, include=["id"])
artifact_definitions = await service.client.all(kind=CoreArtifactDefinition, branch=branch, include=["id"])

for artifact_definition in artifact_definitions:
model = RequestArtifactDefinitionGenerate(branch=branch, artifact_definition=artifact_definition.id)
model = RequestArtifactDefinitionGenerate(
branch=branch,
artifact_definition_id=artifact_definition.id,
artifact_definition_name=artifact_definition.name.value,
)
await service.workflow.submit_workflow(
workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, parameters={"model": model}
)
Expand Down Expand Up @@ -277,15 +281,19 @@ async def generate_artifact(model: RequestArtifactGenerate) -> None:
log.exception("Failed to generate artifact")
artifact.status.value = "Error"
await artifact.save()
raise


@flow(name="request_artifact_definitions_generate", flow_run_name="Trigger Generation of Artifacts for ")
@flow(
name="request_artifact_definitions_generate",
flow_run_name="Trigger Generation of Artifacts for {model.artifact_definition_name}",
)
async def generate_request_artifact_definition(model: RequestArtifactDefinitionGenerate) -> None:
service = services.service
await add_tags(branches=[model.branch])
await add_tags(branches=[model.branch], nodes=[model.artifact_definition_id])

artifact_definition = await service.client.get(
kind=InfrahubKind.ARTIFACTDEFINITION, id=model.artifact_definition, branch=model.branch
kind=InfrahubKind.ARTIFACTDEFINITION, id=model.artifact_definition_id, branch=model.branch
)

await artifact_definition.targets.fetch()
Expand All @@ -295,7 +303,7 @@ async def generate_request_artifact_definition(model: RequestArtifactDefinitionG

existing_artifacts = await service.client.filters(
kind=InfrahubKind.ARTIFACT,
definition__ids=[model.artifact_definition],
definition__ids=[model.artifact_definition_id],
include=["object"],
branch=model.branch,
)
Expand Down Expand Up @@ -334,7 +342,7 @@ async def generate_request_artifact_definition(model: RequestArtifactDefinitionG
request_artifact_generate_model = RequestArtifactGenerate(
artifact_name=artifact_definition.name.value,
artifact_id=artifact_id,
artifact_definition=model.artifact_definition,
artifact_definition=model.artifact_definition_id,
commit=repository.commit.value,
content_type=artifact_definition.content_type.value,
transform_type=transform.typename,
Expand Down
12 changes: 10 additions & 2 deletions backend/infrahub/graphql/mutations/artifact_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ async def mutate_create(
artifact_definition, result = await super().mutate_create(info=info, data=data, branch=branch)

if context.service:
model = RequestArtifactDefinitionGenerate(branch=branch.name, artifact_definition=artifact_definition.id)
model = RequestArtifactDefinitionGenerate(
branch=branch.name,
artifact_definition_id=artifact_definition.id,
artifact_definition_name=artifact_definition.name.value, # type: ignore[attr-defined]
)
await context.service.workflow.submit_workflow(
workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, parameters={"model": model}
)
Expand All @@ -76,7 +80,11 @@ async def mutate_update(
artifact_definition, result = await super().mutate_update(info=info, data=data, branch=branch)

if context.service:
model = RequestArtifactDefinitionGenerate(branch=branch.name, artifact_definition=artifact_definition.id)
model = RequestArtifactDefinitionGenerate(
branch=branch.name,
artifact_definition_id=artifact_definition.id,
artifact_definition_name=artifact_definition.name.value, # type: ignore[attr-defined]
)
await context.service.workflow.submit_workflow(
workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, parameters={"model": model}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ class CheckRepositoryUserCheck(InfrahubMessage):
variables: dict = Field(default_factory=dict, description="Input variables when running the check")
name: str = Field(..., description="The name of the check")
branch_diff: ProposedChangeBranchDiff = Field(..., description="The calculated diff between the two branches")
timeout: int = Field(..., description="The timeout for the check")
7 changes: 5 additions & 2 deletions backend/infrahub/message_bus/operations/check/repository.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from infrahub_sdk.protocols import CoreCheckDefinition
from infrahub_sdk.uuidt import UUIDT
from prefect import flow
from prefect.logging import get_run_logger
Expand All @@ -23,7 +24,7 @@ async def check_definition(message: messages.CheckRepositoryCheckDefinition, ser
log = get_run_logger()

definition = await service.client.get(
kind=InfrahubKind.CHECKDEFINITION, id=message.check_definition_id, branch=message.branch_name
kind=CoreCheckDefinition, id=message.check_definition_id, branch=message.branch_name
)
proposed_change = await service.client.get(kind=InfrahubKind.PROPOSEDCHANGE, id=message.proposed_change)
validator_execution_id = str(UUIDT())
Expand Down Expand Up @@ -87,6 +88,7 @@ async def check_definition(message: messages.CheckRepositoryCheckDefinition, ser
proposed_change=message.proposed_change,
variables=member.extract(params=definition.parameters.value),
branch_diff=message.branch_diff,
timeout=definition.timeout.value,
)
)

Expand All @@ -108,6 +110,7 @@ async def check_definition(message: messages.CheckRepositoryCheckDefinition, ser
check_definition_id=message.check_definition_id,
proposed_change=message.proposed_change,
branch_diff=message.branch_diff,
timeout=definition.timeout.value,
)
)

Expand Down Expand Up @@ -229,7 +232,7 @@ async def user_check(message: messages.CheckRepositoryUserCheck, service: Infrah
severity = "critical"
log_entries = ""
try:
check_run = await repo.execute_python_check(
check_run = await repo.execute_python_check.with_options(timeout_seconds=message.timeout)(
branch_name=message.branch_name,
location=message.file_path,
class_name=message.class_name,
Expand Down
1 change: 1 addition & 0 deletions backend/infrahub/transformations/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DEFAULT_TRANSFORM_TIMEOUT = 10
4 changes: 3 additions & 1 deletion backend/infrahub/transformations/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class TransformPythonData(BaseModel):
data: dict = Field(..., description="Input data for the template")
branch: str = Field(..., description="The branch to target")
transform_location: str = Field(..., description="Location of the transform within the repository")
commit: str = Field(..., description="The commit id to use when rendering the template")
commit: str = Field(..., description="The commit id to use when generating the artifact")
timeout: int = Field(..., description="The timeout value to use when generating the artifact")


class TransformJinjaTemplateData(BaseModel):
Expand All @@ -23,3 +24,4 @@ class TransformJinjaTemplateData(BaseModel):
branch: str = Field(..., description="The branch to target")
template_location: str = Field(..., description="Location of the template within the repository")
commit: str = Field(..., description="The commit id to use when rendering the template")
timeout: int = Field(..., description="The timeout value to use when rendering the template")
Loading

0 comments on commit d3561ec

Please sign in to comment.