Skip to content

Commit

Permalink
Allow OIDC fallback to id_token for group support
Browse files Browse the repository at this point in the history
Fixes #5464
  • Loading branch information
ogenstad committed Jan 21, 2025
1 parent 4929486 commit cdcf22d
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 6 deletions.
1 change: 1 addition & 0 deletions .vale/styles/spelling-exceptions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ hfids
hostname
httpx
human_friendly_id
id_token
idempotency
include_in_menu
Infrahub
Expand Down
36 changes: 32 additions & 4 deletions backend/infrahub/api/oidc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin

import jwt
import ujson
from authlib.integrations.httpx_client import AsyncOAuth2Client
from fastapi import APIRouter, Depends, Request, Response
Expand Down Expand Up @@ -138,7 +139,7 @@ async def token(

with trace.get_tracer(__name__).start_as_current_span("sso_token_request") as span:
span.set_attribute("token_request_data", ujson.dumps(token_response.json()))
payload = token_response.json()
payload: dict[str, Any] = token_response.json()

headers = {"Authorization": f"{payload.get('token_type')} {payload.get('access_token')}"}

Expand All @@ -148,8 +149,10 @@ async def token(
userinfo_response = await service.http.post(str(oidc_config.userinfo_endpoint), headers=headers)

_validate_response(response=userinfo_response)
user_info = userinfo_response.json()
sso_groups = user_info.get("groups", [])
user_info: dict[str, Any] = userinfo_response.json()
sso_groups = user_info.get("groups") or await _get_id_token_groups(
oidc_config=oidc_config, service=service, payload=payload, client_id=provider.client_id
)

if not sso_groups and config.SETTINGS.security.sso_user_default_group:
sso_groups = [config.SETTINGS.security.sso_user_default_group]
Expand Down Expand Up @@ -185,3 +188,28 @@ def _validate_response(response: httpx.Response) -> None:
body=response.json(),
)
raise GatewayError(message="Invalid response from Authentication provider")


async def _get_id_token_groups(
oidc_config: OIDCDiscoveryConfig, service: InfrahubServices, payload: dict[str, Any], client_id: str
) -> list[str]:
id_token = payload.get("id_token")
if not id_token:
return []
jwks = await service.http.get(url=str(oidc_config.jwks_uri))

jwk_client = jwt.PyJWKClient(uri=str(oidc_config.jwks_uri), cache_jwk_set=True)
if jwk_client.jwk_set_cache:
jwk_client.jwk_set_cache.put(jwks.json())

signing_key = jwk_client.get_signing_key_from_jwt(id_token)

decoded_token: dict[str, Any] = jwt.decode(
jwt=id_token,
key=signing_key.key,
algorithms=oidc_config.id_token_signing_alg_values_supported,
audience=client_id,
issuer=str(oidc_config.issuer),
)

return decoded_token.get("groups", [])
34 changes: 34 additions & 0 deletions backend/tests/adapters/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any

import httpx

from infrahub.services.adapters.http import InfrahubHTTP


class MemoryHTTP(InfrahubHTTP):
def __init__(self) -> None:
self._get_response: dict[str, httpx.Response] = {}
self._post_response: dict[str, httpx.Response] = {}

async def get(
self,
url: str,
headers: dict[str, Any] | None = None,
) -> httpx.Response:
return self._get_response[url]

async def post(
self,
url: str,
data: Any | None = None,
json: Any | None = None,
headers: dict[str, Any] | None = None,
verify: bool | None = None,
) -> httpx.Response:
return self._post_response[url]

def add_get_response(self, url: str, response: httpx.Response) -> None:
self._get_response[url] = response

def add_post_response(self, url: str, response: httpx.Response) -> None:
self._post_response[url] = response
149 changes: 149 additions & 0 deletions backend/tests/unit/api/test_oidc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import json
import time
import uuid
from typing import Any

import httpx
from jwcrypto import jwk, jwt
from pydantic_core import Url

from infrahub.api.oidc import OIDCDiscoveryConfig, _get_id_token_groups
from infrahub.services import InfrahubServices
from tests.adapters.http import MemoryHTTP


async def test_get_id_token_groups_for_oidc() -> None:
memory_http = MemoryHTTP()
service = InfrahubServices(http=memory_http)
client_id = "testing-oicd-1234"

helper = OIDCTestHelper()
token_response = helper.generate_token_response(
username="testuser",
groups=["operators"],
client_id=client_id,
issuer=str(OIDC_CONFIG.issuer),
)

memory_http.add_get_response(
url=str(OIDC_CONFIG.jwks_uri),
response=httpx.Response(status_code=200, content=json.dumps(helper.jwks_payload)),
)

groups = await _get_id_token_groups(
oidc_config=OIDC_CONFIG,
service=service,
payload=token_response,
client_id=client_id,
)

assert groups == ["operators"]


async def test_get_id_token_groups_for_oidc_no_id_token() -> None:
memory_http = MemoryHTTP()
service = InfrahubServices(http=memory_http)
client_id = "testing-oicd-1234"

helper = OIDCTestHelper()
token_response = helper.generate_token_response(
username="testuser",
groups=["operators"],
client_id=client_id,
issuer=str(OIDC_CONFIG.issuer),
)
token_response.pop("id_token")

memory_http.add_get_response(
url=str(OIDC_CONFIG.jwks_uri),
response=httpx.Response(status_code=200, content=json.dumps(helper.jwks_payload)),
)

groups = await _get_id_token_groups(
oidc_config=OIDC_CONFIG,
service=service,
payload=token_response,
client_id=client_id,
)

assert groups == []


class OIDCTestHelper:
def __init__(self) -> None:
self.key: jwk.JWK = jwk.JWK.generate(kty="RSA", size=2048)
self.kid = str(uuid.uuid4())

self.jwks_payload = {
"keys": [
{
**json.loads(self.key.export_public()),
"kid": self.kid,
}
]
}

def generate_token_response(self, username: str, groups: list[str], client_id: str, issuer: str) -> dict[str, Any]:
current_time = int(time.time())
expiration_time = current_time + 600

id_token = jwt.JWT(
header={"alg": "RS256", "kid": self.kid},
claims={
"sub": str(uuid.uuid4()),
"aud": client_id,
"iss": issuer,
"exp": expiration_time,
"iat": current_time,
"auth_time": current_time,
"name": username,
"groups": groups,
},
)
id_token.make_signed_token(self.key)

return {
"access_token": id_token.serialize(),
"expires_in": 600,
"refresh_expires_in": 1800,
"id_token": id_token.serialize(),
"token_type": "Bearer",
"scope": "openid profile email",
}


OIDC_CONFIG = OIDCDiscoveryConfig(
issuer=Url("https://oidc.example.com/realms/infrahub-oidc"),
authorization_endpoint=Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/auth"),
token_endpoint=Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/token"),
userinfo_endpoint=Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/userinfo"),
jwks_uri=Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/certs"),
revocation_endpoint=Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/revoke"),
registration_endpoint=Url("https://oidc.example.com/realms/infrahub-oidc/clients-registrations/openid-connect"),
introspection_endpoint=Url(
"https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/token/introspect"
),
end_session_endpoint=Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/logout"),
frontchannel_logout_supported=True,
frontchannel_logout_session_supported=True,
grant_types_supported=["authorization_code", "implicit"],
response_types_supported=["code", "id_token", "token"],
subject_types_supported=["public"],
id_token_signing_alg_values_supported=["RS256"],
scopes_supported=["openid", "profile", "email"],
token_endpoint_auth_methods_supported=["client_secret_basic"],
claims_supported=["sub", "name", "email"],
acr_values_supported=["1"],
request_parameter_supported=True,
request_uri_parameter_supported=True,
require_request_uri_registration=True,
code_challenge_methods_supported=["S256"],
tls_client_certificate_bound_access_tokens=True,
mtls_endpoint_aliases={
"token_endpoint": Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/token"),
"revocation_endpoint": Url("https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/revoke"),
"introspection_endpoint": Url(
"https://oidc.example.com/realms/infrahub-oidc/protocol/openid-connect/token/introspect"
),
},
)
1 change: 1 addition & 0 deletions changelog/5464.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow OIDC providers to fall back to id_token for group membership reports if they are not provided within the userinfo URL. This allows for group support using Azure.
19 changes: 17 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ ruamel-yaml = "^0.18.6"
pytest-httpx = ">=0.30"
docker = "^7.1.0"
psutil = "^6.1.0"
jwcrypto = "1.5.6"

[tool.poetry.scripts]
infrahub = "infrahub.cli:app"
Expand Down

0 comments on commit cdcf22d

Please sign in to comment.