diff --git a/.vale/styles/spelling-exceptions.txt b/.vale/styles/spelling-exceptions.txt index ab64dd7296..d018438b7b 100644 --- a/.vale/styles/spelling-exceptions.txt +++ b/.vale/styles/spelling-exceptions.txt @@ -54,6 +54,7 @@ hfids hostname httpx human_friendly_id +id_token idempotency include_in_menu Infrahub diff --git a/backend/infrahub/api/oidc.py b/backend/infrahub/api/oidc.py index 719d2cf588..338b9c45ce 100644 --- a/backend/infrahub/api/oidc.py +++ b/backend/infrahub/api/oidc.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from urllib.parse import urljoin +import jwt import ujson from authlib.integrations.httpx_client import AsyncOAuth2Client from fastapi import APIRouter, Depends, Request, Response @@ -138,7 +139,7 @@ async def token( with trace.get_tracer(__name__).start_as_current_span("sso_token_request") as span: span.set_attribute("token_request_data", ujson.dumps(token_response.json())) - payload = token_response.json() + payload: dict[str, Any] = token_response.json() headers = {"Authorization": f"{payload.get('token_type')} {payload.get('access_token')}"} @@ -148,8 +149,10 @@ async def token( userinfo_response = await service.http.post(str(oidc_config.userinfo_endpoint), headers=headers) _validate_response(response=userinfo_response) - user_info = userinfo_response.json() - sso_groups = user_info.get("groups", []) + user_info: dict[str, Any] = userinfo_response.json() + sso_groups = user_info.get("groups") or await _get_id_token_groups( + oidc_config=oidc_config, service=service, payload=payload, client_id=provider.client_id + ) if not sso_groups and config.SETTINGS.security.sso_user_default_group: sso_groups = [config.SETTINGS.security.sso_user_default_group] @@ -185,3 +188,28 @@ def _validate_response(response: httpx.Response) -> None: body=response.json(), ) raise GatewayError(message="Invalid response from Authentication provider") + + +async def _get_id_token_groups( + oidc_config: OIDCDiscoveryConfig, service: InfrahubServices, payload: dict[str, Any], client_id: str +) -> list[str]: + id_token = payload.get("id_token") + if not id_token: + return [] + jwks = await service.http.get(url=str(oidc_config.jwks_uri)) + + jwk_client = jwt.PyJWKClient(uri=str(oidc_config.jwks_uri), cache_jwk_set=True) + if jwk_client.jwk_set_cache: + jwk_client.jwk_set_cache.put(jwks.json()) + + signing_key = jwk_client.get_signing_key_from_jwt(id_token) + + decoded_token: dict[str, Any] = jwt.decode( + jwt=id_token, + key=signing_key.key, + algorithms=oidc_config.id_token_signing_alg_values_supported, + audience=client_id, + issuer=str(oidc_config.issuer), + ) + + return decoded_token.get("groups", []) diff --git a/backend/tests/adapters/http.py b/backend/tests/adapters/http.py new file mode 100644 index 0000000000..066f59733b --- /dev/null +++ b/backend/tests/adapters/http.py @@ -0,0 +1,34 @@ +from typing import Any + +import httpx + +from infrahub.services.adapters.http import InfrahubHTTP + + +class MemoryHTTP(InfrahubHTTP): + def __init__(self) -> None: + self._get_response: dict[str, httpx.Response] = {} + self._post_response: dict[str, httpx.Response] = {} + + async def get( + self, + url: str, + headers: dict[str, Any] | None = None, + ) -> httpx.Response: + return self._get_response[url] + + async def post( + self, + url: str, + data: Any | None = None, + json: Any | None = None, + headers: dict[str, Any] | None = None, + verify: bool | None = None, + ) -> httpx.Response: + return self._post_response[url] + + def add_get_response(self, url: str, response: httpx.Response) -> None: + self._get_response[url] = response + + def add_post_response(self, url: str, response: httpx.Response) -> None: + self._post_response[url] = response diff --git a/backend/tests/unit/api/test_oidc.py b/backend/tests/unit/api/test_oidc.py new file mode 100644 index 0000000000..65a4eac409 --- /dev/null +++ b/backend/tests/unit/api/test_oidc.py @@ -0,0 +1,149 @@ +import json +import time +import uuid +from typing import Any + +import httpx +from jwcrypto import jwk, jwt +from pydantic_core import Url + +from infrahub.api.oidc import OIDCDiscoveryConfig, _get_id_token_groups +from infrahub.services import InfrahubServices +from tests.adapters.http import MemoryHTTP + + +async def test_get_id_token_groups_for_oidc() -> None: + memory_http = MemoryHTTP() + service = InfrahubServices(http=memory_http) + client_id = "testing-oicd-1234" + + helper = OIDCTestHelper() + token_response = helper.generate_token_response( + username="testuser", + groups=["operators"], + client_id=client_id, + issuer=str(OIDC_CONFIG.issuer), + ) + + memory_http.add_get_response( + url=str(OIDC_CONFIG.jwks_uri), + response=httpx.Response(status_code=200, content=json.dumps(helper.jwks_payload)), + ) + + groups = await _get_id_token_groups( + oidc_config=OIDC_CONFIG, + service=service, + payload=token_response, + client_id=client_id, + ) + + assert groups == ["operators"] + + +async def test_get_id_token_groups_for_oidc_no_id_token() -> None: + memory_http = MemoryHTTP() + service = InfrahubServices(http=memory_http) + client_id = "testing-oicd-1234" + + helper = OIDCTestHelper() + token_response = helper.generate_token_response( + username="testuser", + groups=["operators"], + client_id=client_id, + issuer=str(OIDC_CONFIG.issuer), + ) + token_response.pop("id_token") + + memory_http.add_get_response( + url=str(OIDC_CONFIG.jwks_uri), + response=httpx.Response(status_code=200, content=json.dumps(helper.jwks_payload)), + ) + + groups = await _get_id_token_groups( + oidc_config=OIDC_CONFIG, + service=service, + payload=token_response, + client_id=client_id, + ) + + assert groups == [] + + +class OIDCTestHelper: + def __init__(self) -> None: + self.key: jwk.JWK = jwk.JWK.generate(kty="RSA", size=2048) + self.kid = str(uuid.uuid4()) + + self.jwks_payload = { + "keys": [ + { + **json.loads(self.key.export_public()), + "kid": self.kid, + } + ] + } + + def generate_token_response(self, username: str, groups: list[str], client_id: str, issuer: str) -> dict[str, Any]: + current_time = int(time.time()) + expiration_time = current_time + 600 + + id_token = jwt.JWT( + header={"alg": "RS256", "kid": self.kid}, + claims={ + "sub": str(uuid.uuid4()), + "aud": client_id, + "iss": issuer, + "exp": expiration_time, + "iat": current_time, + "auth_time": current_time, + "name": username, + "groups": groups, + }, + ) + id_token.make_signed_token(self.key) + + return { + "access_token": id_token.serialize(), + "expires_in": 600, + "refresh_expires_in": 1800, + "id_token": id_token.serialize(), + "token_type": "Bearer", + "scope": "openid profile email", + } + + +OIDC_CONFIG = OIDCDiscoveryConfig( + issuer=Url("https://oidc.example.com/realms/infrahub-oidc"), + authorization_endpoint=Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/auth"), + token_endpoint=Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/token"), + userinfo_endpoint=Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/userinfo"), + jwks_uri=Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/certs"), + revocation_endpoint=Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/revoke"), + registration_endpoint=Url("https://oidc.example.com/realms/infrahub-oidc/clients-registrations/openid-connect"), + introspection_endpoint=Url( + "https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/token/introspect" + ), + end_session_endpoint=Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/logout"), + frontchannel_logout_supported=True, + frontchannel_logout_session_supported=True, + grant_types_supported=["authorization_code", "implicit"], + response_types_supported=["code", "id_token", "token"], + subject_types_supported=["public"], + id_token_signing_alg_values_supported=["RS256"], + scopes_supported=["openid", "profile", "email"], + token_endpoint_auth_methods_supported=["client_secret_basic"], + claims_supported=["sub", "name", "email"], + acr_values_supported=["1"], + request_parameter_supported=True, + request_uri_parameter_supported=True, + require_request_uri_registration=True, + code_challenge_methods_supported=["S256"], + tls_client_certificate_bound_access_tokens=True, + mtls_endpoint_aliases={ + "token_endpoint": Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/token"), + "revocation_endpoint": Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/revoke"), + "introspection_endpoint": Url( + "https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/token/introspect" + ), + }, +) diff --git a/changelog/5464.added.md b/changelog/5464.added.md new file mode 100644 index 0000000000..43bd088908 --- /dev/null +++ b/changelog/5464.added.md @@ -0,0 +1 @@ +Allow OIDC providers to fall back to id_token for group membership reports if they are not provided within the userinfo URL. This allows for group support using Azure. diff --git a/poetry.lock b/poetry.lock index 5c580b06d4..cd7239f59b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aio-pika" @@ -2127,6 +2127,21 @@ files = [ [package.dependencies] referencing = ">=0.31.0" +[[package]] +name = "jwcrypto" +version = "1.5.6" +description = "Implementation of JOSE Web standards" +optional = false +python-versions = ">= 3.8" +files = [ + {file = "jwcrypto-1.5.6-py3-none-any.whl", hash = "sha256:150d2b0ebbdb8f40b77f543fb44ffd2baeff48788be71f67f03566692fd55789"}, + {file = "jwcrypto-1.5.6.tar.gz", hash = "sha256:771a87762a0c081ae6166958a954f80848820b2ab066937dc8b8379d65b1b039"}, +] + +[package.dependencies] +cryptography = ">=3.4" +typing-extensions = ">=4.5.0" + [[package]] name = "kiwisolver" version = "1.4.7" @@ -5909,4 +5924,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10, < 3.13" -content-hash = "aa89046e03c3dc2765fd0136ff213b1619c7e5099009a301ced67d4afe9bf494" +content-hash = "6f3bafc0265af0711b0459b9cfdd8457de22d58fe18e62a42293069821f86731" diff --git a/pyproject.toml b/pyproject.toml index 697985620c..057342e847 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ ruamel-yaml = "^0.18.6" pytest-httpx = ">=0.30" docker = "^7.1.0" psutil = "^6.1.0" +jwcrypto = "1.5.6" [tool.poetry.scripts] infrahub = "infrahub.cli:app"