"""Integration tests for OIDC — mock provider endpoints, test full auth code flow.""" import json import time from unittest.mock import patch from uuid import uuid4 import respx from httpx import Response from jose import jwt from sqlmodel import select from wiregui.auth.oidc import get_provider_config, load_providers, oauth, register_providers from wiregui.config import get_settings from wiregui.models.configuration import Configuration from wiregui.models.oidc_connection import OIDCConnection from wiregui.models.user import User # --- Helper to create a fake OIDC provider config in the DB --- async def _setup_oidc_config(session) -> Configuration: """Insert a Configuration with a test OIDC provider.""" config = Configuration( openid_connect_providers=[ { "id": "test-idp", "label": "Test IdP", "scope": "openid email profile", "response_type": "code", "client_id": "test-client-id", "client_secret": "test-client-secret", "discovery_document_uri": "https://idp.example.com/.well-known/openid-configuration", "auto_create_users": True, } ], ) session.add(config) await session.commit() return config def _mock_discovery(): """Mock OIDC discovery document response.""" return { "issuer": "https://idp.example.com", "authorization_endpoint": "https://idp.example.com/authorize", "token_endpoint": "https://idp.example.com/token", "userinfo_endpoint": "https://idp.example.com/userinfo", "jwks_uri": "https://idp.example.com/.well-known/jwks.json", } def _mock_token_response(email: str = "oidc-user@example.com"): """Mock OIDC token endpoint response with ID token.""" now = int(time.time()) id_token_payload = { "iss": "https://idp.example.com", "sub": "oidc-subject-123", "aud": "test-client-id", "email": email, "name": "OIDC User", "iat": now, "exp": now + 3600, "nonce": "test-nonce", } # Sign with a simple secret (in real life this would be RSA) id_token = jwt.encode(id_token_payload, "fake-secret", algorithm="HS256") return { "access_token": "mock-access-token", "token_type": "Bearer", "expires_in": 3600, "refresh_token": "mock-refresh-token", "id_token": id_token, } # --- Provider config loading --- async def test_load_providers_from_config(session, monkeypatch): """Providers should be loaded from the Configuration table.""" from contextlib import asynccontextmanager @asynccontextmanager async def mock_session(): yield session monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session) await _setup_oidc_config(session) providers = await load_providers() assert len(providers) == 1 assert providers[0]["id"] == "test-idp" assert providers[0]["client_id"] == "test-client-id" async def test_load_providers_empty_when_no_config(session, monkeypatch): """Should return empty list when no Configuration exists.""" from contextlib import asynccontextmanager @asynccontextmanager async def mock_session(): yield session monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session) providers = await load_providers() assert providers == [] async def test_get_provider_config_by_id(session, monkeypatch): """Should find a specific provider by ID.""" from contextlib import asynccontextmanager @asynccontextmanager async def mock_session(): yield session monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session) await _setup_oidc_config(session) config = await get_provider_config("test-idp") assert config is not None assert config["label"] == "Test IdP" config_missing = await get_provider_config("nonexistent") assert config_missing is None # --- OIDC connection storage --- async def test_oidc_connection_created_on_login(session): """Simulates what the callback route does: create user + OIDC connection.""" user = User(email="oidc-new@example.com", role="unprivileged") session.add(user) await session.flush() token_data = _mock_token_response("oidc-new@example.com") conn = OIDCConnection( provider="test-idp", refresh_token=token_data["refresh_token"], refresh_response=token_data, user_id=user.id, ) session.add(conn) await session.flush() # Verify it was stored fetched = (await session.execute( select(OIDCConnection).where(OIDCConnection.user_id == user.id) )).scalar_one() assert fetched.provider == "test-idp" assert fetched.refresh_token == "mock-refresh-token" assert fetched.refresh_response["access_token"] == "mock-access-token" async def test_oidc_connection_updated_on_re_login(session): """Re-login should update the existing OIDC connection, not create a duplicate.""" user = User(email="oidc-relogin@example.com") session.add(user) await session.flush() # First login conn = OIDCConnection( provider="test-idp", refresh_token="old-refresh-token", user_id=user.id, ) session.add(conn) await session.flush() # Re-login — update existing connection (as the callback route does) existing = (await session.execute( select(OIDCConnection).where( OIDCConnection.user_id == user.id, OIDCConnection.provider == "test-idp", ) )).scalar_one() existing.refresh_token = "new-refresh-token" from wiregui.utils.time import utcnow existing.refreshed_at = utcnow() session.add(existing) await session.flush() # Should still be one connection from sqlmodel import func count = (await session.execute( select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id) )).scalar() assert count == 1 fetched = (await session.execute( select(OIDCConnection).where(OIDCConnection.user_id == user.id) )).scalar_one() assert fetched.refresh_token == "new-refresh-token" async def test_oidc_auto_create_user(session): """When auto_create_users is True, a new user should be created from OIDC email.""" email = "auto-created@example.com" # Verify user doesn't exist existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none() assert existing is None # Simulate what callback does with auto_create user = User(email=email, role="unprivileged") session.add(user) await session.flush() from wiregui.utils.time import utcnow user.last_signed_in_at = utcnow() user.last_signed_in_method = "oidc:test-idp" session.add(user) await session.flush() created = (await session.execute(select(User).where(User.email == email))).scalar_one() assert created.role == "unprivileged" assert created.last_signed_in_method == "oidc:test-idp" async def test_oidc_disabled_user_rejected(session): """Disabled users should not be logged in via OIDC.""" from wiregui.utils.time import utcnow user = User(email="oidc-disabled@example.com", disabled_at=utcnow()) session.add(user) await session.flush() # The callback route checks disabled_at before creating session assert user.disabled_at is not None # Would redirect to /login async def test_oidc_user_without_auto_create_rejected(session): """When auto_create is False and user doesn't exist, login should fail.""" email = "no-auto-create@example.com" existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none() assert existing is None # The callback route checks auto_create_users from provider config # With auto_create=False and no existing user, it would redirect to /login # This verifies the precondition # --- OIDC refresh token flow --- async def test_oidc_refresh_stores_new_token(session): """Simulates a successful token refresh updating the connection.""" user = User(email="oidc-refresh-test@example.com") session.add(user) await session.flush() conn = OIDCConnection( provider="test-idp", refresh_token="old-refresh", user_id=user.id, ) session.add(conn) await session.flush() # Simulate refresh result new_token = { "access_token": "new-access", "refresh_token": "new-refresh", "expires_in": 3600, } conn.refresh_token = new_token.get("refresh_token", conn.refresh_token) conn.refresh_response = new_token from wiregui.utils.time import utcnow conn.refreshed_at = utcnow() session.add(conn) await session.flush() fetched = await session.get(OIDCConnection, conn.id) assert fetched.refresh_token == "new-refresh" assert fetched.refresh_response["access_token"] == "new-access" assert fetched.refreshed_at is not None async def test_oidc_multiple_providers_per_user(session): """User can have connections to multiple OIDC providers.""" user = User(email="multi-provider@example.com") session.add(user) await session.flush() for provider in ["google", "okta", "azure-ad"]: session.add(OIDCConnection( provider=provider, refresh_token=f"token-{provider}", user_id=user.id, )) await session.flush() conns = (await session.execute( select(OIDCConnection).where(OIDCConnection.user_id == user.id).order_by(OIDCConnection.provider) )).scalars().all() assert len(conns) == 3 assert [c.provider for c in conns] == ["azure-ad", "google", "okta"]