Skip to content

Commit

Permalink
Merge pull request #5874 from opsmill/pog-external-graphql-context-IF…
Browse files Browse the repository at this point in the history
…C-1303

Provide option to set GraphQL context information as input to mutations
  • Loading branch information
ogenstad authored Mar 3, 2025
2 parents 50b47f3 + 5e0857b commit 84f1c09
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 12 deletions.
26 changes: 26 additions & 0 deletions backend/infrahub/graphql/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from infrahub.core.constants import InfrahubKind
from infrahub.core.manager import NodeManager
from infrahub.exceptions import NodeNotFoundError, ValidationError

if TYPE_CHECKING:
from .initialization import GraphqlContext
from .types.context import ContextInput


async def apply_external_context(graphql_context: GraphqlContext, context_input: ContextInput | None) -> None:
"""Applies context provided by an external mutation to the GraphQL context"""
if not context_input or not context_input.account:
return

try:
account = await NodeManager.get_one_by_id_or_default_filter(
db=graphql_context.db, id=str(context_input.account.id), kind=InfrahubKind.GENERICACCOUNT
)
except NodeNotFoundError as exc:
raise ValidationError(input_value="Unable to set context for account that doesn't exist") from exc

graphql_context.active_account_session.account_id = account.id
13 changes: 4 additions & 9 deletions backend/infrahub/graphql/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
)
from .types.attribute import BaseAttribute as BaseAttributeType
from .types.attribute import TextAttributeType
from .types.context import ContextInput
from .types.event import EVENT_TYPES

