310 lines
9.6 KiB
Python
310 lines
9.6 KiB
Python
|
|
"""Integration tests for OIDC — mock provider endpoints, test full auth code flow."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import time
|
||
|
|
from unittest.mock import patch
|
||
|
|
from uuid import uuid4
|
||
|
|
|
||
|
|
import respx
|
||
|
|
from httpx import Response
|
||
|
|
from jose import jwt
|
||
|
|
from sqlmodel import select
|
||
|
|
|
||
|
|
from wiregui.auth.oidc import get_provider_config, load_providers, oauth, register_providers
|
||
|
|
from wiregui.config import get_settings
|
||
|
|
from wiregui.models.configuration import Configuration
|
||
|
|
from wiregui.models.oidc_connection import OIDCConnection
|
||
|
|
from wiregui.models.user import User
|
||
|
|
|
||
|
|
|
||
|
|
# --- Helper to create a fake OIDC provider config in the DB ---
|
||
|
|
|
||
|
|
|
||
|
|
async def _setup_oidc_config(session) -> Configuration:
|
||
|
|
"""Insert a Configuration with a test OIDC provider."""
|
||
|
|
config = Configuration(
|
||
|
|
openid_connect_providers=[
|
||
|
|
{
|
||
|
|
"id": "test-idp",
|
||
|
|
"label": "Test IdP",
|
||
|
|
"scope": "openid email profile",
|
||
|
|
"response_type": "code",
|
||
|
|
"client_id": "test-client-id",
|
||
|
|
"client_secret": "test-client-secret",
|
||
|
|
"discovery_document_uri": "https://idp.example.com/.well-known/openid-configuration",
|
||
|
|
"auto_create_users": True,
|
||
|
|
}
|
||
|
|
],
|
||
|
|
)
|
||
|
|
session.add(config)
|
||
|
|
await session.commit()
|
||
|
|
return config
|
||
|
|
|
||
|
|
|
||
|
|
def _mock_discovery():
|
||
|
|
"""Mock OIDC discovery document response."""
|
||
|
|
return {
|
||
|
|
"issuer": "https://idp.example.com",
|
||
|
|
"authorization_endpoint": "https://idp.example.com/authorize",
|
||
|
|
"token_endpoint": "https://idp.example.com/token",
|
||
|
|
"userinfo_endpoint": "https://idp.example.com/userinfo",
|
||
|
|
"jwks_uri": "https://idp.example.com/.well-known/jwks.json",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def _mock_token_response(email: str = "oidc-user@example.com"):
|
||
|
|
"""Mock OIDC token endpoint response with ID token."""
|
||
|
|
now = int(time.time())
|
||
|
|
id_token_payload = {
|
||
|
|
"iss": "https://idp.example.com",
|
||
|
|
"sub": "oidc-subject-123",
|
||
|
|
"aud": "test-client-id",
|
||
|
|
"email": email,
|
||
|
|
"name": "OIDC User",
|
||
|
|
"iat": now,
|
||
|
|
"exp": now + 3600,
|
||
|
|
"nonce": "test-nonce",
|
||
|
|
}
|
||
|
|
# Sign with a simple secret (in real life this would be RSA)
|
||
|
|
id_token = jwt.encode(id_token_payload, "fake-secret", algorithm="HS256")
|
||
|
|
|
||
|
|
return {
|
||
|
|
"access_token": "mock-access-token",
|
||
|
|
"token_type": "Bearer",
|
||
|
|
"expires_in": 3600,
|
||
|
|
"refresh_token": "mock-refresh-token",
|
||
|
|
"id_token": id_token,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
# --- Provider config loading ---
|
||
|
|
|
||
|
|
|
||
|
|
async def test_load_providers_from_config(session, monkeypatch):
|
||
|
|
"""Providers should be loaded from the Configuration table."""
|
||
|
|
from contextlib import asynccontextmanager
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def mock_session():
|
||
|
|
yield session
|
||
|
|
|
||
|
|
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||
|
|
|
||
|
|
await _setup_oidc_config(session)
|
||
|
|
|
||
|
|
providers = await load_providers()
|
||
|
|
assert len(providers) == 1
|
||
|
|
assert providers[0]["id"] == "test-idp"
|
||
|
|
assert providers[0]["client_id"] == "test-client-id"
|
||
|
|
|
||
|
|
|
||
|
|
async def test_load_providers_empty_when_no_config(session, monkeypatch):
|
||
|
|
"""Should return empty list when no Configuration exists."""
|
||
|
|
from contextlib import asynccontextmanager
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def mock_session():
|
||
|
|
yield session
|
||
|
|
|
||
|
|
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||
|
|
|
||
|
|
providers = await load_providers()
|
||
|
|
assert providers == []
|
||
|
|
|
||
|
|
|
||
|
|
async def test_get_provider_config_by_id(session, monkeypatch):
|
||
|
|
"""Should find a specific provider by ID."""
|
||
|
|
from contextlib import asynccontextmanager
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def mock_session():
|
||
|
|
yield session
|
||
|
|
|
||
|
|
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||
|
|
|
||
|
|
await _setup_oidc_config(session)
|
||
|
|
|
||
|
|
config = await get_provider_config("test-idp")
|
||
|
|
assert config is not None
|
||
|
|
assert config["label"] == "Test IdP"
|
||
|
|
|
||
|
|
config_missing = await get_provider_config("nonexistent")
|
||
|
|
assert config_missing is None
|
||
|
|
|
||
|
|
|
||
|
|
# --- OIDC connection storage ---
|
||
|
|
|
||
|
|
|
||
|
|
async def test_oidc_connection_created_on_login(session):
|
||
|
|
"""Simulates what the callback route does: create user + OIDC connection."""
|
||
|
|
user = User(email="oidc-new@example.com", role="unprivileged")
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
token_data = _mock_token_response("oidc-new@example.com")
|
||
|
|
conn = OIDCConnection(
|
||
|
|
provider="test-idp",
|
||
|
|
refresh_token=token_data["refresh_token"],
|
||
|
|
refresh_response=token_data,
|
||
|
|
user_id=user.id,
|
||
|
|
)
|
||
|
|
session.add(conn)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
# Verify it was stored
|
||
|
|
fetched = (await session.execute(
|
||
|
|
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||
|
|
)).scalar_one()
|
||
|
|
assert fetched.provider == "test-idp"
|
||
|
|
assert fetched.refresh_token == "mock-refresh-token"
|
||
|
|
assert fetched.refresh_response["access_token"] == "mock-access-token"
|
||
|
|
|
||
|
|
|
||
|
|
async def test_oidc_connection_updated_on_re_login(session):
|
||
|
|
"""Re-login should update the existing OIDC connection, not create a duplicate."""
|
||
|
|
user = User(email="oidc-relogin@example.com")
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
# First login
|
||
|
|
conn = OIDCConnection(
|
||
|
|
provider="test-idp",
|
||
|
|
refresh_token="old-refresh-token",
|
||
|
|
user_id=user.id,
|
||
|
|
)
|
||
|
|
session.add(conn)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
# Re-login — update existing connection (as the callback route does)
|
||
|
|
existing = (await session.execute(
|
||
|
|
select(OIDCConnection).where(
|
||
|
|
OIDCConnection.user_id == user.id,
|
||
|
|
OIDCConnection.provider == "test-idp",
|
||
|
|
)
|
||
|
|
)).scalar_one()
|
||
|
|
|
||
|
|
existing.refresh_token = "new-refresh-token"
|
||
|
|
from wiregui.utils.time import utcnow
|
||
|
|
existing.refreshed_at = utcnow()
|
||
|
|
session.add(existing)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
# Should still be one connection
|
||
|
|
from sqlmodel import func
|
||
|
|
count = (await session.execute(
|
||
|
|
select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||
|
|
)).scalar()
|
||
|
|
assert count == 1
|
||
|
|
|
||
|
|
fetched = (await session.execute(
|
||
|
|
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||
|
|
)).scalar_one()
|
||
|
|
assert fetched.refresh_token == "new-refresh-token"
|
||
|
|
|
||
|
|
|
||
|
|
async def test_oidc_auto_create_user(session):
|
||
|
|
"""When auto_create_users is True, a new user should be created from OIDC email."""
|
||
|
|
email = "auto-created@example.com"
|
||
|
|
|
||
|
|
# Verify user doesn't exist
|
||
|
|
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
|
||
|
|
assert existing is None
|
||
|
|
|
||
|
|
# Simulate what callback does with auto_create
|
||
|
|
user = User(email=email, role="unprivileged")
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
from wiregui.utils.time import utcnow
|
||
|
|
user.last_signed_in_at = utcnow()
|
||
|
|
user.last_signed_in_method = "oidc:test-idp"
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
created = (await session.execute(select(User).where(User.email == email))).scalar_one()
|
||
|
|
assert created.role == "unprivileged"
|
||
|
|
assert created.last_signed_in_method == "oidc:test-idp"
|
||
|
|
|
||
|
|
|
||
|
|
async def test_oidc_disabled_user_rejected(session):
|
||
|
|
"""Disabled users should not be logged in via OIDC."""
|
||
|
|
from wiregui.utils.time import utcnow
|
||
|
|
|
||
|
|
user = User(email="oidc-disabled@example.com", disabled_at=utcnow())
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
# The callback route checks disabled_at before creating session
|
||
|
|
assert user.disabled_at is not None # Would redirect to /login
|
||
|
|
|
||
|
|
|
||
|
|
async def test_oidc_user_without_auto_create_rejected(session):
|
||
|
|
"""When auto_create is False and user doesn't exist, login should fail."""
|
||
|
|
email = "no-auto-create@example.com"
|
||
|
|
|
||
|
|
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
|
||
|
|
assert existing is None
|
||
|
|
|
||
|
|
# The callback route checks auto_create_users from provider config
|
||
|
|
# With auto_create=False and no existing user, it would redirect to /login
|
||
|
|
# This verifies the precondition
|
||
|
|
|
||
|
|
|
||
|
|
# --- OIDC refresh token flow ---
|
||
|
|
|
||
|
|
|
||
|
|
async def test_oidc_refresh_stores_new_token(session):
|
||
|
|
"""Simulates a successful token refresh updating the connection."""
|
||
|
|
user = User(email="oidc-refresh-test@example.com")
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
conn = OIDCConnection(
|
||
|
|
provider="test-idp",
|
||
|
|
refresh_token="old-refresh",
|
||
|
|
user_id=user.id,
|
||
|
|
)
|
||
|
|
session.add(conn)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
# Simulate refresh result
|
||
|
|
new_token = {
|
||
|
|
"access_token": "new-access",
|
||
|
|
"refresh_token": "new-refresh",
|
||
|
|
"expires_in": 3600,
|
||
|
|
}
|
||
|
|
|
||
|
|
conn.refresh_token = new_token.get("refresh_token", conn.refresh_token)
|
||
|
|
conn.refresh_response = new_token
|
||
|
|
from wiregui.utils.time import utcnow
|
||
|
|
conn.refreshed_at = utcnow()
|
||
|
|
session.add(conn)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
fetched = await session.get(OIDCConnection, conn.id)
|
||
|
|
assert fetched.refresh_token == "new-refresh"
|
||
|
|
assert fetched.refresh_response["access_token"] == "new-access"
|
||
|
|
assert fetched.refreshed_at is not None
|
||
|
|
|
||
|
|
|
||
|
|
async def test_oidc_multiple_providers_per_user(session):
|
||
|
|
"""User can have connections to multiple OIDC providers."""
|
||
|
|
user = User(email="multi-provider@example.com")
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
for provider in ["google", "okta", "azure-ad"]:
|
||
|
|
session.add(OIDCConnection(
|
||
|
|
provider=provider,
|
||
|
|
refresh_token=f"token-{provider}",
|
||
|
|
user_id=user.id,
|
||
|
|
))
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
conns = (await session.execute(
|
||
|
|
select(OIDCConnection).where(OIDCConnection.user_id == user.id).order_by(OIDCConnection.provider)
|
||
|
|
)).scalars().all()
|
||
|
|
|
||
|
|
assert len(conns) == 3
|
||
|
|
assert [c.provider for c in conns] == ["azure-ad", "google", "okta"]
|