From 068f2660ab207cbbb452114a525d64a4bc962ee8 Mon Sep 17 00:00:00 2001 From: M Aswin Kishore <60577077+mak626@users.noreply.github.com> Date: Mon, 15 Jan 2024 23:30:40 +0530 Subject: [PATCH] feat: schema also validates the correct locations of directives now --- graphene_directives/schema.py | 68 ++++++++++++++++++++++++++--------- pyproject.toml | 2 +- 2 files changed, 53 insertions(+), 17 deletions(-) diff --git a/graphene_directives/schema.py b/graphene_directives/schema.py index 21aa6d9..753fe43 100644 --- a/graphene_directives/schema.py +++ b/graphene_directives/schema.py @@ -134,7 +134,6 @@ def type_attribute_to_field_name(self, attribute: str) -> str: def _add_argument_decorators( self, entity_name: str, - allowed_locations: list[str], required_directive_field_types: set[DirectiveLocation], args: dict[str, GraphQLArgument], ) -> str: @@ -169,13 +168,19 @@ def _add_argument_decorators( directive, "_graphene_directive" ) - if required_directive_field_types in set(directive.locations): + if ( + not required_directive_field_types.intersection( + set(directive.locations) + ) + and len(required_directive_field_types) != 0 + ): raise DirectiveValidationError( - ", ".join( + "\n".join( [ f"{str(directive)} cannot be used at argument {name} level", - allowed_locations, - f"at {entity_name}", + f"\tat {entity_name}", + f"\tallowed: {directive.locations}", + f"\trequired: {required_directive_field_types}", ] ) ) @@ -214,19 +219,19 @@ def _add_field_decorators(self, graphene_types: set, string_schema: str) -> str: entity_type = self.graphql_schema.get_type(entity_name) get_field_graphene_type = self.field_name_to_type_attribute(graphene_type) - required_directive_field_types = set() + required_directive_locations = set() if is_object_type(entity_type) or is_interface_type(entity_type): - required_directive_field_types.union( + required_directive_locations.union( { DirectiveLocation.FIELD_DEFINITION, DirectiveLocation.ARGUMENT_DEFINITION, } ) elif is_enum_type(entity_type): - required_directive_field_types.add(DirectiveLocation.ENUM_VALUE) + required_directive_locations.add(DirectiveLocation.ENUM_VALUE) elif is_input_type(entity_type): - required_directive_field_types.add( + required_directive_locations.add( DirectiveLocation.INPUT_FIELD_DEFINITION ) else: @@ -238,7 +243,6 @@ def _add_field_decorators(self, graphene_types: set, string_schema: str) -> str: fields: dict = entity_type.fields str_fields = [] - allowed_locations = [str(t) for t in required_directive_field_types] for field_name, field in fields.items(): if is_enum_type(entity_type): @@ -276,8 +280,7 @@ def _add_field_decorators(self, graphene_types: set, string_schema: str) -> str: ) replacement_args = self._add_argument_decorators( entity_name=entity_name, - allowed_locations=allowed_locations, - required_directive_field_types=required_directive_field_types, + required_directive_field_types=required_directive_locations, args=arg_field.args, ) str_field = str_field.replace( @@ -304,16 +307,23 @@ def _add_field_decorators(self, graphene_types: set, string_schema: str) -> str: directive, "_graphene_directive" ) - if required_directive_field_types in set(directive.locations): + if ( + not required_directive_locations.intersection( + set(directive.locations) + ) + and len(required_directive_locations) != 0 + ): raise DirectiveValidationError( - ", ".join( + "\n".join( [ f"{str(directive)} cannot be used at field level", - allowed_locations, - f"at {entity_name}", + f"\tat {entity_name}", + f"\tallowed: {directive.locations}", + f"\trequired: {required_directive_locations}", ] ) ) + for directive_value in directive_values: if ( meta_data.field_validator is not None @@ -380,18 +390,26 @@ def add_non_field_decorators( entity_name = non_field._meta.name # noqa entity_type = self.graphql_schema.get_type(entity_name) + required_directive_locations = set() + if is_scalar_type(entity_type): non_field_pattern = rf"(scalar {entity_name})" + required_directive_locations.add(DirectiveLocation.SCALAR) elif is_union_type(entity_type): non_field_pattern = rf"(union {entity_name} )" + required_directive_locations.add(DirectiveLocation.UNION) elif is_object_type(entity_type): non_field_pattern = rf"(type {entity_name} [^\{{]*)" + required_directive_locations.add(DirectiveLocation.OBJECT) elif is_interface_type(entity_type): non_field_pattern = rf"(interface {entity_name} [^\{{]*)" + required_directive_locations.add(DirectiveLocation.INTERFACE) elif is_enum_type(entity_type): non_field_pattern = rf"(enum {entity_name} [^\{{]*)" + required_directive_locations.add(DirectiveLocation.ENUM) elif is_input_type(entity_type): non_field_pattern = rf"(input {entity_name} [^\{{]*)" + required_directive_locations.add(DirectiveLocation.INPUT_OBJECT) else: continue @@ -404,6 +422,24 @@ def add_non_field_decorators( directive_values = get_non_field_attribute_value( non_field, directive ) + + if ( + not required_directive_locations.intersection( + set(directive.locations) + ) + and len(required_directive_locations) != 0 + ): + raise DirectiveValidationError( + "\n".join( + [ + f"{str(directive)} cannot be used at non field level", + f"\tat {entity_name}", + f"\tallowed: {directive.locations}", + f"\trequired: {required_directive_locations}", + ] + ) + ) + for directive_value in directive_values: if ( meta_data.non_field_validator is not None diff --git a/pyproject.toml b/pyproject.toml index 12d1d1b..6f27bfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "graphene-directives" -version = "0.4.5" +version = "0.4.6" packages = [{include = "graphene_directives"}] description = "Schema Directives implementation for graphene" authors = ["Strollby "]