if TYPE_CHECKING:
Expand Down Expand Up @@ -809,9 +810,7 @@ def generate_graphql_mutation_create(
meta_attrs: dict[str, Any] = {"schema": schema, "name": name, "description": schema.description}
main_attrs["Meta"] = type("Meta", (object,), meta_attrs)

args_attrs = {
"data": input_type(required=True),
}
args_attrs = {"data": input_type(required=True), "context": ContextInput(required=False)}
main_attrs["Arguments"] = type("Arguments", (object,), args_attrs)

return type(name, (base_class,), main_attrs)
Expand All @@ -832,9 +831,7 @@ def generate_graphql_mutation_update(
meta_attrs: dict[str, Any] = {"schema": schema, "name": name, "description": schema.description}
main_attrs["Meta"] = type("Meta", (object,), meta_attrs)

args_attrs = {
"data": input_type(required=True),
}
args_attrs = {"data": input_type(required=True), "context": ContextInput(required=False)}
main_attrs["Arguments"] = type("Arguments", (object,), args_attrs)

return type(name, (base_class,), main_attrs)
Expand All @@ -851,9 +848,7 @@ def generate_graphql_mutation_delete(
meta_attrs = {"schema": schema, "name": name, "description": schema.description}
main_attrs["Meta"] = type("Meta", (object,), meta_attrs)

args_attrs: dict[str, Any] = {
"data": DeleteInput(required=True),
}
args_attrs: dict[str, Any] = {"data": DeleteInput(required=True), "context": ContextInput(required=False)}
main_attrs["Arguments"] = type("Arguments", (object,), args_attrs)

return type(name, (base_class,), main_attrs)
Expand Down
27 changes: 26 additions & 1 deletion backend/infrahub/graphql/mutations/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from infrahub.core.branch import Branch
from infrahub.database import retry_db_transaction
from infrahub.graphql.context import apply_external_context
from infrahub.graphql.types.context import ContextInput
from infrahub.log import get_logger
from infrahub.workflows.catalogue import (
BRANCH_CREATE,
Expand Down Expand Up @@ -44,6 +46,7 @@ class BranchCreateInput(InputObjectType):
class BranchCreate(Mutation):
class Arguments:
data = BranchCreateInput(required=True)
context = ContextInput(required=False)
background_execution = Boolean(required=False, deprecation_reason="Please use `wait_until_completion` instead")
wait_until_completion = Boolean(required=False)

Expand All @@ -58,13 +61,15 @@ async def mutate(
root: dict, # noqa: ARG003
info: GraphQLResolveInfo,
data: BranchCreateInput,
context: ContextInput | None = None,
background_execution: bool = False,
wait_until_completion: bool = True,
) -> Self:
graphql_context: GraphqlContext = info.context
task: dict | None = None

model = BranchCreateModel(**data)
await apply_external_context(graphql_context=graphql_context, context_input=context)

if background_execution or not wait_until_completion:
workflow = await graphql_context.active_service.workflow.submit_workflow(
Expand Down Expand Up @@ -96,6 +101,7 @@ class BranchUpdateInput(InputObjectType):
class BranchDelete(Mutation):
class Arguments:
data = BranchNameInput(required=True)
context = ContextInput(required=False)
wait_until_completion = Boolean(required=False)

ok = Boolean()
Expand All @@ -107,10 +113,12 @@ async def mutate(
root: dict, # noqa: ARG003
info: GraphQLResolveInfo,
data: BranchNameInput,
context: ContextInput | None = None,
wait_until_completion: bool = True,
) -> Self:
graphql_context: GraphqlContext = info.context
obj = await Branch.get_by_name(db=graphql_context.db, name=str(data.name))
await apply_external_context(graphql_context=graphql_context, context_input=context)

if wait_until_completion:
await graphql_context.active_service.workflow.execute_workflow(
Expand All @@ -127,15 +135,23 @@ async def mutate(
class BranchUpdate(Mutation):
class Arguments:
data = BranchUpdateInput(required=True)
context = ContextInput(required=False)

ok = Boolean()

@classmethod
@retry_db_transaction(name="branch_update")
async def mutate(cls, root: dict, info: GraphQLResolveInfo, data: BranchNameInput) -> Self: # noqa: ARG003
async def mutate(
cls,
root: dict, # noqa: ARG003
info: GraphQLResolveInfo,
data: BranchNameInput,
context: ContextInput | None = None,
) -> Self:
graphql_context: GraphqlContext = info.context

obj = await Branch.get_by_name(db=graphql_context.db, name=data["name"])
await apply_external_context(graphql_context=graphql_context, context_input=context)

to_extract = ["description"]
for field_name in to_extract:
Expand All @@ -151,6 +167,7 @@ async def mutate(cls, root: dict, info: GraphQLResolveInfo, data: BranchNameInpu
class BranchRebase(Mutation):
class Arguments:
data = BranchNameInput(required=True)
context = ContextInput(required=False)
wait_until_completion = Boolean(required=False)

ok = Boolean()
Expand All @@ -163,11 +180,13 @@ async def mutate(
root: dict, # noqa: ARG003
info: GraphQLResolveInfo,
data: BranchNameInput,
context: ContextInput | None = None,
wait_until_completion: bool = True,
) -> Self:
graphql_context: GraphqlContext = info.context

obj = await Branch.get_by_name(db=graphql_context.db, name=str(data.name))
await apply_external_context(graphql_context=graphql_context, context_input=context)
task: dict | None = None

if wait_until_completion:
Expand All @@ -192,6 +211,7 @@ async def mutate(
class BranchValidate(Mutation):
class Arguments:
data = BranchNameInput(required=True)
context = ContextInput(required=False)
wait_until_completion = Boolean(required=False)

ok = Boolean()
Expand All @@ -205,11 +225,13 @@ async def mutate(
root: dict, # noqa: ARG003
info: GraphQLResolveInfo,
data: BranchNameInput,
context: ContextInput | None = None,
wait_until_completion: bool = True,
) -> Self:
graphql_context: GraphqlContext = info.context

obj = await Branch.get_by_name(db=graphql_context.db, name=str(data.name))
await apply_external_context(graphql_context=graphql_context, context_input=context)
task: dict | None = None
ok = True

Expand All @@ -231,6 +253,7 @@ async def mutate(
class BranchMerge(Mutation):
class Arguments:
data = BranchNameInput(required=True)
context = ContextInput(required=False)
wait_until_completion = Boolean(required=False)

ok = Boolean()
Expand All @@ -243,11 +266,13 @@ async def mutate(
root: dict, # noqa: ARG003
info: GraphQLResolveInfo,
data: BranchNameInput,
context: ContextInput | None = None,
wait_until_completion: bool = True,
) -> Self:
branch_name = data["name"]
task: dict | None = None
graphql_context: GraphqlContext = info.context
await apply_external_context(graphql_context=graphql_context, context_input=context)

if wait_until_completion:
await graphql_context.active_service.workflow.execute_workflow(
Expand Down
5 changes: 5 additions & 0 deletions backend/infrahub/graphql/mutations/computed_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from infrahub.events import EventMeta
from infrahub.events.node_action import NodeMutatedEvent
from infrahub.exceptions import NodeNotFoundError, ValidationError
from infrahub.graphql.context import apply_external_context
from infrahub.graphql.types.context import ContextInput
from infrahub.log import get_log_data
from infrahub.worker import WORKER_IDENTITY

Expand All @@ -31,6 +33,7 @@ class InfrahubComputedAttributeUpdateInput(InputObjectType):
class UpdateComputedAttribute(Mutation):
class Arguments:
data = InfrahubComputedAttributeUpdateInput(required=True)
context = ContextInput(required=False)

ok = Boolean()

Expand All @@ -41,6 +44,7 @@ async def mutate(
_: dict,
info: GraphQLResolveInfo,
data: InfrahubComputedAttributeUpdateInput,
context: ContextInput | None = None,
) -> UpdateComputedAttribute:
graphql_context: GraphqlContext = info.context
node_schema = registry.schema.get_node_schema(
Expand All @@ -63,6 +67,7 @@ async def mutate(
else PermissionDecision.ALLOW_OTHER.value,
)
)
await apply_external_context(graphql_context=graphql_context, context_input=context)

if not (
target_node := await NodeManager.get_one(
Expand Down
5 changes: 5 additions & 0 deletions backend/infrahub/graphql/mutations/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from infrahub.core.timestamp import Timestamp
from infrahub.dependencies.registry import get_component_registry
from infrahub.exceptions import ValidationError
from infrahub.graphql.context import apply_external_context
from infrahub.graphql.types.context import ContextInput
from infrahub.workflows.catalogue import DIFF_UPDATE

from ..types.task import TaskInfo
Expand All @@ -30,6 +32,7 @@ class DiffUpdateInput(InputObjectType):
class DiffUpdateMutation(Mutation):
class Arguments:
data = DiffUpdateInput(required=True)
context = ContextInput(required=False)
wait_until_completion = Boolean(required=False)

ok = Boolean()
Expand All @@ -41,9 +44,11 @@ async def mutate(
root: dict, # noqa: ARG003
info: GraphQLResolveInfo,
data: DiffUpdateInput,
context: ContextInput | None = None,
wait_until_completion: bool = False,
) -> dict[str, bool | dict[str, str]]:
graphql_context: GraphqlContext = info.context
await apply_external_context(graphql_context=graphql_context, context_input=context)

if data.wait_for_completion is True:
wait_until_completion = True
Expand Down
5 changes: 5 additions & 0 deletions backend/infrahub/graphql/mutations/diff_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from infrahub.database import retry_db_transaction
from infrahub.dependencies.registry import get_component_registry
from infrahub.exceptions import ProcessingError
from infrahub.graphql.context import apply_external_context
from infrahub.graphql.enums import ConflictSelection as GraphQlConflictSelection
from infrahub.graphql.types.context import ContextInput

if TYPE_CHECKING:
from graphql import GraphQLResolveInfo
Expand All @@ -29,6 +31,7 @@ class ResolveDiffConflictInput(InputObjectType):
class ResolveDiffConflict(Mutation):
class Arguments:
data = ResolveDiffConflictInput(required=True)
context = ContextInput(required=False)

ok = Boolean()

Expand All @@ -39,8 +42,10 @@ async def mutate(
root: dict, # noqa: ARG003
info: GraphQLResolveInfo,
data: ResolveDiffConflictInput,
context: ContextInput | None = None,
) -> ResolveDiffConflict:
graphql_context: GraphqlContext = info.context
await apply_external_context(graphql_context=graphql_context, context_input=context)

component_registry = get_component_registry()
diff_repo = await component_registry.get_component(
Expand Down
6 changes: 5 additions & 1 deletion backend/infrahub/graphql/mutations/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from infrahub.core.manager import NodeManager
from infrahub.generators.models import ProposedChangeGeneratorDefinition, RequestGeneratorDefinitionRun
from infrahub.graphql.context import apply_external_context
from infrahub.graphql.types.context import ContextInput
from infrahub.graphql.types.task import TaskInfo
from infrahub.workflows.catalogue import REQUEST_GENERATOR_DEFINITION_RUN

Expand All @@ -23,6 +25,7 @@ class GeneratorDefinitionRequestRunInput(InputObjectType):
class GeneratorDefinitionRequestRun(Mutation):
class Arguments:
data = GeneratorDefinitionRequestRunInput(required=True)
context = ContextInput(required=False)
wait_until_completion = Boolean(required=False)

ok = Boolean()
Expand All @@ -34,11 +37,12 @@ async def mutate(
root: dict, # noqa: ARG003
info: GraphQLResolveInfo,
data: GeneratorDefinitionRequestRunInput,
context: ContextInput | None = None,
wait_until_completion: bool = True,
) -> GeneratorDefinitionRequestRun:
graphql_context: GraphqlContext = info.context
db = graphql_context.db

await apply_external_context(graphql_context=graphql_context, context_input=context)
generator_definition = await NodeManager.get_one(
id=str(data.id), db=db, branch=graphql_context.branch, prefetch_relationships=True, raise_on_error=True
)
Expand Down
13 changes: 12 additions & 1 deletion backend/infrahub/graphql/mutations/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from infrahub.dependencies.registry import get_component_registry
from infrahub.events import EventMeta, NodeMutatedEvent
from infrahub.exceptions import ValidationError
from infrahub.graphql.context import apply_external_context
from infrahub.lock import InfrahubMultiLock, build_object_lock_name
from infrahub.log import get_log_data, get_logger
from infrahub.worker import WORKER_IDENTITY
Expand All @@ -40,6 +41,7 @@
from infrahub.core.relationship.model import RelationshipManager
from infrahub.core.schema.schema_branch import SchemaBranch
from infrahub.database import InfrahubDatabase
from infrahub.graphql.types.context import ContextInput

from ..initialization import GraphqlContext
from .node_getter.interface import MutationNodeGetterInterface
Expand All @@ -66,8 +68,17 @@ class InfrahubMutationOptions(MutationOptions):

class InfrahubMutationMixin:
@classmethod
async def mutate(cls, root: dict, info: GraphQLResolveInfo, data: InputObjectType, *args: Any, **kwargs): # noqa: ARG003
async def mutate(
cls,
root: dict, # noqa: ARG003
info: GraphQLResolveInfo,
data: InputObjectType,
context: ContextInput | None = None,
*args: Any, # noqa: ARG003
**kwargs,
):
graphql_context: GraphqlContext = info.context
await apply_external_context(graphql_context=graphql_context, context_input=context)

obj = None
mutation = None
Expand Down
Loading

0 comments on commit 84f1c09

Please sign in to comment.