diff --git a/examples/simple.py b/examples/simple.py index 8cc9f62..722aee3 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -1,8 +1,9 @@ # mypy: disable-error-code="no-any-return" # flake8: noqa: A003 -from typing import List, Any, Dict -from fastapi import FastAPI, APIRouter +from typing import Any, Callable, Coroutine, Dict, List +from fastapi import APIRouter, FastAPI, Request, Response +from fastapi.routing import APIRoute from pydantic import BaseModel from fastapi_versionizer.versionizer import Versionizer, api_version @@ -36,6 +37,25 @@ def __init__(self) -> None: self.items: Dict[int, Any] = {} +class CounterDb: + def __init__(self) -> None: + self.counter = 0 + + def increment(self) -> None: + self.counter += 1 + +counter_db = CounterDb() + +class CounterRoute(APIRoute): + def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request: Request) -> Response: + counter_db.increment() + return await original_route_handler(request) + + return custom_route_handler + db = DB() app = FastAPI( title='test', @@ -47,7 +67,8 @@ def __init__(self) -> None: ) users_router = APIRouter( prefix='/users', - tags=['Users'] + tags=['Users'], + route_class=CounterRoute, ) items_router = APIRouter( prefix='/items', diff --git a/fastapi_versionizer/versionizer.py b/fastapi_versionizer/versionizer.py index 639b633..b93f524 100644 --- a/fastapi_versionizer/versionizer.py +++ b/fastapi_versionizer/versionizer.py @@ -324,6 +324,8 @@ def _add_route_to_router( version: Tuple[int, int] ) -> None: kwargs = dict(route.__dict__) + if route.__class__ != APIRoute: + kwargs['route_class_override'] = route.__class__ deprecated_in_version = getattr(route.endpoint, '_deprecate_in_version', None) if deprecated_in_version is not None: diff --git a/tests/test_simple.py b/tests/test_simple.py index 174077f..4837a5c 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -1,7 +1,7 @@ from fastapi.testclient import TestClient from unittest import TestCase -from examples.simple import app, versions +from examples.simple import app, counter_db, versions class TestSimpleExample(TestCase): @@ -30,6 +30,8 @@ def test_simple_example(self) -> None: self.assertEqual(404, test_client.get('/v2/versions').status_code) self.assertEqual(404, test_client.get('/latest/versions').status_code) + self.assertEqual(0, counter_db.counter) + # versions route self.assertDictEqual( { @@ -129,6 +131,8 @@ def test_simple_example(self) -> None: test_client.get('/latest/users/3').json() ) + self.assertEqual(9, counter_db.counter) + # docs self.assertEqual(200, test_client.get('/swagger').status_code) self.assertEqual(200, test_client.get('/v1/swagger').status_code)