From 5e0857b65b39b815e4f9391b93c16ab84dd6ee22 Mon Sep 17 00:00:00 2001 From: Patrick Ogenstad Date: Thu, 27 Feb 2025 10:47:58 +0100 Subject: [PATCH] Provide option to set GraphQL context information as input to mutations --- backend/infrahub/graphql/context.py | 26 ++++++ backend/infrahub/graphql/manager.py | 13 +-- backend/infrahub/graphql/mutations/branch.py | 27 +++++- .../graphql/mutations/computed_attribute.py | 5 ++ backend/infrahub/graphql/mutations/diff.py | 5 ++ .../graphql/mutations/diff_conflict.py | 5 ++ .../infrahub/graphql/mutations/generator.py | 6 +- backend/infrahub/graphql/mutations/main.py | 13 ++- .../graphql/mutations/relationship.py | 12 +++ backend/infrahub/graphql/mutations/schema.py | 15 ++++ backend/infrahub/graphql/types/context.py | 12 +++ .../mutations/test_mutation_context.py | 88 +++++++++++++++++++ 12 files changed, 215 insertions(+), 12 deletions(-) create mode 100644 backend/infrahub/graphql/context.py create mode 100644 backend/infrahub/graphql/types/context.py create mode 100644 backend/tests/unit/graphql/mutations/test_mutation_context.py diff --git a/backend/infrahub/graphql/context.py b/backend/infrahub/graphql/context.py new file mode 100644 index 0000000000..67cb6b54b3 --- /dev/null +++ b/backend/infrahub/graphql/context.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from infrahub.core.constants import InfrahubKind +from infrahub.core.manager import NodeManager +from infrahub.exceptions import NodeNotFoundError, ValidationError + +if TYPE_CHECKING: + from .initialization import GraphqlContext + from .types.context import ContextInput + + +async def apply_external_context(graphql_context: GraphqlContext, context_input: ContextInput | None) -> None: + """Applies context provided by an external mutation to the GraphQL context""" + if not context_input or not context_input.account: + return + + try: + account = await NodeManager.get_one_by_id_or_default_filter( + db=graphql_context.db, id=str(context_input.account.id), kind=InfrahubKind.GENERICACCOUNT + ) + except NodeNotFoundError as exc: + raise ValidationError(input_value="Unable to set context for account that doesn't exist") from exc + + graphql_context.active_account_session.account_id = account.id diff --git a/backend/infrahub/graphql/manager.py b/backend/infrahub/graphql/manager.py index 44a78d8847..3e92d4a3be 100644 --- a/backend/infrahub/graphql/manager.py +++ b/backend/infrahub/graphql/manager.py @@ -60,6 +60,7 @@ ) from .types.attribute import BaseAttribute as BaseAttributeType from .types.attribute import TextAttributeType +from .types.context import ContextInput from .types.event import EVENT_TYPES if TYPE_CHECKING: @@ -809,9 +810,7 @@ def generate_graphql_mutation_create( meta_attrs: dict[str, Any] = {"schema": schema, "name": name, "description": schema.description} main_attrs["Meta"] = type("Meta", (object,), meta_attrs) - args_attrs = { - "data": input_type(required=True), - } + args_attrs = {"data": input_type(required=True), "context": ContextInput(required=False)} main_attrs["Arguments"] = type("Arguments", (object,), args_attrs) return type(name, (base_class,), main_attrs) @@ -832,9 +831,7 @@ def generate_graphql_mutation_update( meta_attrs: dict[str, Any] = {"schema": schema, "name": name, "description": schema.description} main_attrs["Meta"] = type("Meta", (object,), meta_attrs) - args_attrs = { - "data": input_type(required=True), - } + args_attrs = {"data": input_type(required=True), "context": ContextInput(required=False)} main_attrs["Arguments"] = type("Arguments", (object,), args_attrs) return type(name, (base_class,), main_attrs) @@ -851,9 +848,7 @@ def generate_graphql_mutation_delete( meta_attrs = {"schema": schema, "name": name, "description": schema.description} main_attrs["Meta"] = type("Meta", (object,), meta_attrs) - args_attrs: dict[str, Any] = { - "data": DeleteInput(required=True), - } + args_attrs: dict[str, Any] = {"data": DeleteInput(required=True), "context": ContextInput(required=False)} main_attrs["Arguments"] = type("Arguments", (object,), args_attrs) return type(name, (base_class,), main_attrs) diff --git a/backend/infrahub/graphql/mutations/branch.py b/backend/infrahub/graphql/mutations/branch.py index ff009d715b..22a1c89f24 100644 --- a/backend/infrahub/graphql/mutations/branch.py +++ b/backend/infrahub/graphql/mutations/branch.py @@ -9,6 +9,8 @@ from infrahub.core.branch import Branch from infrahub.database import retry_db_transaction +from infrahub.graphql.context import apply_external_context +from infrahub.graphql.types.context import ContextInput from infrahub.log import get_logger from infrahub.workflows.catalogue import ( BRANCH_CREATE, @@ -44,6 +46,7 @@ class BranchCreateInput(InputObjectType): class BranchCreate(Mutation): class Arguments: data = BranchCreateInput(required=True) + context = ContextInput(required=False) background_execution = Boolean(required=False, deprecation_reason="Please use `wait_until_completion` instead") wait_until_completion = Boolean(required=False) @@ -58,6 +61,7 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: BranchCreateInput, + context: ContextInput | None = None, background_execution: bool = False, wait_until_completion: bool = True, ) -> Self: @@ -65,6 +69,7 @@ async def mutate( task: dict | None = None model = BranchCreateModel(**data) + await apply_external_context(graphql_context=graphql_context, context_input=context) if background_execution or not wait_until_completion: workflow = await graphql_context.active_service.workflow.submit_workflow( @@ -96,6 +101,7 @@ class BranchUpdateInput(InputObjectType): class BranchDelete(Mutation): class Arguments: data = BranchNameInput(required=True) + context = ContextInput(required=False) wait_until_completion = Boolean(required=False) ok = Boolean() @@ -107,10 +113,12 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: BranchNameInput, + context: ContextInput | None = None, wait_until_completion: bool = True, ) -> Self: graphql_context: GraphqlContext = info.context obj = await Branch.get_by_name(db=graphql_context.db, name=str(data.name)) + await apply_external_context(graphql_context=graphql_context, context_input=context) if wait_until_completion: await graphql_context.active_service.workflow.execute_workflow( @@ -127,15 +135,23 @@ async def mutate( class BranchUpdate(Mutation): class Arguments: data = BranchUpdateInput(required=True) + context = ContextInput(required=False) ok = Boolean() @classmethod @retry_db_transaction(name="branch_update") - async def mutate(cls, root: dict, info: GraphQLResolveInfo, data: BranchNameInput) -> Self: # noqa: ARG003 + async def mutate( + cls, + root: dict, # noqa: ARG003 + info: GraphQLResolveInfo, + data: BranchNameInput, + context: ContextInput | None = None, + ) -> Self: graphql_context: GraphqlContext = info.context obj = await Branch.get_by_name(db=graphql_context.db, name=data["name"]) + await apply_external_context(graphql_context=graphql_context, context_input=context) to_extract = ["description"] for field_name in to_extract: @@ -151,6 +167,7 @@ async def mutate(cls, root: dict, info: GraphQLResolveInfo, data: BranchNameInpu class BranchRebase(Mutation): class Arguments: data = BranchNameInput(required=True) + context = ContextInput(required=False) wait_until_completion = Boolean(required=False) ok = Boolean() @@ -163,11 +180,13 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: BranchNameInput, + context: ContextInput | None = None, wait_until_completion: bool = True, ) -> Self: graphql_context: GraphqlContext = info.context obj = await Branch.get_by_name(db=graphql_context.db, name=str(data.name)) + await apply_external_context(graphql_context=graphql_context, context_input=context) task: dict | None = None if wait_until_completion: @@ -192,6 +211,7 @@ async def mutate( class BranchValidate(Mutation): class Arguments: data = BranchNameInput(required=True) + context = ContextInput(required=False) wait_until_completion = Boolean(required=False) ok = Boolean() @@ -205,11 +225,13 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: BranchNameInput, + context: ContextInput | None = None, wait_until_completion: bool = True, ) -> Self: graphql_context: GraphqlContext = info.context obj = await Branch.get_by_name(db=graphql_context.db, name=str(data.name)) + await apply_external_context(graphql_context=graphql_context, context_input=context) task: dict | None = None ok = True @@ -231,6 +253,7 @@ async def mutate( class BranchMerge(Mutation): class Arguments: data = BranchNameInput(required=True) + context = ContextInput(required=False) wait_until_completion = Boolean(required=False) ok = Boolean() @@ -243,11 +266,13 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: BranchNameInput, + context: ContextInput | None = None, wait_until_completion: bool = True, ) -> Self: branch_name = data["name"] task: dict | None = None graphql_context: GraphqlContext = info.context + await apply_external_context(graphql_context=graphql_context, context_input=context) if wait_until_completion: await graphql_context.active_service.workflow.execute_workflow( diff --git a/backend/infrahub/graphql/mutations/computed_attribute.py b/backend/infrahub/graphql/mutations/computed_attribute.py index 64432512cb..41458f2ac9 100644 --- a/backend/infrahub/graphql/mutations/computed_attribute.py +++ b/backend/infrahub/graphql/mutations/computed_attribute.py @@ -12,6 +12,8 @@ from infrahub.events import EventMeta from infrahub.events.node_action import NodeMutatedEvent from infrahub.exceptions import NodeNotFoundError, ValidationError +from infrahub.graphql.context import apply_external_context +from infrahub.graphql.types.context import ContextInput from infrahub.log import get_log_data from infrahub.worker import WORKER_IDENTITY @@ -31,6 +33,7 @@ class InfrahubComputedAttributeUpdateInput(InputObjectType): class UpdateComputedAttribute(Mutation): class Arguments: data = InfrahubComputedAttributeUpdateInput(required=True) + context = ContextInput(required=False) ok = Boolean() @@ -41,6 +44,7 @@ async def mutate( _: dict, info: GraphQLResolveInfo, data: InfrahubComputedAttributeUpdateInput, + context: ContextInput | None = None, ) -> UpdateComputedAttribute: graphql_context: GraphqlContext = info.context node_schema = registry.schema.get_node_schema( @@ -63,6 +67,7 @@ async def mutate( else PermissionDecision.ALLOW_OTHER.value, ) ) + await apply_external_context(graphql_context=graphql_context, context_input=context) if not ( target_node := await NodeManager.get_one( diff --git a/backend/infrahub/graphql/mutations/diff.py b/backend/infrahub/graphql/mutations/diff.py index b13df7c69d..a7e56890a9 100644 --- a/backend/infrahub/graphql/mutations/diff.py +++ b/backend/infrahub/graphql/mutations/diff.py @@ -11,6 +11,8 @@ from infrahub.core.timestamp import Timestamp from infrahub.dependencies.registry import get_component_registry from infrahub.exceptions import ValidationError +from infrahub.graphql.context import apply_external_context +from infrahub.graphql.types.context import ContextInput from infrahub.workflows.catalogue import DIFF_UPDATE from ..types.task import TaskInfo @@ -30,6 +32,7 @@ class DiffUpdateInput(InputObjectType): class DiffUpdateMutation(Mutation): class Arguments: data = DiffUpdateInput(required=True) + context = ContextInput(required=False) wait_until_completion = Boolean(required=False) ok = Boolean() @@ -41,9 +44,11 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: DiffUpdateInput, + context: ContextInput | None = None, wait_until_completion: bool = False, ) -> dict[str, bool | dict[str, str]]: graphql_context: GraphqlContext = info.context + await apply_external_context(graphql_context=graphql_context, context_input=context) if data.wait_for_completion is True: wait_until_completion = True diff --git a/backend/infrahub/graphql/mutations/diff_conflict.py b/backend/infrahub/graphql/mutations/diff_conflict.py index 30b62501dd..63504c97fe 100644 --- a/backend/infrahub/graphql/mutations/diff_conflict.py +++ b/backend/infrahub/graphql/mutations/diff_conflict.py @@ -11,7 +11,9 @@ from infrahub.database import retry_db_transaction from infrahub.dependencies.registry import get_component_registry from infrahub.exceptions import ProcessingError +from infrahub.graphql.context import apply_external_context from infrahub.graphql.enums import ConflictSelection as GraphQlConflictSelection +from infrahub.graphql.types.context import ContextInput if TYPE_CHECKING: from graphql import GraphQLResolveInfo @@ -29,6 +31,7 @@ class ResolveDiffConflictInput(InputObjectType): class ResolveDiffConflict(Mutation): class Arguments: data = ResolveDiffConflictInput(required=True) + context = ContextInput(required=False) ok = Boolean() @@ -39,8 +42,10 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: ResolveDiffConflictInput, + context: ContextInput | None = None, ) -> ResolveDiffConflict: graphql_context: GraphqlContext = info.context + await apply_external_context(graphql_context=graphql_context, context_input=context) component_registry = get_component_registry() diff_repo = await component_registry.get_component( diff --git a/backend/infrahub/graphql/mutations/generator.py b/backend/infrahub/graphql/mutations/generator.py index 02203d0aab..16e627077f 100644 --- a/backend/infrahub/graphql/mutations/generator.py +++ b/backend/infrahub/graphql/mutations/generator.py @@ -6,6 +6,8 @@ from infrahub.core.manager import NodeManager from infrahub.generators.models import ProposedChangeGeneratorDefinition, RequestGeneratorDefinitionRun +from infrahub.graphql.context import apply_external_context +from infrahub.graphql.types.context import ContextInput from infrahub.graphql.types.task import TaskInfo from infrahub.workflows.catalogue import REQUEST_GENERATOR_DEFINITION_RUN @@ -23,6 +25,7 @@ class GeneratorDefinitionRequestRunInput(InputObjectType): class GeneratorDefinitionRequestRun(Mutation): class Arguments: data = GeneratorDefinitionRequestRunInput(required=True) + context = ContextInput(required=False) wait_until_completion = Boolean(required=False) ok = Boolean() @@ -34,11 +37,12 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: GeneratorDefinitionRequestRunInput, + context: ContextInput | None = None, wait_until_completion: bool = True, ) -> GeneratorDefinitionRequestRun: graphql_context: GraphqlContext = info.context db = graphql_context.db - + await apply_external_context(graphql_context=graphql_context, context_input=context) generator_definition = await NodeManager.get_one( id=str(data.id), db=db, branch=graphql_context.branch, prefetch_relationships=True, raise_on_error=True ) diff --git a/backend/infrahub/graphql/mutations/main.py b/backend/infrahub/graphql/mutations/main.py index 79101bcf43..e3d0684c9e 100644 --- a/backend/infrahub/graphql/mutations/main.py +++ b/backend/infrahub/graphql/mutations/main.py @@ -24,6 +24,7 @@ from infrahub.dependencies.registry import get_component_registry from infrahub.events import EventMeta, NodeMutatedEvent from infrahub.exceptions import ValidationError +from infrahub.graphql.context import apply_external_context from infrahub.lock import InfrahubMultiLock, build_object_lock_name from infrahub.log import get_log_data, get_logger from infrahub.worker import WORKER_IDENTITY @@ -40,6 +41,7 @@ from infrahub.core.relationship.model import RelationshipManager from infrahub.core.schema.schema_branch import SchemaBranch from infrahub.database import InfrahubDatabase + from infrahub.graphql.types.context import ContextInput from ..initialization import GraphqlContext from .node_getter.interface import MutationNodeGetterInterface @@ -66,8 +68,17 @@ class InfrahubMutationOptions(MutationOptions): class InfrahubMutationMixin: @classmethod - async def mutate(cls, root: dict, info: GraphQLResolveInfo, data: InputObjectType, *args: Any, **kwargs): # noqa: ARG003 + async def mutate( + cls, + root: dict, # noqa: ARG003 + info: GraphQLResolveInfo, + data: InputObjectType, + context: ContextInput | None = None, + *args: Any, # noqa: ARG003 + **kwargs, + ): graphql_context: GraphqlContext = info.context + await apply_external_context(graphql_context=graphql_context, context_input=context) obj = None mutation = None diff --git a/backend/infrahub/graphql/mutations/relationship.py b/backend/infrahub/graphql/mutations/relationship.py index e832f9a4ac..71d93e632b 100644 --- a/backend/infrahub/graphql/mutations/relationship.py +++ b/backend/infrahub/graphql/mutations/relationship.py @@ -30,6 +30,8 @@ from infrahub.events.group_action import GroupMemberAddedEvent, GroupMemberRemovedEvent from infrahub.events.models import EventNode from infrahub.exceptions import NodeNotFoundError, ValidationError +from infrahub.graphql.context import apply_external_context +from infrahub.graphql.types.context import ContextInput from infrahub.permissions import get_global_permission_for_kind from ..types import RelatedNodeInput @@ -64,6 +66,7 @@ class RelationshipNodesInput(InputObjectType): class RelationshipAdd(Mutation): class Arguments: data = RelationshipNodesInput(required=True) + context = ContextInput(required=False) ok = Boolean() @@ -74,6 +77,7 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: RelationshipNodesInput, + context: ContextInput | None = None, ) -> Self: graphql_context: GraphqlContext = info.context relationship_name = str(data.name) @@ -83,6 +87,9 @@ async def mutate( await _validate_permissions(info=info, source_node=source, peers=nodes) await _validate_peer_types(info=info, data=data, source_node=source, peers=nodes) + # This has to be done after validating the permissions + await apply_external_context(graphql_context=graphql_context, context_input=context) + rel_schema = source.get_schema().get_relationship(name=relationship_name) display_label: str = await source.render_display_label(db=graphql_context.db) node_changelog = NodeChangelog( @@ -157,6 +164,7 @@ async def mutate( class RelationshipRemove(Mutation): class Arguments: data = RelationshipNodesInput(required=True) + context = ContextInput(required=False) ok = Boolean() @@ -167,6 +175,7 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: RelationshipNodesInput, + context: ContextInput | None = None, ) -> Self: graphql_context: GraphqlContext = info.context relationship_name = str(data.name) @@ -176,6 +185,9 @@ async def mutate( await _validate_permissions(info=info, source_node=source, peers=nodes) await _validate_peer_types(info=info, data=data, source_node=source, peers=nodes) + # This has to be done after validating the permissions + await apply_external_context(graphql_context=graphql_context, context_input=context) + rel_schema = source.get_schema().get_relationship(name=relationship_name) display_label: str = await source.render_display_label(db=graphql_context.db) node_changelog = NodeChangelog( diff --git a/backend/infrahub/graphql/mutations/schema.py b/backend/infrahub/graphql/mutations/schema.py index fa6c470d66..44565ff9a4 100644 --- a/backend/infrahub/graphql/mutations/schema.py +++ b/backend/infrahub/graphql/mutations/schema.py @@ -13,6 +13,8 @@ from infrahub.events import EventMeta from infrahub.events.schema_action import SchemaUpdatedEvent from infrahub.exceptions import ValidationError +from infrahub.graphql.context import apply_external_context +from infrahub.graphql.types.context import ContextInput from infrahub.log import get_log_data, get_logger from infrahub.worker import WORKER_IDENTITY @@ -51,6 +53,7 @@ class SchemaDropdownAddInput(SchemaDropdownRemoveInput): class SchemaDropdownAdd(Mutation): class Arguments: data = SchemaDropdownAddInput(required=True) + context = ContextInput(required=False) ok = Boolean() object = Field(DropdownFields) @@ -62,9 +65,12 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: SchemaDropdownAddInput, + context: ContextInput | None = None, ) -> Self: graphql_context: GraphqlContext = info.context + await apply_external_context(graphql_context=graphql_context, context_input=context) + kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name) attribute = str(data.attribute) validate_kind_dropdown(kind=kind, attribute=attribute) @@ -109,6 +115,7 @@ async def mutate( class SchemaDropdownRemove(Mutation): class Arguments: data = SchemaDropdownRemoveInput(required=True) + context = ContextInput(required=False) ok = Boolean() @@ -119,10 +126,12 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: SchemaDropdownRemoveInput, + context: ContextInput | None = None, ) -> dict[str, bool]: graphql_context: GraphqlContext = info.context kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name) + await apply_external_context(graphql_context=graphql_context, context_input=context) attribute = str(data.attribute) validate_kind_dropdown(kind=kind, attribute=attribute) @@ -161,6 +170,7 @@ async def mutate( class SchemaEnumAdd(Mutation): class Arguments: data = SchemaEnumInput(required=True) + context = ContextInput(required=False) ok = Boolean() @@ -171,10 +181,12 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: SchemaEnumInput, + context: ContextInput | None = None, ) -> dict[str, bool]: graphql_context: GraphqlContext = info.context kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name) + await apply_external_context(graphql_context=graphql_context, context_input=context) attribute = str(data.attribute) enum = str(data.enum) @@ -203,6 +215,7 @@ async def mutate( class SchemaEnumRemove(Mutation): class Arguments: data = SchemaEnumInput(required=True) + context = ContextInput(required=False) ok = Boolean() @@ -213,10 +226,12 @@ async def mutate( root: dict, # noqa: ARG003 info: GraphQLResolveInfo, data: SchemaEnumInput, + context: ContextInput | None = None, ) -> dict[str, bool]: graphql_context: GraphqlContext = info.context kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name) + await apply_external_context(graphql_context=graphql_context, context_input=context) attribute = str(data.attribute) enum = str(data.enum) diff --git a/backend/infrahub/graphql/types/context.py b/backend/infrahub/graphql/types/context.py new file mode 100644 index 0000000000..067ae786d8 --- /dev/null +++ b/backend/infrahub/graphql/types/context.py @@ -0,0 +1,12 @@ +from graphene import InputObjectType, String + + +class ContextAccountInput(InputObjectType): + id = String(required=True, description="The Infrahub ID of the account") + + +class ContextInput(InputObjectType): + account = ContextAccountInput( + required=False, + description="The account context can be used to override the account information that will be associated with the mutation", + ) diff --git a/backend/tests/unit/graphql/mutations/test_mutation_context.py b/backend/tests/unit/graphql/mutations/test_mutation_context.py new file mode 100644 index 0000000000..1b8f2de634 --- /dev/null +++ b/backend/tests/unit/graphql/mutations/test_mutation_context.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from infrahub.core.branch import Branch +from infrahub.database import InfrahubDatabase +from infrahub.events.node_action import NodeMutatedEvent +from infrahub.graphql.initialization import prepare_graphql_params +from infrahub.services import InfrahubServices +from tests.adapters.event import MemoryInfrahubEvent +from tests.helpers.graphql import graphql + +if TYPE_CHECKING: + from infrahub.auth import AccountSession + from infrahub.core.branch import Branch + from infrahub.core.node import Node + from infrahub.database import InfrahubDatabase + + +async def test_add_context_invalid_account( + db: InfrahubDatabase, + default_branch: Branch, + car_person_schema: None, + first_account: Node, +): + query = """ + mutation { + TestPersonCreate(data: {name: { value: "John"}, height: {value: 182}}, context: { account: { id: "very-invalid" }}) { + ok + object { + id + } + } + } + """ + gql_params = await prepare_graphql_params(db=db, include_subscription=False, branch=default_branch) + result = await graphql( + schema=gql_params.schema, + source=query, + context_value=gql_params.context, + root_value=None, + variable_values={}, + ) + assert result.errors + assert result.errors[0].message == "Unable to set context for account that doesn't exist" + + +async def test_add_context_valid_account( + db: InfrahubDatabase, + default_branch: Branch, + car_person_schema: None, + enable_broker_config: None, + session_first_account: AccountSession, + first_account: Node, + second_account: Node, +): + query = """ + mutation { + TestPersonCreate(data: {name: { value: "John"}, height: {value: 182}}, context: { account: { id: "%s" }}) { + ok + object { + id + } + } + } + """ % (second_account.id) + + memory_event = MemoryInfrahubEvent() + service = await InfrahubServices.new(event=memory_event) + gql_params = await prepare_graphql_params( + db=db, include_subscription=False, branch=default_branch, service=service, account_session=session_first_account + ) + result = await graphql( + schema=gql_params.schema, + source=query, + context_value=gql_params.context, + root_value=None, + variable_values={}, + ) + + assert result.errors is None + assert gql_params.context.background + await gql_params.context.background() + + assert len(memory_event.events) == 1 + node_event = memory_event.events[0] + assert isinstance(node_event, NodeMutatedEvent) + assert node_event.meta.account_id == second_account.id