diff --git a/backend/infrahub/graphql/mutations/proposed_change.py b/backend/infrahub/graphql/mutations/proposed_change.py index 77a2d9fb44..abe0e747d4 100644 --- a/backend/infrahub/graphql/mutations/proposed_change.py +++ b/backend/infrahub/graphql/mutations/proposed_change.py @@ -71,6 +71,7 @@ async def mutate_create( source_branch=source_branch.name, source_branch_sync_with_git=source_branch.sync_with_git, destination_branch=destination_branch, + context=graphql_context.get_context(), ), ] @@ -176,6 +177,7 @@ async def mutate( source_branch_sync_with_git=source_branch.sync_with_git, destination_branch=destination_branch, check_type=check_type, + context=graphql_context.get_context(), ) if graphql_context.service: await graphql_context.service.message_bus.send(message=message) diff --git a/backend/infrahub/services/adapters/workflow/local.py b/backend/infrahub/services/adapters/workflow/local.py index ef425bb5bd..68c3d4e8df 100644 --- a/backend/infrahub/services/adapters/workflow/local.py +++ b/backend/infrahub/services/adapters/workflow/local.py @@ -35,7 +35,6 @@ async def execute_workflow( inject_context_parameter(func=flow_func, parameters=parameters, context=context) parameters = flow_func.validate_parameters(parameters=parameters) - return await flow_func(**parameters) async def submit_workflow( diff --git a/backend/infrahub/workers/utils.py b/backend/infrahub/workers/utils.py index 80abb14538..2027bd8748 100644 --- a/backend/infrahub/workers/utils.py +++ b/backend/infrahub/workers/utils.py @@ -30,7 +30,7 @@ def inject_service_parameter(func: Flow, parameters: dict[str, Any], service: In def inject_context_parameter(func: Flow, parameters: dict[str, Any], context: InfrahubContext | None = None) -> None: - service_parameter_name = get_parameter_name(func=func, types=[InfrahubContext]) + service_parameter_name = get_parameter_name(func=func, types=[InfrahubContext.__name__, InfrahubContext]) if service_parameter_name and context: parameters[service_parameter_name] = context return diff --git a/backend/tests/unit/api/test_10_query.py b/backend/tests/unit/api/test_10_query.py index 8096a33de7..5d8d58e44e 100644 --- a/backend/tests/unit/api/test_10_query.py +++ b/backend/tests/unit/api/test_10_query.py @@ -6,6 +6,8 @@ import pytest from infrahub import config +from infrahub.auth import AccountSession, AuthType +from infrahub.context import BranchContext, InfrahubContext from infrahub.core.initialization import create_branch from infrahub.groups.models import RequestGraphQLQueryGroupUpdate from infrahub.workflows.catalogue import GRAPHQL_QUERY_GROUP_UPDATE @@ -29,10 +31,10 @@ async def test_query_endpoint_group_no_params( db: InfrahubDatabase, client: TestClient, admin_headers, - create_test_admin, - default_branch, + create_test_admin: Node, + default_branch: Branch, car_person_data, -): +) -> None: # Must execute in a with block to execute the startup/shutdown events with ( client, @@ -44,6 +46,13 @@ async def test_query_endpoint_group_no_params( "/api/query/query01?update_group=true&subscribers=AAAAAA&subscribers=BBBBBB", headers=admin_headers ) + context = InfrahubContext( + branch=BranchContext(name=default_branch.name, id=str(default_branch.get_uuid())), + account=AccountSession( + authenticated=True, account_id=create_test_admin.id, session_id=None, auth_type=AuthType.API + ), + ) + assert "errors" not in response.json() assert response.status_code == 200 assert response.json()["data"] is not None @@ -71,17 +80,19 @@ async def test_query_endpoint_group_no_params( ) expected_calls = [ - call( - workflow=GRAPHQL_QUERY_GROUP_UPDATE, - parameters={"model": model}, - ), + call(workflow=GRAPHQL_QUERY_GROUP_UPDATE, parameters={"model": model}, context=context), ] mock_submit_workflow.assert_has_calls(expected_calls) async def test_query_endpoint_group_params( - db: InfrahubDatabase, client: TestClient, admin_headers, default_branch, create_test_admin, car_person_data -): + db: InfrahubDatabase, + client: TestClient, + admin_headers, + default_branch: Branch, + create_test_admin: Node, + car_person_data, +) -> None: # Must execute in a with block to execute the startup/shutdown events with ( client, @@ -111,11 +122,14 @@ async def test_query_endpoint_group_params( params={"person": "John"}, ) - expected_calls = [ - call( - workflow=GRAPHQL_QUERY_GROUP_UPDATE, - parameters={"model": model}, + context = InfrahubContext( + branch=BranchContext(name=default_branch.name, id=str(default_branch.get_uuid())), + account=AccountSession( + authenticated=True, account_id=create_test_admin.id, session_id=None, auth_type=AuthType.API ), + ) + expected_calls = [ + call(workflow=GRAPHQL_QUERY_GROUP_UPDATE, parameters={"model": model}, context=context), ] mock_submit_workflow.assert_has_calls(expected_calls) diff --git a/backend/tests/unit/api/test_11_artifact.py b/backend/tests/unit/api/test_11_artifact.py index 5a43a8e896..7ffa5cc859 100644 --- a/backend/tests/unit/api/test_11_artifact.py +++ b/backend/tests/unit/api/test_11_artifact.py @@ -4,7 +4,10 @@ from starlette.testclient import TestClient from infrahub import config +from infrahub.auth import AccountSession, AuthType +from infrahub.context import BranchContext, InfrahubContext from infrahub.core import registry +from infrahub.core.branch import Branch from infrahub.core.constants import InfrahubKind from infrahub.core.node import Node from infrahub.database import InfrahubDatabase @@ -78,11 +81,11 @@ async def test_artifact_definition_endpoint( self, db: InfrahubDatabase, admin_headers, - default_branch, + default_branch: Branch, register_core_models_schema, register_builtin_models_schema, car_person_data_generic, - authentication_base, + authentication_base: Node, client, ): _, _, definition = await self.setup_artifact_definition( @@ -107,6 +110,14 @@ async def test_artifact_definition_endpoint( ) assert response.status_code == 200 + + context = InfrahubContext( + branch=BranchContext(name=default_branch.name, id=str(default_branch.get_uuid())), + account=AccountSession( + authenticated=True, account_id=authentication_base.id, session_id=None, auth_type=AuthType.API + ), + ) + expected_calls = [ call( workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, @@ -118,6 +129,7 @@ async def test_artifact_definition_endpoint( limit=[], ) }, + context=context, ), ] mock_submit_workflow.assert_has_calls(expected_calls) diff --git a/backend/tests/unit/conftest.py b/backend/tests/unit/conftest.py index a97535037f..b6c62ccf43 100644 --- a/backend/tests/unit/conftest.py +++ b/backend/tests/unit/conftest.py @@ -2616,8 +2616,8 @@ async def authentication_base( register_core_models_schema, register_builtin_models_schema, register_organization_schema, -): - pass +) -> Node: + return create_test_admin @pytest.fixture diff --git a/backend/tests/unit/core/test_branch_rebase.py b/backend/tests/unit/core/test_branch_rebase.py index 9a6dfffbd0..902560df3e 100644 --- a/backend/tests/unit/core/test_branch_rebase.py +++ b/backend/tests/unit/core/test_branch_rebase.py @@ -1,5 +1,9 @@ +from uuid import uuid4 + import pytest +from infrahub.auth import AccountSession, AuthType +from infrahub.context import InfrahubContext from infrahub.core.branch import Branch from infrahub.core.branch.tasks import rebase_branch from infrahub.core.constants import InfrahubKind @@ -104,4 +108,11 @@ async def test_branch_rebase_diff_conflict( service = await InfrahubServices.new(database=db, workflow=WorkflowLocalExecution()) with pytest.raises(ValidationError, match="contains conflicts with the default branch that must be addressed"): - await rebase_branch(branch=branch2.name, service=service) + await rebase_branch( + branch=branch2.name, + service=service, + context=InfrahubContext.init( + branch=default_branch, + account=AccountSession(account_id=str(uuid4()), auth_type=AuthType.NONE), + ), + ) diff --git a/backend/tests/unit/graphql/mutations/test_proposed_change.py b/backend/tests/unit/graphql/mutations/test_proposed_change.py index 201b2ba399..e41ac8cf5f 100644 --- a/backend/tests/unit/graphql/mutations/test_proposed_change.py +++ b/backend/tests/unit/graphql/mutations/test_proposed_change.py @@ -2,7 +2,7 @@ from prefect.client.orchestration import get_client -from infrahub.auth import AccountSession +from infrahub.auth import AccountSession, AuthType from infrahub.core.branch import Branch from infrahub.core.constants import CheckType, InfrahubKind from infrahub.core.initialization import create_branch @@ -127,7 +127,9 @@ async def test_create_invalid_branch_combinations(db: InfrahubDatabase, default_ ) -async def test_trigger_proposed_change(db: InfrahubDatabase, register_core_models_schema: None): +async def test_trigger_proposed_change( + db: InfrahubDatabase, register_core_models_schema: None, create_test_admin: Node +) -> None: branch_name = "triggered-proposed-change" source_branch = Branch(name=branch_name) await source_branch.save(db=db) @@ -137,8 +139,15 @@ async def test_trigger_proposed_change(db: InfrahubDatabase, register_core_model await proposed_change.save(db=db) all_recorder = BusRecorder() service = await InfrahubServices.new(database=db, message_bus=all_recorder) + account_session = AccountSession( + authenticated=True, account_id=create_test_admin.id, session_id=None, auth_type=AuthType.API + ) all_result = await graphql_mutation( - query=RUN_CHECK, db=db, variables={"proposed_change": proposed_change.id}, service=service + query=RUN_CHECK, + db=db, + variables={"proposed_change": proposed_change.id}, + service=service, + account_session=account_session, ) assert all_result.data assert not all_result.errors @@ -150,6 +159,7 @@ async def test_trigger_proposed_change(db: InfrahubDatabase, register_core_model db=db, variables={"proposed_change": proposed_change.id, "check_type": "ARTIFACT"}, service=service, + account_session=account_session, ) update_status = await graphql_mutation( @@ -157,6 +167,7 @@ async def test_trigger_proposed_change(db: InfrahubDatabase, register_core_model db=db, variables={"proposed_change": proposed_change.id, "state": "canceled"}, service=service, + account_session=account_session, ) cancelled_recorder = BusRecorder() @@ -166,6 +177,7 @@ async def test_trigger_proposed_change(db: InfrahubDatabase, register_core_model db=db, variables={"proposed_change": proposed_change.id, "check_type": "DATA"}, service=service, + account_session=account_session, ) assert len(all_recorder.messages) == 1 diff --git a/backend/tests/unit/graphql/test_mutation_artifact_definition.py b/backend/tests/unit/graphql/test_mutation_artifact_definition.py index fdf495840e..754eaf72c5 100644 --- a/backend/tests/unit/graphql/test_mutation_artifact_definition.py +++ b/backend/tests/unit/graphql/test_mutation_artifact_definition.py @@ -2,6 +2,8 @@ import pytest +from infrahub.auth import AccountSession, AuthType +from infrahub.context import InfrahubContext from infrahub.core.branch import Branch from infrahub.core.constants import InfrahubKind from infrahub.core.manager import NodeManager @@ -68,6 +70,7 @@ async def test_create_artifact_definition( default_branch: Branch, register_core_models_schema, car_person_data_generic, + create_test_admin: Node, group1: Node, transformation1: Node, branch: Branch, @@ -95,7 +98,12 @@ async def test_create_artifact_definition( recorder = BusRecorder() service = await InfrahubServices.new(message_bus=recorder, workflow=WorkflowLocalExecution()) - gql_params = await prepare_graphql_params(db=db, include_subscription=False, branch=branch, service=service) + account_session = AccountSession( + authenticated=True, account_id=create_test_admin.id, session_id=None, auth_type=AuthType.API + ) + gql_params = await prepare_graphql_params( + db=db, include_subscription=False, branch=branch, service=service, account_session=account_session + ) with patch( "infrahub.services.adapters.workflow.local.WorkflowLocalExecution.submit_workflow" @@ -109,6 +117,7 @@ async def test_create_artifact_definition( ) assert result.errors is None + assert result.data assert result.data["CoreArtifactDefinitionCreate"]["ok"] is True ad_id = result.data["CoreArtifactDefinitionCreate"]["object"]["id"] @@ -116,6 +125,11 @@ async def test_create_artifact_definition( assert ad1.name.value == "Artifact 01" + context = InfrahubContext.init( + branch=branch, + account=account_session, + ) + expected_calls = [ call( workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, @@ -127,6 +141,7 @@ async def test_create_artifact_definition( limit=[], ) }, + context=context, ), ] mock_submit_workflow.assert_has_calls(expected_calls) @@ -137,6 +152,7 @@ async def test_update_artifact_definition( default_branch: Branch, register_core_models_schema, car_person_data_generic, + create_test_admin: Node, definition1: Node, branch: Branch, ): @@ -156,8 +172,12 @@ async def test_update_artifact_definition( recorder = BusRecorder() service = await InfrahubServices.new(message_bus=recorder, workflow=WorkflowLocalExecution()) - - gql_params = await prepare_graphql_params(db=db, include_subscription=False, branch=branch, service=service) + account_session = AccountSession( + authenticated=True, account_id=create_test_admin.id, session_id=None, auth_type=AuthType.API + ) + gql_params = await prepare_graphql_params( + db=db, include_subscription=False, branch=branch, service=service, account_session=account_session + ) with patch( "infrahub.services.adapters.workflow.local.WorkflowLocalExecution.submit_workflow" ) as mock_submit_workflow: @@ -170,6 +190,7 @@ async def test_update_artifact_definition( ) assert result.errors is None + assert result.data assert result.data["CoreArtifactDefinitionUpdate"]["ok"] is True ad1_post = await NodeManager.get_one( @@ -178,6 +199,11 @@ async def test_update_artifact_definition( assert ad1_post.artifact_name.value == "myartifact2" + context = InfrahubContext.init( + branch=branch, + account=account_session, + ) + expected_calls = [ call( workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, @@ -189,6 +215,7 @@ async def test_update_artifact_definition( limit=[], ) }, + context=context, ), ] mock_submit_workflow.assert_has_calls(expected_calls)