diff --git a/drf_jsonapi/mixins.py b/drf_jsonapi/mixins.py index 3312f01..a1a0bf4 100644 --- a/drf_jsonapi/mixins.py +++ b/drf_jsonapi/mixins.py @@ -69,7 +69,7 @@ def list(self, request): self.document.instance.data = serializer.data self.document.instance.included = serializer.included - return Response(self.document.data) + return Response(self.document.data, context={"collection": page}) class ProcessRelationshipsMixin: @@ -137,7 +137,11 @@ def create(self, request): serializer.instance = resource self.document.instance.data = serializer.data - return Response(self.document.data, status=status.HTTP_201_CREATED) + return Response( + self.document.data, + status=status.HTTP_201_CREATED, + context={"resource": resource}, + ) class RetrieveMixin: @@ -162,7 +166,7 @@ def retrieve(self, request, pk): self.document.instance.data = serializer.data self.document.instance.included = serializer.included - return Response(self.document.data) + return Response(self.document.data, context={"resource": resource}) class PartialUpdateMixin(ProcessRelationshipsMixin): @@ -199,7 +203,7 @@ def partial_update(self, request, *args, **kwargs): self.document.instance.data = serializer.data - return Response(self.document.data) + return Response(self.document.data, context={"resource": resource}) class DestroyMixin: @@ -208,10 +212,11 @@ class DestroyMixin: """ def destroy(self, request, pk): - resource = self.get_resource(request, pk) - resource.delete() + resource = self.get_resource(request, pk).delete() - return Response(status=status.HTTP_204_NO_CONTENT) + return Response( + status=status.HTTP_204_NO_CONTENT, context={"resource": resource} + ) class RelationshipRetrieveMixin: @@ -238,7 +243,7 @@ def relationship_retrieve(self, request, pk, relationship): serializer = resource_identifier(serializer_class)(related, many=handler.many) self.document.instance.data = serializer.data - return Response(self.document.data) + return Response(self.document.data, context={"resource": resource}) class RelationshipCreateMixin: @@ -268,7 +273,9 @@ def relationship_create(self, request, pk, relationship): handler.add_related(resource, related, request) - return Response(status=status.HTTP_204_NO_CONTENT) + return Response( + status=status.HTTP_204_NO_CONTENT, context={"resource": resource} + ) class RelationshipUpdateMixin: @@ -295,7 +302,9 @@ def relationship_update(self, request, pk, relationship): if not handler.many: resource.save() - return Response(status=status.HTTP_204_NO_CONTENT) + return Response( + status=status.HTTP_204_NO_CONTENT, context={"resource": resource} + ) class RelationshipDestroyMixin: @@ -325,4 +334,6 @@ def relationship_destroy(self, request, pk, relationship): handler.remove_related(resource, related, request) - return Response(status=status.HTTP_204_NO_CONTENT) + return Response( + status=status.HTTP_204_NO_CONTENT, context={"resource": resource} + ) diff --git a/drf_jsonapi/response.py b/drf_jsonapi/response.py index 45c682f..b476a7f 100644 --- a/drf_jsonapi/response.py +++ b/drf_jsonapi/response.py @@ -1,5 +1,21 @@ +from rest_framework import status from rest_framework.response import Response as BaseResponse class Response(BaseResponse): - pass + def __init__(self, *args, **kwargs): + context = kwargs.pop("context", {}) + super(Response, self).__init__(*args, **kwargs) + self.context = context + + @property + def ok(self): + """Returns True if :attr:`status_code` is less than 400, False if not.""" + if self.status_code >= status.HTTP_400_BAD_REQUEST: + return False + return True + + @property + def created(self): + """Returns True if :attr:`status_code` is 201, False if not.""" + return self.status_code == status.HTTP_201_CREATED diff --git a/tests/test_mixins.py b/tests/test_mixins.py index f5f7859..f114e0e 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -88,6 +88,7 @@ def test_retrieve_mixin(self): view = TestViewSet.as_view({"get": "retrieve"}) response = view(request, pk=1) self.assertEqual(response.status_code, 200) + self.assertIsInstance(response.context["resource"], TestModel) def test_create_mixin(self): factory = APIRequestFactory() @@ -105,6 +106,7 @@ def test_create_mixin(self): response = view(request) response.render() self.assertEqual(response.status_code, 201) + self.assertIsInstance(response.context["resource"], TestModel) def test_create_mixin_with_relationships_to_many(self): factory = APIRequestFactory() @@ -272,6 +274,7 @@ def test_partial_update_mixin(self): response = view(request, pk=1) response.render() self.assertEqual(response.status_code, 200) + self.assertIsInstance(response.context["resource"], TestModel) def test_partial_update_mixin_with_relationships(self): factory = APIRequestFactory() @@ -370,6 +373,7 @@ def test_destroy_mixin(self): view = TestViewSet.as_view({"delete": "destroy"}) response = view(request, pk=1) self.assertEqual(response.status_code, 204) + self.assertIsInstance(response.context["resource"], TestModel) class RelationshipCreateMixinTestCase(TestCase): @@ -378,21 +382,22 @@ class TestManyView(mixins.RelationshipCreateMixin, ViewSet): relationship = "related_things" def get_resource(self, request, pk): - return True + return TestModel() class TestOneView(mixins.RelationshipCreateMixin, ViewSet): serializer_class = TestModelSerializer relationship = "related_thing" def get_resource(self, request, pk): - return True + return TestModel() def test_create_valid(self): factory = APIRequestFactory() request = factory.post("/test_resources", {}, format="json") request.data = {"data": [{"type": "test_resource", "id": "test_id"}]} view = self.TestManyView() - view.relationship_create(request, 1, "related_things") + response = view.relationship_create(request, 1, "related_things") + self.assertIsInstance(response.context["resource"], TestModel) def test_create_to_many_invalid(self): factory = APIRequestFactory() @@ -439,7 +444,8 @@ def test_patch_one_valid(self): request = factory.patch("/test_resources", {}, format="json") request.data = {"data": {"type": "test_resource", "id": "test_id"}} view = self.TestOneView() - view.relationship_update(request, 1, "related_thing") + response = view.relationship_update(request, 1, "related_thing") + self.assertIsInstance(response.context["resource"], TestModel) def test_patch_one_invalid(self): factory = APIRequestFactory() @@ -456,14 +462,14 @@ class TestManyView(mixins.RelationshipDestroyMixin, ViewSet): relationship = "related_things" def get_resource(self, request, pk): - return True + return TestModel() class TestOneView(mixins.RelationshipDestroyMixin, ViewSet): serializer_class = TestModelSerializer relationship = "related_thing" def get_resource(self, request, pk): - return True + return TestModel() def test_destroy_to_one_invalid(self): factory = APIRequestFactory() @@ -483,7 +489,8 @@ def test_destroy_to_many_valid(self): ) view = self.TestManyView() request.data = {"data": {"type": "test_resource", "id": "5"}} - view.relationship_destroy(request, 1, "related_things") + response = view.relationship_destroy(request, 1, "related_things") + self.assertIsInstance(response.context["resource"], TestModel) def test_destroy_to_many_valid_iterator(self): factory = APIRequestFactory() @@ -523,6 +530,7 @@ def test_list_mixin_one(self): view = self.TestOneView.as_view({"get": "relationship_retrieve"}) response = view(request, 1, "related_thing") self.assertEqual(response.status_code, 200) + self.assertIsInstance(response.context["resource"], TestModel) def test_list_mixin_many(self): factory = APIRequestFactory() diff --git a/tests/test_response.py b/tests/test_response.py new file mode 100644 index 0000000..b6d56b2 --- /dev/null +++ b/tests/test_response.py @@ -0,0 +1,27 @@ +from django.test import TestCase +from rest_framework import status + +from drf_jsonapi.response import Response + + +class ResponseTestCase(TestCase): + def test_init_with_context(self): + context = {"key": "value"} + response = Response(context=context) + self.assertDictEqual(response.context, context) + + def test_ok_success(self): + response = Response(status=status.HTTP_200_OK) + self.assertTrue(response.ok) + + def test_ok_error(self): + response = Response(status=status.HTTP_400_BAD_REQUEST) + self.assertFalse(response.ok) + + def test_created_true(self): + response = Response(status=status.HTTP_201_CREATED) + self.assertTrue(response.created) + + def test_created_false(self): + response = Response(status=status.HTTP_200_OK) + self.assertFalse(response.created)