Skip to content

Commit

Permalink
fix(backend): ignore prefect typing errors
Browse files Browse the repository at this point in the history
Prefect 3.1.15 introduced a change that breaks mypy with
error: Incompatible types in "await" (actual type "State[Coroutine[Any, Any, None]]", expected type
"Awaitable[Any]")  [misc]

Signed-off-by: Fatih Acar <fatih@opsmill.com>
  • Loading branch information
fatih-acar committed Mar 3, 2025
1 parent 2f38211 commit 9c9bbe3
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 35 deletions.
4 changes: 2 additions & 2 deletions backend/infrahub/computed_attribute/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def process_transform(
service=service,
repository_kind=str(transform.repository.peer.typename),
commit=repo_node.commit.value,
)
) # type: ignore[misc]

data = await service.client.query_gql_query(
name=transform.query.peer.name.value,
Expand All @@ -131,7 +131,7 @@ async def process_transform(
location=f"{transform.file_path.value}::{transform.class_name.value}",
data=data,
client=service.client,
)
) # type: ignore[misc]

await service.client.execute_graphql(
query=UPDATE_ATTRIBUTE,
Expand Down
30 changes: 15 additions & 15 deletions backend/infrahub/git/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,27 +171,27 @@ async def import_objects_from_files(
self.create_commit_worktree(commit)
await self._update_sync_status(branch_name=infrahub_branch_name, status=RepositorySyncStatus.SYNCING)

config_file = await self.get_repository_config(branch_name=infrahub_branch_name, commit=commit)
config_file = await self.get_repository_config(branch_name=infrahub_branch_name, commit=commit) # type: ignore[misc]
sync_status = RepositorySyncStatus.IN_SYNC if config_file else RepositorySyncStatus.ERROR_IMPORT
error: Exception | None = None

try:
if config_file:
await self.import_schema_files(branch_name=infrahub_branch_name, commit=commit, config_file=config_file)
await self.import_schema_files(branch_name=infrahub_branch_name, commit=commit, config_file=config_file) # type: ignore[misc]

await self.import_all_graphql_query(
branch_name=infrahub_branch_name, commit=commit, config_file=config_file
)
) # type: ignore[misc]

await self.import_all_python_files( # type: ignore[call-overload]
branch_name=infrahub_branch_name, commit=commit, config_file=config_file
)
) # type: ignore[misc]
await self.import_jinja2_transforms(
branch_name=infrahub_branch_name, commit=commit, config_file=config_file
)
) # type: ignore[misc]
await self.import_artifact_definitions(
branch_name=infrahub_branch_name, commit=commit, config_file=config_file
)
) # type: ignore[misc]

except Exception as exc:
sync_status = RepositorySyncStatus.ERROR_IMPORT
Expand Down Expand Up @@ -636,7 +636,7 @@ async def import_python_check_definitions(
module=module,
file_path=file_info.relative_path_file,
check_definition=check,
)
) # type: ignore[misc]
)

local_check_definitions = {check.name: check for check in checks}
Expand Down Expand Up @@ -797,7 +797,7 @@ async def import_python_transforms(
module=module,
file_path=file_info.relative_path_file,
transform=transform,
)
) # type: ignore[misc]
)

local_transform_definitions = {transform.name: transform for transform in transforms}
Expand Down Expand Up @@ -1070,9 +1070,9 @@ async def import_all_python_files(
) -> None:
await add_tags(branches=[branch_name], nodes=[str(self.id)])

await self.import_python_check_definitions(branch_name=branch_name, commit=commit, config_file=config_file)
await self.import_python_transforms(branch_name=branch_name, commit=commit, config_file=config_file)
await self.import_generator_definitions(branch_name=branch_name, commit=commit, config_file=config_file)
await self.import_python_check_definitions(branch_name=branch_name, commit=commit, config_file=config_file) # type: ignore[misc]
await self.import_python_transforms(branch_name=branch_name, commit=commit, config_file=config_file) # type: ignore[misc]
await self.import_generator_definitions(branch_name=branch_name, commit=commit, config_file=config_file) # type: ignore[misc]

