diff --git a/neon_hana/app/routers/util.py b/neon_hana/app/routers/util.py index a4bf674..e5bc7cd 100644 --- a/neon_hana/app/routers/util.py +++ b/neon_hana/app/routers/util.py @@ -34,11 +34,13 @@ @util_route.get("/client_ip", response_class=PlainTextResponse) async def api_client_ip(request: Request) -> str: + # Validation will fail, but this increments the rate-limiting client_manager.validate_auth("", request.client.host) return request.client.host @util_route.get("/headers") async def api_headers(request: Request): + # Validation will fail, but this increments the rate-limiting client_manager.validate_auth("", request.client.host) return request.headers diff --git a/neon_hana/auth/client_manager.py b/neon_hana/auth/client_manager.py index 286f431..91eaa72 100644 --- a/neon_hana/auth/client_manager.py +++ b/neon_hana/auth/client_manager.py @@ -58,10 +58,9 @@ def __init__(self, config: dict, mq_connector: Optional[MQServiceManager] = None): self.rate_limiter = TokenThrottler(cost=1, storage=RuntimeStorage()) - # TODO: Is `authorized_clients` useful to track? # Keep a dict of `client_id` to auth tokens that have authenticated to # this instance - self.authorized_clients: Dict[str, HanaToken] = dict() + self._authorized_clients: Dict[str, AuthenticationResponse] = dict() self._access_token_lifetime = config.get("access_token_ttl", 3600 * 24) self._refresh_token_lifetime = config.get("refresh_token_ttl", 3600 * 24 * 90) @@ -79,6 +78,16 @@ def __init__(self, config: dict, self._stream_check_lock = Lock() self._mq_connector = mq_connector + @property + def authorized_clients(self) -> Dict[str, AuthenticationResponse]: + """ + Dict of `client_id` to `AuthenticationResponse` objects for clients + known by this instance. NOTE: Refresh tokens are not reliably stored + here and should never be retrievable after generation for security. + """ + # TODO: Is `authorized_clients` useful to track? + return self._authorized_clients + def _create_tokens(self, user_id: str, client_id: str, @@ -92,9 +101,9 @@ def _create_tokens(self, expiration_timestamp = creation_timestamp + self._access_token_lifetime refresh_expiration_timestamp = creation_timestamp + self._refresh_token_lifetime permissions = permissions or PermissionsConfig(core=AccessRoles.GUEST, - diana=AccessRoles.GUEST, - node=AccessRoles.GUEST, - llm=AccessRoles.GUEST) + diana=AccessRoles.GUEST, + node=AccessRoles.GUEST, + llm=AccessRoles.GUEST) token_name = token_name or kwargs.get("name") or \ datetime.fromtimestamp(creation_timestamp).isoformat() access_token_data = HanaToken(iss=self._jwt_issuer, @@ -174,9 +183,9 @@ def check_auth_request(self, client_id: str, username: str, @param origin_ip: Origin IP address of request @return: response tokens, permissions, and other metadata """ - # if client_id in self.authorized_clients: - # print(f"Using cached client: {self.authorized_clients[client_id]}") - # return self.authorized_clients[client_id] + if client_id in self.authorized_clients: + print(f"Using cached client: {self.authorized_clients[client_id]}") + return self.authorized_clients[client_id] ratelimit_id = f"auth{origin_ip}" if not self.rate_limiter.get_all_buckets(ratelimit_id): @@ -208,13 +217,15 @@ def check_auth_request(self, client_id: str, username: str, "token_name": token_name, "last_refresh_timestamp": create_time} access, refresh, config = self._create_tokens(**encode_data) - self.authorized_clients[client_id] = config + + auth_response = AuthenticationResponse(username=user.username, + client_id=client_id, + access_token=access, + refresh_token=refresh, + expiration=config.refresh_expiration_timestamp) + self.authorized_clients[client_id] = auth_response self._add_token_to_userdb(user, config) - return AuthenticationResponse(username=user.username, - client_id=client_id, - access_token=access, - refresh_token=refresh, - expiration=config.refresh_expiration_timestamp) + return auth_response def check_refresh_request(self, access_token: str, refresh_token: str, client_id: str) -> AuthenticationResponse: @@ -263,11 +274,14 @@ def check_refresh_request(self, access_token: str, refresh_token: str, else: username = token_data.sub access, refresh, config = self._create_tokens(**encode_data) - return AuthenticationResponse(username=username, + + auth_response = AuthenticationResponse(username=username, client_id=client_id, access_token=access, refresh_token=refresh, expiration=config.refresh_expiration_timestamp) + self._authorized_clients[client_id] = auth_response + return auth_response def _add_token_to_userdb(self, user: User, new_token: TokenConfig): if self._mq_connector is None: @@ -310,7 +324,9 @@ def validate_auth(self, token: str, origin_ip: str) -> bool: if auth.exp < time(): self.authorized_clients.pop(auth.client_id, None) return False - self.authorized_clients[auth.client_id] = auth + self.authorized_clients[auth.client_id] = AuthenticationResponse( + username=auth.sub, client_id=auth.client_id, access_token=token, + refresh_token="", expiration=auth.exp) return True except DecodeError: # Invalid token supplied diff --git a/tests/test_auth.py b/tests/test_auth.py index f1b5ed3..0d46b3c 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -47,16 +47,16 @@ def test_check_auth_request(self): # Check simple auth auth_resp_1 = self.client_manager.check_auth_request(**request_1) - # self.assertEqual(self.client_manager.authorized_clients[client_1], - # auth_resp_1.access_token) + self.assertEqual(self.client_manager.authorized_clients[client_1], + auth_resp_1) self.assertEqual(auth_resp_1.username, 'guest') self.assertEqual(auth_resp_1.client_id, client_1) # Check auth from different client auth_resp_2 = self.client_manager.check_auth_request(**request_2) self.assertNotEquals(auth_resp_1, auth_resp_2) - # self.assertEqual(self.client_manager.authorized_clients[client_2], - # auth_resp_2.access_token) + self.assertEqual(self.client_manager.authorized_clients[client_2], + auth_resp_2) self.assertEqual(auth_resp_2.username, 'guest') self.assertEqual(auth_resp_2.client_id, client_2)