Skip to content

Commit

Permalink
Tests and fixes for InfrahubContext
Browse files Browse the repository at this point in the history
  • Loading branch information
ogenstad committed Feb 10, 2025
1 parent 7f5dbef commit 13e9bc3
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 26 deletions.
2 changes: 2 additions & 0 deletions backend/infrahub/graphql/mutations/proposed_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ async def mutate_create(
source_branch=source_branch.name,
source_branch_sync_with_git=source_branch.sync_with_git,
destination_branch=destination_branch,
context=graphql_context.get_context(),
),
]

Expand Down Expand Up @@ -176,6 +177,7 @@ async def mutate(
source_branch_sync_with_git=source_branch.sync_with_git,
destination_branch=destination_branch,
check_type=check_type,
context=graphql_context.get_context(),
)
if graphql_context.service:
await graphql_context.service.message_bus.send(message=message)
Expand Down
1 change: 0 additions & 1 deletion backend/infrahub/services/adapters/workflow/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ async def execute_workflow(
inject_context_parameter(func=flow_func, parameters=parameters, context=context)

parameters = flow_func.validate_parameters(parameters=parameters)

return await flow_func(**parameters)

async def submit_workflow(
Expand Down
2 changes: 1 addition & 1 deletion backend/infrahub/workers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def inject_service_parameter(func: Flow, parameters: dict[str, Any], service: In


def inject_context_parameter(func: Flow, parameters: dict[str, Any], context: InfrahubContext | None = None) -> None:
service_parameter_name = get_parameter_name(func=func, types=[InfrahubContext])
service_parameter_name = get_parameter_name(func=func, types=[InfrahubContext.__name__, InfrahubContext])
if service_parameter_name and context:
parameters[service_parameter_name] = context
return
Expand Down
40 changes: 27 additions & 13 deletions backend/tests/unit/api/test_10_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pytest

from infrahub import config
from infrahub.auth import AccountSession, AuthType
from infrahub.context import BranchContext, InfrahubContext
from infrahub.core.initialization import create_branch
from infrahub.groups.models import RequestGraphQLQueryGroupUpdate
from infrahub.workflows.catalogue import GRAPHQL_QUERY_GROUP_UPDATE
Expand All @@ -29,10 +31,10 @@ async def test_query_endpoint_group_no_params(
db: InfrahubDatabase,
client: TestClient,
admin_headers,
create_test_admin,
default_branch,
create_test_admin: Node,
default_branch: Branch,
car_person_data,
):
) -> None:
# Must execute in a with block to execute the startup/shutdown events
with (
client,
Expand All @@ -44,6 +46,13 @@ async def test_query_endpoint_group_no_params(
"/api/query/query01?update_group=true&subscribers=AAAAAA&subscribers=BBBBBB", headers=admin_headers
)

context = InfrahubContext(
branch=BranchContext(name=default_branch.name, id=str(default_branch.get_uuid())),
account=AccountSession(
authenticated=True, account_id=create_test_admin.id, session_id=None, auth_type=AuthType.API
),
)

assert "errors" not in response.json()
assert response.status_code == 200
assert response.json()["data"] is not None
Expand Down Expand Up @@ -71,17 +80,19 @@ async def test_query_endpoint_group_no_params(
)

expected_calls = [
call(
workflow=GRAPHQL_QUERY_GROUP_UPDATE,
parameters={"model": model},
),
call(workflow=GRAPHQL_QUERY_GROUP_UPDATE, parameters={"model": model}, context=context),
]
mock_submit_workflow.assert_has_calls(expected_calls)


async def test_query_endpoint_group_params(
db: InfrahubDatabase, client: TestClient, admin_headers, default_branch, create_test_admin, car_person_data
):
db: InfrahubDatabase,
client: TestClient,
admin_headers,
default_branch: Branch,
create_test_admin: Node,
car_person_data,
) -> None:
# Must execute in a with block to execute the startup/shutdown events
with (
client,
Expand Down Expand Up @@ -111,11 +122,14 @@ async def test_query_endpoint_group_params(
params={"person": "John"},
)

expected_calls = [
call(
workflow=GRAPHQL_QUERY_GROUP_UPDATE,
parameters={"model": model},
context = InfrahubContext(
branch=BranchContext(name=default_branch.name, id=str(default_branch.get_uuid())),
account=AccountSession(
authenticated=True, account_id=create_test_admin.id, session_id=None, auth_type=AuthType.API
),
)
expected_calls = [
call(workflow=GRAPHQL_QUERY_GROUP_UPDATE, parameters={"model": model}, context=context),
]
mock_submit_workflow.assert_has_calls(expected_calls)

Expand Down
16 changes: 14 additions & 2 deletions backend/tests/unit/api/test_11_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from starlette.testclient import TestClient

from infrahub import config
from infrahub.auth import AccountSession, AuthType
from infrahub.context import BranchContext, InfrahubContext
from infrahub.core import registry
from infrahub.core.branch import Branch
from infrahub.core.constants import InfrahubKind
from infrahub.core.node import Node
from infrahub.database import InfrahubDatabase
Expand Down Expand Up @@ -78,11 +81,11 @@ async def test_artifact_definition_endpoint(
self,
db: InfrahubDatabase,
admin_headers,
default_branch,
default_branch: Branch,
register_core_models_schema,
register_builtin_models_schema,
car_person_data_generic,
authentication_base,
authentication_base: Node,
client,
):
_, _, definition = await self.setup_artifact_definition(
Expand All @@ -107,6 +110,14 @@ async def test_artifact_definition_endpoint(
)

assert response.status_code == 200

context = InfrahubContext(
branch=BranchContext(name=default_branch.name, id=str(default_branch.get_uuid())),
account=AccountSession(
authenticated=True, account_id=authentication_base.id, session_id=None, auth_type=AuthType.API
),
)

expected_calls = [
call(
workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE,
Expand All @@ -118,6 +129,7 @@ async def test_artifact_definition_endpoint(
limit=[],
)
},
context=context,
),
]
mock_submit_workflow.assert_has_calls(expected_calls)
Expand Down
4 changes: 2 additions & 2 deletions backend/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2616,8 +2616,8 @@ async def authentication_base(
register_core_models_schema,
register_builtin_models_schema,
register_organization_schema,
):
pass
) -> Node:
return create_test_admin


@pytest.fixture
Expand Down
13 changes: 12 additions & 1 deletion backend/tests/unit/core/test_branch_rebase.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from uuid import uuid4

import pytest

from infrahub.auth import AccountSession, AuthType
from infrahub.context import InfrahubContext
from infrahub.core.branch import Branch
from infrahub.core.branch.tasks import rebase_branch
from infrahub.core.constants import InfrahubKind
Expand Down Expand Up @@ -104,4 +108,11 @@ async def test_branch_rebase_diff_conflict(
service = await InfrahubServices.new(database=db, workflow=WorkflowLocalExecution())

with pytest.raises(ValidationError, match="contains conflicts with the default branch that must be addressed"):
await rebase_branch(branch=branch2.name, service=service)
await rebase_branch(
branch=branch2.name,
service=service,
context=InfrahubContext.init(
branch=default_branch,
account=AccountSession(account_id=str(uuid4()), auth_type=AuthType.NONE),
),
)
18 changes: 15 additions & 3 deletions backend/tests/unit/graphql/mutations/test_proposed_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from prefect.client.orchestration import get_client

from infrahub.auth import AccountSession
from infrahub.auth import AccountSession, AuthType
from infrahub.core.branch import Branch
from infrahub.core.constants import CheckType, InfrahubKind
from infrahub.core.initialization import create_branch
Expand Down Expand Up @@ -127,7 +127,9 @@ async def test_create_invalid_branch_combinations(db: InfrahubDatabase, default_
)


async def test_trigger_proposed_change(db: InfrahubDatabase, register_core_models_schema: None):
async def test_trigger_proposed_change(
db: InfrahubDatabase, register_core_models_schema: None, create_test_admin: Node
) -> None:
branch_name = "triggered-proposed-change"
source_branch = Branch(name=branch_name)
await source_branch.save(db=db)
Expand All @@ -137,8 +139,15 @@ async def test_trigger_proposed_change(db: InfrahubDatabase, register_core_model
await proposed_change.save(db=db)
all_recorder = BusRecorder()
service = await InfrahubServices.new(database=db, message_bus=all_recorder)
account_session = AccountSession(
authenticated=True, account_id=create_test_admin.id, session_id=None, auth_type=AuthType.API
)
all_result = await graphql_mutation(
query=RUN_CHECK, db=db, variables={"proposed_change": proposed_change.id}, service=service
query=RUN_CHECK,
db=db,
variables={"proposed_change": proposed_change.id},
service=service,
account_session=account_session,
)
assert all_result.data
assert not all_result.errors
Expand All @@ -150,13 +159,15 @@ async def test_trigger_proposed_change(db: InfrahubDatabase, register_core_model
db=db,
variables={"proposed_change": proposed_change.id, "check_type": "ARTIFACT"},
service=service,
account_session=account_session,
)

update_status = await graphql_mutation(
query=UPDATE_PROPOSED_CHANGE,
db=db,
variables={"proposed_change": proposed_change.id, "state": "canceled"},
service=service,
account_session=account_session,
)

cancelled_recorder = BusRecorder()
Expand All @@ -166,6 +177,7 @@ async def test_trigger_proposed_change(db: InfrahubDatabase, register_core_model
db=db,
variables={"proposed_change": proposed_change.id, "check_type": "DATA"},
service=service,
account_session=account_session,
)

assert len(all_recorder.messages) == 1
Expand Down
33 changes: 30 additions & 3 deletions backend/tests/unit/graphql/test_mutation_artifact_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest

from infrahub.auth import AccountSession, AuthType
from infrahub.context import InfrahubContext
from infrahub.core.branch import Branch
from infrahub.core.constants import InfrahubKind
from infrahub.core.manager import NodeManager
Expand Down Expand Up @@ -68,6 +70,7 @@ async def test_create_artifact_definition(
default_branch: Branch,
register_core_models_schema,
car_person_data_generic,
create_test_admin: Node,
group1: Node,
transformation1: Node,
branch: Branch,
Expand Down Expand Up @@ -95,7 +98,12 @@ async def test_create_artifact_definition(
recorder = BusRecorder()
service = await InfrahubServices.new(message_bus=recorder, workflow=WorkflowLocalExecution())

gql_params = await prepare_graphql_params(db=db, include_subscription=False, branch=branch, service=service)
account_session = AccountSession(
authenticated=True, account_id=create_test_admin.id, session_id=None, auth_type=AuthType.API
)
gql_params = await prepare_graphql_params(
db=db, include_subscription=False, branch=branch, service=service, account_session=account_session
)

with patch(
"infrahub.services.adapters.workflow.local.WorkflowLocalExecution.submit_workflow"
Expand All @@ -109,13 +117,19 @@ async def test_create_artifact_definition(
)

assert result.errors is None
assert result.data
assert result.data["CoreArtifactDefinitionCreate"]["ok"] is True
ad_id = result.data["CoreArtifactDefinitionCreate"]["object"]["id"]

ad1 = await NodeManager.get_one(db=db, id=ad_id, include_owner=True, include_source=True, branch=branch)

assert ad1.name.value == "Artifact 01"

context = InfrahubContext.init(
branch=branch,
account=account_session,
)

expected_calls = [
call(
workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE,
Expand All @@ -127,6 +141,7 @@ async def test_create_artifact_definition(
limit=[],
)
},
context=context,
),
]
mock_submit_workflow.assert_has_calls(expected_calls)
Expand All @@ -137,6 +152,7 @@ async def test_update_artifact_definition(
default_branch: Branch,
register_core_models_schema,
car_person_data_generic,
create_test_admin: Node,
definition1: Node,
branch: Branch,
):
Expand All @@ -156,8 +172,12 @@ async def test_update_artifact_definition(

recorder = BusRecorder()
service = await InfrahubServices.new(message_bus=recorder, workflow=WorkflowLocalExecution())

gql_params = await prepare_graphql_params(db=db, include_subscription=False, branch=branch, service=service)
account_session = AccountSession(
authenticated=True, account_id=create_test_admin.id, session_id=None, auth_type=AuthType.API
)
gql_params = await prepare_graphql_params(
db=db, include_subscription=False, branch=branch, service=service, account_session=account_session
)
with patch(
"infrahub.services.adapters.workflow.local.WorkflowLocalExecution.submit_workflow"
) as mock_submit_workflow:
Expand All @@ -170,6 +190,7 @@ async def test_update_artifact_definition(
)

assert result.errors is None
assert result.data
assert result.data["CoreArtifactDefinitionUpdate"]["ok"] is True

ad1_post = await NodeManager.get_one(
Expand All @@ -178,6 +199,11 @@ async def test_update_artifact_definition(

assert ad1_post.artifact_name.value == "myartifact2"

context = InfrahubContext.init(
branch=branch,
account=account_session,
)

expected_calls = [
call(
workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE,
Expand All @@ -189,6 +215,7 @@ async def test_update_artifact_definition(
limit=[],
)
},
context=context,
),
]
mock_submit_workflow.assert_has_calls(expected_calls)

0 comments on commit 13e9bc3

Please sign in to comment.