@task(name="jinja2-template-render", task_run_name="Render Jinja2 template", cache_policy=NONE) # type: ignore[arg-type]
async def render_jinja2_template(self, commit: str, location: str, data: dict) -> str:
Expand Down Expand Up @@ -1227,7 +1227,7 @@ async def artifact_generate(
if transformation.typename == InfrahubKind.TRANSFORMJINJA2:
artifact_content = await self.render_jinja2_template.with_options(
timeout_seconds=transformation.timeout.value
)(commit=commit, location=transformation.template_path.value, data=response)
)(commit=commit, location=transformation.template_path.value, data=response) # type: ignore[misc]
elif transformation.typename == InfrahubKind.TRANSFORMPYTHON:
transformation_location = f"{transformation.file_path.value}::{transformation.class_name.value}"
artifact_content = await self.execute_python_transform.with_options(
Expand All @@ -1238,7 +1238,7 @@ async def artifact_generate(
location=transformation_location,
data=response,
client=self.sdk,
)
) # type: ignore[misc]

if definition.content_type.value == ContentType.APPLICATION_JSON.value and isinstance(artifact_content, dict):
artifact_content_str = ujson.dumps(artifact_content, indent=2)
Expand Down Expand Up @@ -1289,15 +1289,15 @@ async def render_artifact(
if message.transform_type == InfrahubKind.TRANSFORMJINJA2:
artifact_content = await self.render_jinja2_template.with_options(timeout_seconds=message.timeout)(
commit=message.commit, location=message.transform_location, data=response
)
) # type: ignore[misc]
elif message.transform_type == InfrahubKind.TRANSFORMPYTHON:
artifact_content = await self.execute_python_transform.with_options(timeout_seconds=message.timeout)(
branch_name=message.branch_name,
commit=message.commit,
location=message.transform_location,
data=response,
client=self.sdk,
)
) # type: ignore[misc]

if message.content_type == ContentType.APPLICATION_JSON.value and isinstance(artifact_content, dict):
artifact_content_str = ujson.dumps(artifact_content, indent=2)
Expand Down
2 changes: 1 addition & 1 deletion backend/infrahub/git/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ async def run_user_check(model: UserCheckData, service: InfrahubServices) -> Val
client=service.client,
commit=model.commit,
params=model.variables,
)
) # type: ignore[misc]
if check_run.passed:
conclusion = ValidatorConclusion.SUCCESS
severity = "info"
Expand Down
4 changes: 2 additions & 2 deletions backend/infrahub/transformations/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def transform_python(message: TransformPythonData, service: InfrahubServic
location=message.transform_location,
data=message.data,
client=service.client,
)
) # type: ignore[misc]

return transformed_data

Expand All @@ -49,6 +49,6 @@ async def transform_render_jinja2_template(message: TransformJinjaTemplateData,

rendered_template = await repo.render_jinja2_template.with_options(timeout_seconds=message.timeout)(
commit=message.commit, location=message.template_location, data={"data": message.data}
)
) # type: ignore[misc]

return rendered_template
2 changes: 1 addition & 1 deletion backend/infrahub/webhook/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ async def _prepare_payload(self, data: dict[str, Any], context: EventContext, se
location=f"{self.transform_file}::{self.transform_class}",
data={"data": data, **context.model_dump()},
client=service.client,
)
) # type: ignore[misc]

@classmethod
def from_object(cls, obj: CoreCustomWebhook, transform: CoreTransformPython) -> Self:
Expand Down
2 changes: 1 addition & 1 deletion backend/infrahub/webhook/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def configure_webhook_all(service: InfrahubServices) -> None:
triggers=triggers,
trigger_type=TriggerType.WEBHOOK,
deprecated_triggers=[AUTOMATION_NAME_RUN],
)
) # type: ignore[misc]

log.info(f"{len(triggers)} Webhooks automation configuration completed")

