Skip to content

Commit

Permalink
feat: schema also validates the correct locations of directives now
Browse files Browse the repository at this point in the history
  • Loading branch information
mak626 committed Jan 15, 2024
1 parent 2ae721b commit 068f266
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 17 deletions.
68 changes: 52 additions & 16 deletions graphene_directives/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def type_attribute_to_field_name(self, attribute: str) -> str:
def _add_argument_decorators(
self,
entity_name: str,
allowed_locations: list[str],
required_directive_field_types: set[DirectiveLocation],
args: dict[str, GraphQLArgument],
) -> str:
Expand Down Expand Up @@ -169,13 +168,19 @@ def _add_argument_decorators(
directive, "_graphene_directive"
)

if required_directive_field_types in set(directive.locations):
if (
not required_directive_field_types.intersection(
set(directive.locations)
)
and len(required_directive_field_types) != 0
):
raise DirectiveValidationError(
", ".join(
"\n".join(
[
f"{str(directive)} cannot be used at argument {name} level",
allowed_locations,
f"at {entity_name}",
f"\tat {entity_name}",
f"\tallowed: {directive.locations}",
f"\trequired: {required_directive_field_types}",
]
)
)
Expand Down Expand Up @@ -214,19 +219,19 @@ def _add_field_decorators(self, graphene_types: set, string_schema: str) -> str:
entity_type = self.graphql_schema.get_type(entity_name)
get_field_graphene_type = self.field_name_to_type_attribute(graphene_type)

required_directive_field_types = set()
required_directive_locations = set()

if is_object_type(entity_type) or is_interface_type(entity_type):
required_directive_field_types.union(
required_directive_locations.union(
{
DirectiveLocation.FIELD_DEFINITION,
DirectiveLocation.ARGUMENT_DEFINITION,
}
)
elif is_enum_type(entity_type):
required_directive_field_types.add(DirectiveLocation.ENUM_VALUE)
required_directive_locations.add(DirectiveLocation.ENUM_VALUE)
elif is_input_type(entity_type):
required_directive_field_types.add(
required_directive_locations.add(
DirectiveLocation.INPUT_FIELD_DEFINITION
)
else:
Expand All @@ -238,7 +243,6 @@ def _add_field_decorators(self, graphene_types: set, string_schema: str) -> str:
fields: dict = entity_type.fields

str_fields = []
allowed_locations = [str(t) for t in required_directive_field_types]

for field_name, field in fields.items():
if is_enum_type(entity_type):
Expand Down Expand Up @@ -276,8 +280,7 @@ def _add_field_decorators(self, graphene_types: set, string_schema: str) -> str:
)
replacement_args = self._add_argument_decorators(
entity_name=entity_name,
allowed_locations=allowed_locations,
required_directive_field_types=required_directive_field_types,
required_directive_field_types=required_directive_locations,
args=arg_field.args,
)
str_field = str_field.replace(
Expand All @@ -304,16 +307,23 @@ def _add_field_decorators(self, graphene_types: set, string_schema: str) -> str:
directive, "_graphene_directive"
)

if required_directive_field_types in set(directive.locations):
if (
not required_directive_locations.intersection(
set(directive.locations)
)
and len(required_directive_locations) != 0
):
raise DirectiveValidationError(
", ".join(
"\n".join(
[
f"{str(directive)} cannot be used at field level",
allowed_locations,
f"at {entity_name}",
f"\tat {entity_name}",
f"\tallowed: {directive.locations}",
f"\trequired: {required_directive_locations}",
]
)
)

for directive_value in directive_values:
if (
meta_data.field_validator is not None
Expand Down Expand Up @@ -380,18 +390,26 @@ def add_non_field_decorators(
entity_name = non_field._meta.name # noqa
entity_type = self.graphql_schema.get_type(entity_name)

required_directive_locations = set()

if is_scalar_type(entity_type):
non_field_pattern = rf"(scalar {entity_name})"
required_directive_locations.add(DirectiveLocation.SCALAR)
elif is_union_type(entity_type):
non_field_pattern = rf"(union {entity_name} )"
required_directive_locations.add(DirectiveLocation.UNION)
elif is_object_type(entity_type):
non_field_pattern = rf"(type {entity_name} [^\{{]*)"
required_directive_locations.add(DirectiveLocation.OBJECT)
elif is_interface_type(entity_type):
non_field_pattern = rf"(interface {entity_name} [^\{{]*)"
required_directive_locations.add(DirectiveLocation.INTERFACE)
elif is_enum_type(entity_type):
non_field_pattern = rf"(enum {entity_name} [^\{{]*)"
required_directive_locations.add(DirectiveLocation.ENUM)
elif is_input_type(entity_type):
non_field_pattern = rf"(input {entity_name} [^\{{]*)"
required_directive_locations.add(DirectiveLocation.INPUT_OBJECT)
else:
continue

Expand All @@ -404,6 +422,24 @@ def add_non_field_decorators(
directive_values = get_non_field_attribute_value(
non_field, directive
)

if (
not required_directive_locations.intersection(
set(directive.locations)
)
and len(required_directive_locations) != 0
):
raise DirectiveValidationError(
"\n".join(
[
f"{str(directive)} cannot be used at non field level",
f"\tat {entity_name}",
f"\tallowed: {directive.locations}",
f"\trequired: {required_directive_locations}",
]
)
)

for directive_value in directive_values:
if (
meta_data.non_field_validator is not None
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphene-directives"
version = "0.4.5"
version = "0.4.6"
packages = [{include = "graphene_directives"}]
description = "Schema Directives implementation for graphene"
authors = ["Strollby <developers@strollby.com>"]
Expand Down

0 comments on commit 068f266

Please sign in to comment.