Expand Down
26 changes: 13 additions & 13 deletions backend/tests/integration/git/test_git_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,39 +117,39 @@ async def repo(

async def test_import_schema_files(self, db: InfrahubDatabase, client: InfrahubClient, repo: InfrahubRepository):
commit = repo.get_commit_value(branch_name="main")
config_file = await repo.get_repository_config(branch_name="main", commit=commit)
config_file = await repo.get_repository_config(branch_name="main", commit=commit) # type: ignore[misc]
assert config_file
await repo.import_schema_files(branch_name="main", commit=commit, config_file=config_file)
await repo.import_schema_files(branch_name="main", commit=commit, config_file=config_file) # type: ignore[misc]

assert await client.schema.get(kind="DemoEdgeFabric", refresh=True)

async def test_import_schema_files_from_directory(
self, db: InfrahubDatabase, client: InfrahubClient, repo: InfrahubRepository
):
commit = repo.get_commit_value(branch_name="main")
config_file = await repo.get_repository_config(branch_name="main", commit=commit)
config_file = await repo.get_repository_config(branch_name="main", commit=commit) # type: ignore[misc]
assert config_file

config_file.schemas = [Path("schemas")]
await repo.import_schema_files(branch_name="main", commit=commit, config_file=config_file)
await repo.import_schema_files(branch_name="main", commit=commit, config_file=config_file) # type: ignore[misc]

assert await client.schema.get(kind="DemoEdgeFabric", refresh=True)

async def test_import_all_graphql_query(
self, db: InfrahubDatabase, client: InfrahubClient, repo: InfrahubRepository
):
commit = repo.get_commit_value(branch_name="main")
config_file = await repo.get_repository_config(branch_name="main", commit=commit)
config_file = await repo.get_repository_config(branch_name="main", commit=commit) # type: ignore[misc]
assert config_file

await repo.import_all_graphql_query(branch_name="main", commit=commit, config_file=config_file)
await repo.import_all_graphql_query(branch_name="main", commit=commit, config_file=config_file) # type: ignore[misc]

queries = await client.all(kind=CoreGraphQLQuery)
assert len(queries) == 5

# Validate if the function is idempotent, another import just after the first one shouldn't change anything
nbr_relationships_before = await count_relationships(db=db)
await repo.import_all_graphql_query(branch_name="main", commit=commit, config_file=config_file)
await repo.import_all_graphql_query(branch_name="main", commit=commit, config_file=config_file) # type: ignore[misc]
assert await count_relationships(db=db) == nbr_relationships_before

# 1. Modify an object to validate if its being properly updated
Expand All @@ -167,7 +167,7 @@ async def test_import_all_graphql_query(
)
await obj.save(db=db)

await repo.import_all_graphql_query(branch_name="main", commit=commit, config_file=config_file)
await repo.import_all_graphql_query(branch_name="main", commit=commit, config_file=config_file) # type: ignore[misc]

modified_query = await client.get(kind=CoreGraphQLQuery, id=queries[0].id)
assert modified_query.query.value == value_before_change
Expand All @@ -179,7 +179,7 @@ async def test_import_all_python_files(
self, db: InfrahubDatabase, client: InfrahubClient, repo: InfrahubRepository, query_99
):
commit = repo.get_commit_value(branch_name="main")
config_file = await repo.get_repository_config(branch_name="main", commit=commit)
config_file = await repo.get_repository_config(branch_name="main", commit=commit) # type: ignore[misc]
assert config_file

await repo.import_all_python_files(branch_name="main", commit=commit, config_file=config_file) # type: ignore[call-overload]
Expand Down Expand Up @@ -256,16 +256,16 @@ async def test_import_all_yaml_files(
self, db: InfrahubDatabase, client: InfrahubClient, repo: InfrahubRepository, query_99
):
commit = repo.get_commit_value(branch_name="main")
config_file = await repo.get_repository_config(branch_name="main", commit=commit)
config_file = await repo.get_repository_config(branch_name="main", commit=commit) # type: ignore[misc]
assert config_file
await repo.import_jinja2_transforms(branch_name="main", commit=commit, config_file=config_file)
await repo.import_jinja2_transforms(branch_name="main", commit=commit, config_file=config_file) # type: ignore[misc]

rfiles = await client.all(kind=CoreTransformJinja2)
assert len(rfiles) == 2

# Validate if the function is idempotent, another import just after the first one shouldn't change anything
nbr_relationships_before = await count_relationships(db=db)
await repo.import_jinja2_transforms(branch_name="main", commit=commit, config_file=config_file)
await repo.import_jinja2_transforms(branch_name="main", commit=commit, config_file=config_file) # type: ignore[misc]
assert await count_relationships(db=db) == nbr_relationships_before

# 1. Modify an object to validate if its being properly updated
Expand All @@ -286,7 +286,7 @@ async def test_import_all_yaml_files(
)
await obj.save(db=db)

await repo.import_jinja2_transforms(branch_name="main", commit=commit, config_file=config_file)
await repo.import_jinja2_transforms(branch_name="main", commit=commit, config_file=config_file) # type: ignore[misc]

modified_rfile = await client.get(kind=CoreTransformJinja2, id=rfiles[0].id)
assert modified_rfile.template_path.value == rfile_template_path_value_before_change
Expand Down

0 comments on commit 9c9bbe3

Please sign in to comment.