diff --git a/tests/conftest.py b/tests/conftest.py index ad85276..e0ac10e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,12 @@ -"""Shared test fixtures — async DB session using a test database.""" +"""Shared test fixtures — async DB session using a test database. + +The module-level code below replaces ``wiregui.db.engine`` and +``wiregui.db.async_session`` with instances pointing at the **test** database +*before* any test (or other module) can grab a reference to the originals. +This means every ``from wiregui.db import async_session`` — whether in test +files or in production code like ``wiregui.utils.server_key`` — will get the +test-database session maker. +""" import os from collections.abc import AsyncGenerator @@ -8,6 +16,7 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlmodel import SQLModel +import wiregui.db as _db_module from wiregui.config import get_settings # All models must be imported so SQLModel.metadata knows about them @@ -51,19 +60,41 @@ def _ensure_test_db_sync(): _ensure_test_db_sync() +# --------------------------------------------------------------------------- +# Replace the production engine/session in wiregui.db at import time so that +# every module that does ``from wiregui.db import async_session`` picks up the +# test database. This MUST happen before test modules are collected (which +# triggers their top-level imports). +# --------------------------------------------------------------------------- +_test_engine = create_async_engine(TEST_DATABASE_URL) +_test_session_factory = async_sessionmaker(_test_engine, expire_on_commit=False) +_db_module.engine = _test_engine +_db_module.async_session = _test_session_factory + + +@pytest_asyncio.fixture(scope="session", autouse=True) +async def _setup_test_tables(): + """Create all tables once at the start of the test session, drop at end.""" + async with _test_engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + yield + async with _test_engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.drop_all) + await _test_engine.dispose() + @pytest_asyncio.fixture async def session() -> AsyncGenerator[AsyncSession]: - """Fresh engine + session per test, with table setup/teardown.""" - engine = create_async_engine(TEST_DATABASE_URL) - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) + """Per-test session with transaction isolation. - factory = async_sessionmaker(engine, expire_on_commit=False) - async with factory() as sess: + The session is bound to a connection-level transaction that is always + rolled back at teardown. When tested code calls ``session.commit()``, + SQLAlchemy only releases a SAVEPOINT — the outer transaction is never + committed, so no test data persists between tests. + """ + async with _test_engine.connect() as conn: + txn = await conn.begin() + sess = AsyncSession(bind=conn, expire_on_commit=False, join_transaction_mode="create_savepoint") yield sess - await sess.rollback() - - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.drop_all) - await engine.dispose() + await sess.close() + await txn.rollback() diff --git a/tests/test_account.py b/tests/test_account.py deleted file mode 100644 index 067c088..0000000 --- a/tests/test_account.py +++ /dev/null @@ -1,161 +0,0 @@ -"""Tests for account functionality — password changes, API tokens, OIDC connections.""" - -import hashlib -from datetime import timedelta - -from sqlmodel import func, select - -from wiregui.auth.api_token import generate_api_token -from wiregui.auth.passwords import hash_password, verify_password -from wiregui.models.api_token import ApiToken -from wiregui.models.oidc_connection import OIDCConnection -from wiregui.models.user import User -from wiregui.utils.time import utcnow - - -# --- Password change --- - - -async def test_password_change_flow(session): - """Simulate the password change flow: verify old, set new.""" - user = User(email="pw-change@example.com", password_hash=hash_password("old-password")) - session.add(user) - await session.flush() - - # Verify old password - assert verify_password("old-password", user.password_hash) is True - - # Change password - user.password_hash = hash_password("new-password") - session.add(user) - await session.flush() - - fetched = await session.get(User, user.id) - assert verify_password("new-password", fetched.password_hash) is True - assert verify_password("old-password", fetched.password_hash) is False - - -async def test_password_change_wrong_current(session): - """Wrong current password should not allow change.""" - user = User(email="pw-wrong@example.com", password_hash=hash_password("correct")) - session.add(user) - await session.flush() - - # Simulate check - assert verify_password("wrong", user.password_hash) is False - - -# --- API token management --- - - -async def test_create_multiple_tokens(session): - user = User(email="multi-token@example.com") - session.add(user) - await session.flush() - - for _ in range(3): - _, token_hash = generate_api_token() - session.add(ApiToken(token_hash=token_hash, user_id=user.id)) - await session.flush() - - count = (await session.execute( - select(func.count()).select_from(ApiToken).where(ApiToken.user_id == user.id) - )).scalar() - assert count == 3 - - -async def test_token_with_expiry(session): - user = User(email="expiry-token@example.com") - session.add(user) - await session.flush() - - _, token_hash = generate_api_token() - expires = utcnow() + timedelta(days=30) - token = ApiToken(token_hash=token_hash, expires_at=expires, user_id=user.id) - session.add(token) - await session.flush() - - fetched = await session.get(ApiToken, token.id) - assert fetched.expires_at is not None - assert fetched.expires_at > utcnow() - - -async def test_delete_token(session): - user = User(email="del-token@example.com") - session.add(user) - await session.flush() - - _, token_hash = generate_api_token() - token = ApiToken(token_hash=token_hash, user_id=user.id) - session.add(token) - await session.flush() - - await session.delete(token) - await session.flush() - - assert await session.get(ApiToken, token.id) is None - - -# --- OIDC connections --- - - -async def test_oidc_connection_create(session): - user = User(email="oidc-conn@example.com") - session.add(user) - await session.flush() - - conn = OIDCConnection( - provider="google", - refresh_token="refresh-tok-123", - refresh_response={"access_token": "at", "token_type": "Bearer"}, - refreshed_at=utcnow(), - user_id=user.id, - ) - session.add(conn) - await session.flush() - - fetched = (await session.execute( - select(OIDCConnection).where(OIDCConnection.user_id == user.id) - )).scalar_one() - assert fetched.provider == "google" - assert fetched.refresh_token == "refresh-tok-123" - assert fetched.refresh_response["access_token"] == "at" - - -async def test_multiple_oidc_providers(session): - user = User(email="multi-oidc@example.com") - session.add(user) - await session.flush() - - for provider in ["google", "okta", "azure"]: - conn = OIDCConnection(provider=provider, user_id=user.id) - session.add(conn) - await session.flush() - - count = (await session.execute( - select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id) - )).scalar() - assert count == 3 - - -async def test_oidc_connection_update_refresh_token(session): - user = User(email="oidc-refresh@example.com") - session.add(user) - await session.flush() - - conn = OIDCConnection( - provider="google", - refresh_token="old-token", - user_id=user.id, - ) - session.add(conn) - await session.flush() - - conn.refresh_token = "new-token" - conn.refreshed_at = utcnow() - session.add(conn) - await session.flush() - - fetched = await session.get(OIDCConnection, conn.id) - assert fetched.refresh_token == "new-token" - assert fetched.refreshed_at is not None diff --git a/tests/test_admin.py b/tests/test_admin.py deleted file mode 100644 index 714b814..0000000 --- a/tests/test_admin.py +++ /dev/null @@ -1,283 +0,0 @@ -"""Tests for admin functionality — user management, configuration, cascading deletes.""" - -import pytest -from sqlmodel import func, select - -from wiregui.auth.passwords import hash_password, verify_password -from wiregui.models.api_token import ApiToken -from wiregui.models.configuration import Configuration -from wiregui.models.device import Device -from wiregui.models.mfa_method import MFAMethod -from wiregui.models.rule import Rule -from wiregui.models.user import User -from wiregui.utils.time import utcnow - - -# --- User CRUD --- - - -async def test_create_user_with_role(session): - user = User(email="new-admin@test.com", password_hash=hash_password("secret"), role="admin") - session.add(user) - await session.flush() - - fetched = await session.get(User, user.id) - assert fetched.role == "admin" - assert verify_password("secret", fetched.password_hash) - - -async def test_update_user_email(session): - user = User(email="old@test.com", password_hash=hash_password("pw")) - session.add(user) - await session.flush() - - user.email = "new@test.com" - session.add(user) - await session.flush() - - fetched = await session.get(User, user.id) - assert fetched.email == "new@test.com" - - -async def test_disable_user(session): - user = User(email="active@test.com", password_hash=hash_password("pw")) - session.add(user) - await session.flush() - assert user.disabled_at is None - - user.disabled_at = utcnow() - session.add(user) - await session.flush() - - fetched = await session.get(User, user.id) - assert fetched.disabled_at is not None - - -async def test_promote_demote_user(session): - user = User(email="user@test.com", role="unprivileged") - session.add(user) - await session.flush() - assert user.role == "unprivileged" - - user.role = "admin" - session.add(user) - await session.flush() - - fetched = await session.get(User, user.id) - assert fetched.role == "admin" - - user.role = "unprivileged" - session.add(user) - await session.flush() - assert (await session.get(User, user.id)).role == "unprivileged" - - -# --- Cascading delete (manual, as we do it in the admin page) --- - - -async def test_delete_user_cascades_devices(session): - user = User(email="cascade@test.com") - session.add(user) - await session.flush() - - d1 = Device(name="d1", public_key="pk-cascade-1", ipv4="10.0.0.1", user_id=user.id) - d2 = Device(name="d2", public_key="pk-cascade-2", ipv4="10.0.0.2", user_id=user.id) - session.add_all([d1, d2]) - await session.flush() - - # Manually delete devices then user (matching admin page behavior) - devices = (await session.execute(select(Device).where(Device.user_id == user.id))).scalars().all() - for d in devices: - await session.delete(d) - await session.delete(user) - await session.flush() - - assert (await session.execute(select(func.count()).select_from(Device).where(Device.user_id == user.id))).scalar() == 0 - assert await session.get(User, user.id) is None - - -async def test_delete_user_cascades_rules(session): - user = User(email="rule-cascade@test.com") - session.add(user) - await session.flush() - - rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id) - session.add(rule) - await session.flush() - - # Delete rules then user - rules = (await session.execute(select(Rule).where(Rule.user_id == user.id))).scalars().all() - for r in rules: - await session.delete(r) - await session.delete(user) - await session.flush() - - assert (await session.execute(select(func.count()).select_from(Rule).where(Rule.user_id == user.id))).scalar() == 0 - - -# --- Configuration singleton --- - - -async def test_configuration_create_and_update(session): - config = Configuration() - session.add(config) - await session.flush() - - assert config.default_client_mtu == 1280 - assert config.local_auth_enabled is True - - config.default_client_mtu = 1400 - config.local_auth_enabled = False - config.vpn_session_duration = 3600 - session.add(config) - await session.flush() - - fetched = await session.get(Configuration, config.id) - assert fetched.default_client_mtu == 1400 - assert fetched.local_auth_enabled is False - assert fetched.vpn_session_duration == 3600 - - -async def test_configuration_oidc_providers(session): - config = Configuration() - session.add(config) - await session.flush() - - assert config.openid_connect_providers == [] - - providers = [ - { - "id": "google", - "label": "Sign in with Google", - "scope": "openid email profile", - "response_type": "code", - "client_id": "google-client-id", - "client_secret": "google-secret", - "discovery_document_uri": "https://accounts.google.com/.well-known/openid-configuration", - "auto_create_users": True, - }, - { - "id": "okta", - "label": "Okta SSO", - "scope": "openid email profile", - "response_type": "code", - "client_id": "okta-client-id", - "client_secret": "okta-secret", - "discovery_document_uri": "https://dev-123.okta.com/.well-known/openid-configuration", - "auto_create_users": False, - }, - ] - config.openid_connect_providers = providers - session.add(config) - await session.flush() - - fetched = await session.get(Configuration, config.id) - assert len(fetched.openid_connect_providers) == 2 - assert fetched.openid_connect_providers[0]["id"] == "google" - assert fetched.openid_connect_providers[1]["auto_create_users"] is False - - -async def test_configuration_update_client_defaults(session): - config = Configuration() - session.add(config) - await session.flush() - - config.default_client_endpoint = "vpn.example.com" - config.default_client_dns = ["8.8.8.8", "8.8.4.4"] - config.default_client_allowed_ips = ["10.0.0.0/8"] - config.default_client_persistent_keepalive = 30 - session.add(config) - await session.flush() - - fetched = await session.get(Configuration, config.id) - assert fetched.default_client_endpoint == "vpn.example.com" - assert fetched.default_client_dns == ["8.8.8.8", "8.8.4.4"] - assert fetched.default_client_allowed_ips == ["10.0.0.0/8"] - assert fetched.default_client_persistent_keepalive == 30 - - -async def test_configuration_security_toggles(session): - config = Configuration() - session.add(config) - await session.flush() - - config.allow_unprivileged_device_management = False - config.allow_unprivileged_device_configuration = False - config.disable_vpn_on_oidc_error = True - session.add(config) - await session.flush() - - fetched = await session.get(Configuration, config.id) - assert fetched.allow_unprivileged_device_management is False - assert fetched.allow_unprivileged_device_configuration is False - assert fetched.disable_vpn_on_oidc_error is True - - -# --- Device config overrides --- - - -async def test_device_with_custom_config(session): - user = User(email="config-user@test.com") - session.add(user) - await session.flush() - - device = Device( - name="custom-config", - public_key="pk-custom-config", - user_id=user.id, - use_default_dns=False, - use_default_endpoint=False, - use_default_mtu=False, - use_default_persistent_keepalive=False, - use_default_allowed_ips=False, - dns=["8.8.8.8"], - endpoint="custom-vpn.example.com", - mtu=1400, - persistent_keepalive=15, - allowed_ips=["10.0.0.0/8", "172.16.0.0/12"], - ) - session.add(device) - await session.flush() - - fetched = await session.get(Device, device.id) - assert fetched.use_default_dns is False - assert fetched.dns == ["8.8.8.8"] - assert fetched.endpoint == "custom-vpn.example.com" - assert fetched.mtu == 1400 - assert fetched.persistent_keepalive == 15 - assert fetched.allowed_ips == ["10.0.0.0/8", "172.16.0.0/12"] - - -async def test_device_default_flags_are_true(session): - user = User(email="defaults@test.com") - session.add(user) - await session.flush() - - device = Device(name="defaults", public_key="pk-defaults", user_id=user.id) - session.add(device) - await session.flush() - - fetched = await session.get(Device, device.id) - assert fetched.use_default_allowed_ips is True - assert fetched.use_default_dns is True - assert fetched.use_default_endpoint is True - assert fetched.use_default_mtu is True - assert fetched.use_default_persistent_keepalive is True - - -# --- User device count --- - - -async def test_user_device_count_query(session): - user = User(email="count-user@test.com") - session.add(user) - await session.flush() - - for i in range(3): - session.add(Device(name=f"d{i}", public_key=f"pk-count-{i}", user_id=user.id)) - await session.flush() - - count = (await session.execute( - select(func.count()).select_from(Device).where(Device.user_id == user.id) - )).scalar() - assert count == 3 diff --git a/tests/test_api_deps.py b/tests/test_api_deps.py index 64d8a32..f5a7455 100644 --- a/tests/test_api_deps.py +++ b/tests/test_api_deps.py @@ -1,15 +1,12 @@ """Tests for API dependency injection — Bearer token auth and admin guard.""" -import hashlib from datetime import timedelta -from uuid import uuid4 import pytest from unittest.mock import AsyncMock, MagicMock -from wiregui.auth.api_token import generate_api_token +from wiregui.auth.api_token import generate_api_token, resolve_bearer_token from wiregui.auth.passwords import hash_password -from wiregui.db import async_session from wiregui.models.api_token import ApiToken from wiregui.models.user import User from wiregui.utils.time import utcnow @@ -18,143 +15,80 @@ from wiregui.utils.time import utcnow # ========== resolve_bearer_token ========== -async def test_resolve_valid_token(): +async def test_resolve_valid_token(session): """Valid, non-expired token resolves to user.""" - from wiregui.auth.api_token import resolve_bearer_token - plaintext, token_hash = generate_api_token() - async with async_session() as session: - user = User(email="api-test@test.com", password_hash=hash_password("x"), role="admin") - session.add(user) - await session.commit() - await session.refresh(user) + user = User(email="api-test@test.com", password_hash=hash_password("x"), role="admin") + session.add(user) + await session.flush() - api_token = ApiToken( - token_hash=token_hash, - user_id=user.id, - expires_at=utcnow() + timedelta(hours=1), - ) - session.add(api_token) - await session.commit() + api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=utcnow() + timedelta(hours=1)) + session.add(api_token) + await session.flush() - try: - async with async_session() as session: - resolved = await resolve_bearer_token(session, plaintext) - assert resolved is not None - assert resolved.id == user.id - assert resolved.email == "api-test@test.com" - finally: - async with async_session() as session: - await session.delete(await session.get(ApiToken, api_token.id)) - await session.delete(await session.get(User, user.id)) - await session.commit() + resolved = await resolve_bearer_token(session, plaintext) + assert resolved is not None + assert resolved.id == user.id + assert resolved.email == "api-test@test.com" -async def test_resolve_expired_token(): +async def test_resolve_expired_token(session): """Expired token returns None.""" - from wiregui.auth.api_token import resolve_bearer_token - plaintext, token_hash = generate_api_token() - async with async_session() as session: - user = User(email="api-expired@test.com", password_hash=hash_password("x"), role="admin") - session.add(user) - await session.commit() - await session.refresh(user) + user = User(email="api-expired@test.com", password_hash=hash_password("x"), role="admin") + session.add(user) + await session.flush() - api_token = ApiToken( - token_hash=token_hash, - user_id=user.id, - expires_at=utcnow() - timedelta(hours=1), # already expired - ) - session.add(api_token) - await session.commit() + api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=utcnow() - timedelta(hours=1)) + session.add(api_token) + await session.flush() - try: - async with async_session() as session: - resolved = await resolve_bearer_token(session, plaintext) - assert resolved is None - finally: - async with async_session() as session: - await session.delete(await session.get(ApiToken, api_token.id)) - await session.delete(await session.get(User, user.id)) - await session.commit() + resolved = await resolve_bearer_token(session, plaintext) + assert resolved is None -async def test_resolve_invalid_token(): +async def test_resolve_invalid_token(session): """Nonexistent token returns None.""" - from wiregui.auth.api_token import resolve_bearer_token - - async with async_session() as session: - resolved = await resolve_bearer_token(session, "totally-bogus-token") - assert resolved is None + resolved = await resolve_bearer_token(session, "totally-bogus-token") + assert resolved is None -async def test_resolve_token_disabled_user(): +async def test_resolve_token_disabled_user(session): """Token for disabled user returns None.""" - from wiregui.auth.api_token import resolve_bearer_token - plaintext, token_hash = generate_api_token() - async with async_session() as session: - user = User( - email="api-disabled@test.com", password_hash=hash_password("x"), - role="admin", disabled_at=utcnow(), - ) - session.add(user) - await session.commit() - await session.refresh(user) + user = User( + email="api-disabled@test.com", password_hash=hash_password("x"), + role="admin", disabled_at=utcnow(), + ) + session.add(user) + await session.flush() - api_token = ApiToken( - token_hash=token_hash, - user_id=user.id, - expires_at=utcnow() + timedelta(hours=1), - ) - session.add(api_token) - await session.commit() + api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=utcnow() + timedelta(hours=1)) + session.add(api_token) + await session.flush() - try: - async with async_session() as session: - resolved = await resolve_bearer_token(session, plaintext) - assert resolved is None - finally: - async with async_session() as session: - await session.delete(await session.get(ApiToken, api_token.id)) - await session.delete(await session.get(User, user.id)) - await session.commit() + resolved = await resolve_bearer_token(session, plaintext) + assert resolved is None -async def test_resolve_token_no_expiry(): +async def test_resolve_token_no_expiry(session): """Token without expires_at (never expires) resolves successfully.""" - from wiregui.auth.api_token import resolve_bearer_token - plaintext, token_hash = generate_api_token() - async with async_session() as session: - user = User(email="api-noexp@test.com", password_hash=hash_password("x"), role="admin") - session.add(user) - await session.commit() - await session.refresh(user) + user = User(email="api-noexp@test.com", password_hash=hash_password("x"), role="admin") + session.add(user) + await session.flush() - api_token = ApiToken( - token_hash=token_hash, - user_id=user.id, - expires_at=None, - ) - session.add(api_token) - await session.commit() + api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=None) + session.add(api_token) + await session.flush() - try: - async with async_session() as session: - resolved = await resolve_bearer_token(session, plaintext) - assert resolved is not None - assert resolved.id == user.id - finally: - async with async_session() as session: - await session.delete(await session.get(ApiToken, api_token.id)) - await session.delete(await session.get(User, user.id)) - await session.commit() + resolved = await resolve_bearer_token(session, plaintext) + assert resolved is not None + assert resolved.id == user.id # ========== get_current_api_user (via FastAPI deps) ========== @@ -187,7 +121,7 @@ async def test_get_current_api_user_bad_scheme(): assert exc_info.value.status_code == 401 -async def test_get_current_api_user_invalid_token(): +async def test_get_current_api_user_invalid_token(session): """Valid Bearer scheme but bogus token raises 401.""" from fastapi import HTTPException from wiregui.api.deps import get_current_api_user @@ -195,45 +129,31 @@ async def test_get_current_api_user_invalid_token(): request = MagicMock() request.headers = {"Authorization": "Bearer bogus-token-value"} - async with async_session() as session: - with pytest.raises(HTTPException) as exc_info: - await get_current_api_user(request, session=session) - assert exc_info.value.status_code == 401 - assert "Invalid" in exc_info.value.detail + with pytest.raises(HTTPException) as exc_info: + await get_current_api_user(request, session=session) + assert exc_info.value.status_code == 401 + assert "Invalid" in exc_info.value.detail -async def test_get_current_api_user_valid_token(): +async def test_get_current_api_user_valid_token(session): """Valid Bearer token resolves to user.""" from wiregui.api.deps import get_current_api_user plaintext, token_hash = generate_api_token() - async with async_session() as session: - user = User(email="api-dep-test@test.com", password_hash=hash_password("x"), role="admin") - session.add(user) - await session.commit() - await session.refresh(user) + user = User(email="api-dep-test@test.com", password_hash=hash_password("x"), role="admin") + session.add(user) + await session.flush() - api_token = ApiToken( - token_hash=token_hash, - user_id=user.id, - expires_at=utcnow() + timedelta(hours=1), - ) - session.add(api_token) - await session.commit() + api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=utcnow() + timedelta(hours=1)) + session.add(api_token) + await session.flush() - try: - request = MagicMock() - request.headers = {"Authorization": f"Bearer {plaintext}"} + request = MagicMock() + request.headers = {"Authorization": f"Bearer {plaintext}"} - async with async_session() as session: - resolved = await get_current_api_user(request, session=session) - assert resolved.id == user.id - finally: - async with async_session() as session: - await session.delete(await session.get(ApiToken, api_token.id)) - await session.delete(await session.get(User, user.id)) - await session.commit() + resolved = await get_current_api_user(request, session=session) + assert resolved.id == user.id # ========== require_admin ========== @@ -260,4 +180,4 @@ async def test_require_admin_rejects_unprivileged(): with pytest.raises(HTTPException) as exc_info: await require_admin(user=regular_user) assert exc_info.value.status_code == 403 - assert "Admin" in exc_info.value.detail \ No newline at end of file + assert "Admin" in exc_info.value.detail diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py deleted file mode 100644 index a926c63..0000000 --- a/tests/test_api_routes.py +++ /dev/null @@ -1,325 +0,0 @@ -"""Tests for REST API routes via httpx AsyncClient against the FastAPI app.""" - -import hashlib -from uuid import UUID, uuid4 - -from fastapi import FastAPI -from fastapi.testclient import TestClient -from httpx import ASGITransport, AsyncClient -from sqlmodel import select - -from wiregui.api.deps import get_current_api_user, get_db, require_admin -from wiregui.api.v0 import router as api_router -from wiregui.auth.api_token import generate_api_token -from wiregui.auth.passwords import hash_password -from wiregui.models.api_token import ApiToken -from wiregui.models.configuration import Configuration -from wiregui.models.device import Device -from wiregui.models.rule import Rule -from wiregui.models.user import User - - -def _build_app(session, admin_user=None, regular_user=None): - """Build a test FastAPI app with overridden dependencies.""" - test_app = FastAPI() - test_app.include_router(api_router, prefix="/api") - - async def override_get_db(): - yield session - - test_app.dependency_overrides[get_db] = override_get_db - - if admin_user: - test_app.dependency_overrides[get_current_api_user] = lambda: admin_user - test_app.dependency_overrides[require_admin] = lambda: admin_user - - return test_app - - -async def _make_admin(session) -> User: - user = User(email="api-admin@test.com", password_hash=hash_password("pw"), role="admin") - session.add(user) - await session.flush() - return user - - -async def _make_user(session, email="api-user@test.com") -> User: - user = User(email=email, password_hash=hash_password("pw"), role="unprivileged") - session.add(user) - await session.flush() - return user - - -# ========== Users API ========== - - -async def test_list_users(session): - admin = await _make_admin(session) - await _make_user(session, "user1@test.com") - await _make_user(session, "user2@test.com") - - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.get("/api/v0/users/") - assert resp.status_code == 200 - data = resp.json() - assert len(data) >= 3 # admin + 2 users - - -async def test_get_user(session): - admin = await _make_admin(session) - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.get(f"/api/v0/users/{admin.id}") - assert resp.status_code == 200 - assert resp.json()["email"] == "api-admin@test.com" - - -async def test_get_user_not_found(session): - admin = await _make_admin(session) - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.get(f"/api/v0/users/{uuid4()}") - assert resp.status_code == 404 - - -async def test_create_user(session): - admin = await _make_admin(session) - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.post("/api/v0/users/", json={ - "email": "new-api-user@test.com", - "password": "secret123", - "role": "unprivileged", - }) - assert resp.status_code == 201 - data = resp.json() - assert data["email"] == "new-api-user@test.com" - assert data["role"] == "unprivileged" - assert "id" in data - - -async def test_update_user(session): - admin = await _make_admin(session) - user = await _make_user(session) - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.put(f"/api/v0/users/{user.id}", json={ - "role": "admin", - }) - assert resp.status_code == 200 - assert resp.json()["role"] == "admin" - - -async def test_update_user_password(session): - admin = await _make_admin(session) - user = await _make_user(session) - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.put(f"/api/v0/users/{user.id}", json={ - "password": "new-password-123", - }) - assert resp.status_code == 200 - - from wiregui.auth.passwords import verify_password - refreshed = await session.get(User, user.id) - assert verify_password("new-password-123", refreshed.password_hash) - - -async def test_delete_user(session): - admin = await _make_admin(session) - user = await _make_user(session) - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.delete(f"/api/v0/users/{user.id}") - assert resp.status_code == 204 - - assert await session.get(User, user.id) is None - - -# ========== Devices API ========== - - -async def test_list_devices_admin_sees_all(session): - admin = await _make_admin(session) - user = await _make_user(session) - session.add(Device(name="d1", public_key="pk-api-d1", user_id=admin.id)) - session.add(Device(name="d2", public_key="pk-api-d2", user_id=user.id)) - await session.flush() - - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.get("/api/v0/devices/") - assert resp.status_code == 200 - assert len(resp.json()) >= 2 - - -async def test_list_devices_user_sees_own(session): - admin = await _make_admin(session) - user = await _make_user(session, "own-devices@test.com") - session.add(Device(name="mine", public_key="pk-api-mine", user_id=user.id)) - session.add(Device(name="not-mine", public_key="pk-api-notmine", user_id=admin.id)) - await session.flush() - - # Override to be the regular user - test_app = _build_app(session) - test_app.dependency_overrides[get_current_api_user] = lambda: user - async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: - resp = await client.get("/api/v0/devices/") - assert resp.status_code == 200 - names = [d["name"] for d in resp.json()] - assert "mine" in names - assert "not-mine" not in names - - -async def test_get_device(session): - admin = await _make_admin(session) - device = Device(name="detail", public_key="pk-api-detail", user_id=admin.id, ipv4="10.0.0.5") - session.add(device) - await session.flush() - - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.get(f"/api/v0/devices/{device.id}") - assert resp.status_code == 200 - assert resp.json()["name"] == "detail" - assert resp.json()["ipv4"] == "10.0.0.5" - - -async def test_get_device_forbidden_for_other_user(session): - admin = await _make_admin(session) - user = await _make_user(session, "other-dev@test.com") - device = Device(name="admin-dev", public_key="pk-api-forbid", user_id=admin.id) - session.add(device) - await session.flush() - - test_app = _build_app(session) - test_app.dependency_overrides[get_current_api_user] = lambda: user - async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: - resp = await client.get(f"/api/v0/devices/{device.id}") - assert resp.status_code == 403 - - -async def test_update_device(session): - admin = await _make_admin(session) - device = Device(name="old-name", public_key="pk-api-update", user_id=admin.id) - session.add(device) - await session.flush() - - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.put(f"/api/v0/devices/{device.id}", json={"name": "new-name"}) - assert resp.status_code == 200 - assert resp.json()["name"] == "new-name" - - -async def test_delete_device(session): - admin = await _make_admin(session) - device = Device(name="to-delete", public_key="pk-api-del", user_id=admin.id) - session.add(device) - await session.flush() - did = device.id - - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.delete(f"/api/v0/devices/{did}") - assert resp.status_code == 204 - - assert await session.get(Device, did) is None - - -# ========== Rules API ========== - - -async def test_list_rules(session): - admin = await _make_admin(session) - session.add(Rule(action="accept", destination="10.0.0.0/8")) - session.add(Rule(action="drop", destination="192.168.0.0/16", user_id=admin.id)) - await session.flush() - - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.get("/api/v0/rules/") - assert resp.status_code == 200 - assert len(resp.json()) >= 2 - - -async def test_create_rule(session): - admin = await _make_admin(session) - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.post("/api/v0/rules/", json={ - "action": "accept", - "destination": "172.16.0.0/12", - "port_type": "tcp", - "port_range": "443", - }) - assert resp.status_code == 201 - data = resp.json() - assert data["action"] == "accept" - assert data["destination"] == "172.16.0.0/12" - assert data["port_type"] == "tcp" - assert data["port_range"] == "443" - - -async def test_update_rule(session): - admin = await _make_admin(session) - rule = Rule(action="accept", destination="10.0.0.0/8") - session.add(rule) - await session.flush() - - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.put(f"/api/v0/rules/{rule.id}", json={"action": "drop"}) - assert resp.status_code == 200 - assert resp.json()["action"] == "drop" - - -async def test_delete_rule(session): - admin = await _make_admin(session) - rule = Rule(action="drop", destination="0.0.0.0/0") - session.add(rule) - await session.flush() - rid = rule.id - - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.delete(f"/api/v0/rules/{rid}") - assert resp.status_code == 204 - - assert await session.get(Rule, rid) is None - - -# ========== Configuration API ========== - - -async def test_get_configuration_auto_creates(session): - admin = await _make_admin(session) - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.get("/api/v0/configuration/") - assert resp.status_code == 200 - data = resp.json() - assert data["default_client_mtu"] == 1280 - assert data["local_auth_enabled"] is True - - -async def test_update_configuration(session): - admin = await _make_admin(session) - # Pre-create config - config = Configuration() - session.add(config) - await session.flush() - - app = _build_app(session, admin_user=admin) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: - resp = await client.put("/api/v0/configuration/", json={ - "default_client_mtu": 1400, - "vpn_session_duration": 3600, - "default_client_dns": ["8.8.8.8"], - }) - assert resp.status_code == 200 - data = resp.json() - assert data["default_client_mtu"] == 1400 - assert data["vpn_session_duration"] == 3600 - assert data["default_client_dns"] == ["8.8.8.8"] diff --git a/tests/test_auth.py b/tests/test_auth.py index 08f35dc..52c3c1e 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,4 +1,4 @@ -"""Tests for authentication modules.""" +"""Tests for authentication modules — seed logic and JWT edge cases.""" from sqlmodel import select @@ -8,17 +8,7 @@ from wiregui.auth.seed import seed_admin from wiregui.models.user import User -# --- Password hashing --- - - -def test_hash_and_verify(): - hashed = hash_password("my-secret") - assert verify_password("my-secret", hashed) is True - - -def test_verify_wrong_password(): - hashed = hash_password("correct") - assert verify_password("wrong", hashed) is False +# --- Password hashing (format guard) --- def test_hash_is_not_plaintext(): @@ -27,16 +17,7 @@ def test_hash_is_not_plaintext(): assert hashed.startswith("$2b$") -# --- JWT --- - - -def test_create_and_decode_token(): - token = create_access_token(user_id="user-123", role="admin") - payload = decode_access_token(token) - assert payload is not None - assert payload["sub"] == "user-123" - assert payload["role"] == "admin" - assert "exp" in payload +# --- JWT edge cases --- def test_decode_invalid_token(): @@ -54,8 +35,6 @@ def test_decode_tampered_token(): async def test_seed_admin_creates_user(session, monkeypatch): """seed_admin should create an admin when no users exist.""" - # Patch async_session to use our test session - from unittest.mock import AsyncMock from contextlib import asynccontextmanager @asynccontextmanager diff --git a/tests/test_auth_extended.py b/tests/test_auth_extended.py index f6f296b..28ef53a 100644 --- a/tests/test_auth_extended.py +++ b/tests/test_auth_extended.py @@ -1,65 +1,9 @@ -"""Extended auth tests — OIDC registration, WebAuthn options, session edge cases.""" +"""Extended auth tests — OIDC registration, WebAuthn options, rule event handlers.""" from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 -from wiregui.auth.passwords import hash_password -from wiregui.auth.session import authenticate_user from wiregui.models.user import User -from wiregui.utils.time import utcnow - - -# ========== Session / authenticate_user edge cases ========== - - -async def test_authenticate_user_no_password_hash(session, monkeypatch): - """Users without a password (OIDC-only) should not authenticate via password.""" - from contextlib import asynccontextmanager - - @asynccontextmanager - async def mock_session(): - yield session - - monkeypatch.setattr("wiregui.auth.session.async_session", mock_session) - - user = User(email="no-pw@test.com", password_hash=None) - session.add(user) - await session.flush() - - result = await authenticate_user("no-pw@test.com", "anything") - assert result is None - - -async def test_authenticate_user_disabled(session, monkeypatch): - """Disabled users should not authenticate.""" - from contextlib import asynccontextmanager - - @asynccontextmanager - async def mock_session(): - yield session - - monkeypatch.setattr("wiregui.auth.session.async_session", mock_session) - - user = User(email="disabled-auth@test.com", password_hash=hash_password("pw"), disabled_at=utcnow()) - session.add(user) - await session.flush() - - result = await authenticate_user("disabled-auth@test.com", "pw") - assert result is None - - -async def test_authenticate_user_nonexistent(session, monkeypatch): - """Nonexistent email should return None.""" - from contextlib import asynccontextmanager - - @asynccontextmanager - async def mock_session(): - yield session - - monkeypatch.setattr("wiregui.auth.session.async_session", mock_session) - - result = await authenticate_user("ghost@nowhere.com", "pw") - assert result is None # ========== OIDC provider registration ========== @@ -163,13 +107,11 @@ async def test_on_rule_updated_triggers_rebuild(mock_fw, mock_settings): from wiregui.models.rule import Rule from wiregui.services.events import on_rule_updated - # Need to mock the DB call inside _rebuild_user_chain with patch("wiregui.services.events.async_session") as mock_session_factory: mock_session = AsyncMock() mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) - # Mock the select results mock_rules_result = MagicMock() mock_rules_result.scalars.return_value.all.return_value = [] mock_devices_result = MagicMock() diff --git a/tests/test_integration_mfa.py b/tests/test_integration_mfa.py deleted file mode 100644 index 6a4ae62..0000000 --- a/tests/test_integration_mfa.py +++ /dev/null @@ -1,239 +0,0 @@ -"""Integration tests for MFA — full registration and authentication flows through the database.""" - -import pyotp -from sqlmodel import func, select - -from wiregui.auth.mfa import generate_totp_secret, verify_totp_code -from wiregui.auth.passwords import hash_password, verify_password -from wiregui.auth.session import authenticate_user -from wiregui.models.mfa_method import MFAMethod -from wiregui.models.user import User -from wiregui.utils.time import utcnow - - -async def test_full_totp_registration_flow(session, monkeypatch): - """End-to-end: create user → generate secret → verify code → store method → re-verify from DB.""" - from contextlib import asynccontextmanager - - @asynccontextmanager - async def mock_session(): - yield session - - # Create user with password - user = User(email="mfa-flow@example.com", password_hash=hash_password("secure123")) - session.add(user) - await session.flush() - - # Step 1: Generate TOTP secret (happens in account page) - secret = generate_totp_secret() - - # Step 2: User scans QR, enters code from their authenticator - totp = pyotp.TOTP(secret) - code = totp.now() - - # Step 3: Verify the code is correct before saving - assert verify_totp_code(secret, code) is True - - # Step 4: Save the MFA method to DB - method = MFAMethod( - name="My Authenticator", - type="totp", - payload={"secret": secret}, - user_id=user.id, - ) - session.add(method) - await session.flush() - - # Step 5: Simulate future login — load method from DB and verify a fresh code - fetched_methods = (await session.execute( - select(MFAMethod).where(MFAMethod.user_id == user.id) - )).scalars().all() - - assert len(fetched_methods) == 1 - stored_secret = fetched_methods[0].payload["secret"] - fresh_code = pyotp.TOTP(stored_secret).now() - assert verify_totp_code(stored_secret, fresh_code) is True - - -async def test_mfa_blocks_login_without_code(session, monkeypatch): - """User with MFA should not be fully authenticated without completing MFA challenge.""" - from contextlib import asynccontextmanager - - @asynccontextmanager - async def mock_session(): - yield session - - monkeypatch.setattr("wiregui.auth.session.async_session", mock_session) - - # Create user with MFA - user = User(email="mfa-block@example.com", password_hash=hash_password("password1")) - session.add(user) - await session.flush() - - secret = generate_totp_secret() - method = MFAMethod(name="Phone", type="totp", payload={"secret": secret}, user_id=user.id) - session.add(method) - await session.flush() - - # Password auth succeeds - authed_user = await authenticate_user("mfa-block@example.com", "password1") - assert authed_user is not None - - # But MFA methods exist — login page would redirect to /mfa instead of completing login - mfa_methods = (await session.execute( - select(MFAMethod).where(MFAMethod.user_id == authed_user.id) - )).scalars().all() - assert len(mfa_methods) > 0 # Login flow would check this and redirect to /mfa - - -async def test_mfa_wrong_code_rejected(session): - """Wrong TOTP code should be rejected even if method is valid.""" - user = User(email="mfa-wrong@example.com", password_hash=hash_password("pw")) - session.add(user) - await session.flush() - - secret = generate_totp_secret() - method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id) - session.add(method) - await session.flush() - - # Load from DB and try wrong code - fetched = (await session.execute( - select(MFAMethod).where(MFAMethod.user_id == user.id) - )).scalar_one() - - assert verify_totp_code(fetched.payload["secret"], "000000") is False - assert verify_totp_code(fetched.payload["secret"], "123456") is False - - -async def test_mfa_multiple_methods_any_valid_code_works(session): - """If user has multiple TOTP methods, a valid code from any should work.""" - user = User(email="mfa-multi@example.com") - session.add(user) - await session.flush() - - secret1 = generate_totp_secret() - secret2 = generate_totp_secret() - - session.add(MFAMethod(name="Phone", type="totp", payload={"secret": secret1}, user_id=user.id)) - session.add(MFAMethod(name="Backup", type="totp", payload={"secret": secret2}, user_id=user.id)) - await session.flush() - - methods = (await session.execute( - select(MFAMethod).where(MFAMethod.user_id == user.id) - )).scalars().all() - - # Code from method 1 should verify against method 1's secret - code1 = pyotp.TOTP(secret1).now() - verified = False - for m in methods: - if verify_totp_code(m.payload["secret"], code1): - verified = True - break - assert verified is True - - # Code from method 2 should also work - code2 = pyotp.TOTP(secret2).now() - verified2 = False - for m in methods: - if verify_totp_code(m.payload["secret"], code2): - verified2 = True - break - assert verified2 is True - - -async def test_mfa_method_last_used_tracking(session): - """Verifying MFA should update last_used_at timestamp.""" - user = User(email="mfa-tracking@example.com") - session.add(user) - await session.flush() - - secret = generate_totp_secret() - method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id) - session.add(method) - await session.flush() - - assert method.last_used_at is None - - # Simulate successful verification and update - code = pyotp.TOTP(secret).now() - assert verify_totp_code(secret, code) is True - - method.last_used_at = utcnow() - session.add(method) - await session.flush() - - fetched = await session.get(MFAMethod, method.id) - assert fetched.last_used_at is not None - - -async def test_mfa_delete_method_allows_login_without_mfa(session, monkeypatch): - """After removing all MFA methods, user should not be redirected to MFA challenge.""" - from contextlib import asynccontextmanager - - @asynccontextmanager - async def mock_session(): - yield session - - monkeypatch.setattr("wiregui.auth.session.async_session", mock_session) - - user = User(email="mfa-remove@example.com", password_hash=hash_password("pw")) - session.add(user) - await session.flush() - - secret = generate_totp_secret() - method = MFAMethod(name="Temp", type="totp", payload={"secret": secret}, user_id=user.id) - session.add(method) - await session.flush() - - # MFA exists - count = (await session.execute( - select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id) - )).scalar() - assert count == 1 - - # Delete it - await session.delete(method) - await session.flush() - - count = (await session.execute( - select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id) - )).scalar() - assert count == 0 - - # Password auth still works - authed = await authenticate_user("mfa-remove@example.com", "pw") - assert authed is not None - - # No MFA methods — login flow would skip MFA challenge - mfa_check = (await session.execute( - select(MFAMethod).where(MFAMethod.user_id == authed.id) - )).scalars().all() - assert len(mfa_check) == 0 - - -async def test_disabled_user_with_mfa_cannot_login(session, monkeypatch): - """Disabled user should be rejected at password stage, never reaching MFA.""" - from contextlib import asynccontextmanager - - @asynccontextmanager - async def mock_session(): - yield session - - monkeypatch.setattr("wiregui.auth.session.async_session", mock_session) - - user = User( - email="mfa-disabled@example.com", - password_hash=hash_password("pw"), - disabled_at=utcnow(), - ) - session.add(user) - await session.flush() - - secret = generate_totp_secret() - session.add(MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id)) - await session.flush() - - # Password auth rejects disabled user before MFA is ever checked - result = await authenticate_user("mfa-disabled@example.com", "pw") - assert result is None diff --git a/tests/test_integration_oidc.py b/tests/test_integration_oidc.py deleted file mode 100644 index 3ecd07f..0000000 --- a/tests/test_integration_oidc.py +++ /dev/null @@ -1,309 +0,0 @@ -"""Integration tests for OIDC — mock provider endpoints, test full auth code flow.""" - -import json -import time -from unittest.mock import patch -from uuid import uuid4 - -import respx -from httpx import Response -from jose import jwt -from sqlmodel import select - -from wiregui.auth.oidc import get_provider_config, load_providers, oauth, register_providers -from wiregui.config import get_settings -from wiregui.models.configuration import Configuration -from wiregui.models.oidc_connection import OIDCConnection -from wiregui.models.user import User - - -# --- Helper to create a fake OIDC provider config in the DB --- - - -async def _setup_oidc_config(session) -> Configuration: - """Insert a Configuration with a test OIDC provider.""" - config = Configuration( - openid_connect_providers=[ - { - "id": "test-idp", - "label": "Test IdP", - "scope": "openid email profile", - "response_type": "code", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "discovery_document_uri": "https://idp.example.com/.well-known/openid-configuration", - "auto_create_users": True, - } - ], - ) - session.add(config) - await session.commit() - return config - - -def _mock_discovery(): - """Mock OIDC discovery document response.""" - return { - "issuer": "https://idp.example.com", - "authorization_endpoint": "https://idp.example.com/authorize", - "token_endpoint": "https://idp.example.com/token", - "userinfo_endpoint": "https://idp.example.com/userinfo", - "jwks_uri": "https://idp.example.com/.well-known/jwks.json", - } - - -def _mock_token_response(email: str = "oidc-user@example.com"): - """Mock OIDC token endpoint response with ID token.""" - now = int(time.time()) - id_token_payload = { - "iss": "https://idp.example.com", - "sub": "oidc-subject-123", - "aud": "test-client-id", - "email": email, - "name": "OIDC User", - "iat": now, - "exp": now + 3600, - "nonce": "test-nonce", - } - # Sign with a simple secret (in real life this would be RSA) - id_token = jwt.encode(id_token_payload, "fake-secret", algorithm="HS256") - - return { - "access_token": "mock-access-token", - "token_type": "Bearer", - "expires_in": 3600, - "refresh_token": "mock-refresh-token", - "id_token": id_token, - } - - -# --- Provider config loading --- - - -async def test_load_providers_from_config(session, monkeypatch): - """Providers should be loaded from the Configuration table.""" - from contextlib import asynccontextmanager - - @asynccontextmanager - async def mock_session(): - yield session - - monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session) - - await _setup_oidc_config(session) - - providers = await load_providers() - assert len(providers) == 1 - assert providers[0]["id"] == "test-idp" - assert providers[0]["client_id"] == "test-client-id" - - -async def test_load_providers_empty_when_no_config(session, monkeypatch): - """Should return empty list when no Configuration exists.""" - from contextlib import asynccontextmanager - - @asynccontextmanager - async def mock_session(): - yield session - - monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session) - - providers = await load_providers() - assert providers == [] - - -async def test_get_provider_config_by_id(session, monkeypatch): - """Should find a specific provider by ID.""" - from contextlib import asynccontextmanager - - @asynccontextmanager - async def mock_session(): - yield session - - monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session) - - await _setup_oidc_config(session) - - config = await get_provider_config("test-idp") - assert config is not None - assert config["label"] == "Test IdP" - - config_missing = await get_provider_config("nonexistent") - assert config_missing is None - - -# --- OIDC connection storage --- - - -async def test_oidc_connection_created_on_login(session): - """Simulates what the callback route does: create user + OIDC connection.""" - user = User(email="oidc-new@example.com", role="unprivileged") - session.add(user) - await session.flush() - - token_data = _mock_token_response("oidc-new@example.com") - conn = OIDCConnection( - provider="test-idp", - refresh_token=token_data["refresh_token"], - refresh_response=token_data, - user_id=user.id, - ) - session.add(conn) - await session.flush() - - # Verify it was stored - fetched = (await session.execute( - select(OIDCConnection).where(OIDCConnection.user_id == user.id) - )).scalar_one() - assert fetched.provider == "test-idp" - assert fetched.refresh_token == "mock-refresh-token" - assert fetched.refresh_response["access_token"] == "mock-access-token" - - -async def test_oidc_connection_updated_on_re_login(session): - """Re-login should update the existing OIDC connection, not create a duplicate.""" - user = User(email="oidc-relogin@example.com") - session.add(user) - await session.flush() - - # First login - conn = OIDCConnection( - provider="test-idp", - refresh_token="old-refresh-token", - user_id=user.id, - ) - session.add(conn) - await session.flush() - - # Re-login — update existing connection (as the callback route does) - existing = (await session.execute( - select(OIDCConnection).where( - OIDCConnection.user_id == user.id, - OIDCConnection.provider == "test-idp", - ) - )).scalar_one() - - existing.refresh_token = "new-refresh-token" - from wiregui.utils.time import utcnow - existing.refreshed_at = utcnow() - session.add(existing) - await session.flush() - - # Should still be one connection - from sqlmodel import func - count = (await session.execute( - select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id) - )).scalar() - assert count == 1 - - fetched = (await session.execute( - select(OIDCConnection).where(OIDCConnection.user_id == user.id) - )).scalar_one() - assert fetched.refresh_token == "new-refresh-token" - - -async def test_oidc_auto_create_user(session): - """When auto_create_users is True, a new user should be created from OIDC email.""" - email = "auto-created@example.com" - - # Verify user doesn't exist - existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none() - assert existing is None - - # Simulate what callback does with auto_create - user = User(email=email, role="unprivileged") - session.add(user) - await session.flush() - - from wiregui.utils.time import utcnow - user.last_signed_in_at = utcnow() - user.last_signed_in_method = "oidc:test-idp" - session.add(user) - await session.flush() - - created = (await session.execute(select(User).where(User.email == email))).scalar_one() - assert created.role == "unprivileged" - assert created.last_signed_in_method == "oidc:test-idp" - - -async def test_oidc_disabled_user_rejected(session): - """Disabled users should not be logged in via OIDC.""" - from wiregui.utils.time import utcnow - - user = User(email="oidc-disabled@example.com", disabled_at=utcnow()) - session.add(user) - await session.flush() - - # The callback route checks disabled_at before creating session - assert user.disabled_at is not None # Would redirect to /login - - -async def test_oidc_user_without_auto_create_rejected(session): - """When auto_create is False and user doesn't exist, login should fail.""" - email = "no-auto-create@example.com" - - existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none() - assert existing is None - - # The callback route checks auto_create_users from provider config - # With auto_create=False and no existing user, it would redirect to /login - # This verifies the precondition - - -# --- OIDC refresh token flow --- - - -async def test_oidc_refresh_stores_new_token(session): - """Simulates a successful token refresh updating the connection.""" - user = User(email="oidc-refresh-test@example.com") - session.add(user) - await session.flush() - - conn = OIDCConnection( - provider="test-idp", - refresh_token="old-refresh", - user_id=user.id, - ) - session.add(conn) - await session.flush() - - # Simulate refresh result - new_token = { - "access_token": "new-access", - "refresh_token": "new-refresh", - "expires_in": 3600, - } - - conn.refresh_token = new_token.get("refresh_token", conn.refresh_token) - conn.refresh_response = new_token - from wiregui.utils.time import utcnow - conn.refreshed_at = utcnow() - session.add(conn) - await session.flush() - - fetched = await session.get(OIDCConnection, conn.id) - assert fetched.refresh_token == "new-refresh" - assert fetched.refresh_response["access_token"] == "new-access" - assert fetched.refreshed_at is not None - - -async def test_oidc_multiple_providers_per_user(session): - """User can have connections to multiple OIDC providers.""" - user = User(email="multi-provider@example.com") - session.add(user) - await session.flush() - - for provider in ["google", "okta", "azure-ad"]: - session.add(OIDCConnection( - provider=provider, - refresh_token=f"token-{provider}", - user_id=user.id, - )) - await session.flush() - - conns = (await session.execute( - select(OIDCConnection).where(OIDCConnection.user_id == user.id).order_by(OIDCConnection.provider) - )).scalars().all() - - assert len(conns) == 3 - assert [c.provider for c in conns] == ["azure-ad", "google", "okta"] diff --git a/tests/test_magic_link.py b/tests/test_magic_link.py index 0975c54..1d97eef 100644 --- a/tests/test_magic_link.py +++ b/tests/test_magic_link.py @@ -1,34 +1,6 @@ -"""Tests for magic link authentication flow.""" - -from datetime import timedelta +"""Tests for magic link authentication — token subject validation.""" from wiregui.auth.jwt import create_access_token, decode_access_token -from wiregui.auth.passwords import hash_password -from wiregui.models.user import User - - -def test_magic_link_token_creation(): - """Magic link token should be a valid JWT with short expiry.""" - token = create_access_token( - user_id="user-123", - role="unprivileged", - expires_delta=timedelta(minutes=15), - ) - payload = decode_access_token(token) - assert payload is not None - assert payload["sub"] == "user-123" - assert payload["role"] == "unprivileged" - - -def test_magic_link_token_expired(): - """Expired magic link token should be rejected.""" - token = create_access_token( - user_id="user-123", - role="admin", - expires_delta=timedelta(minutes=-1), # Already expired - ) - payload = decode_access_token(token) - assert payload is None def test_magic_link_token_wrong_user(): @@ -37,22 +9,3 @@ def test_magic_link_token_wrong_user(): payload = decode_access_token(token) assert payload["sub"] == "user-A" # Caller is responsible for checking sub matches the URL user_id - - -async def test_magic_link_disabled_user_rejected(session): - """Disabled users should not be able to use magic links.""" - from wiregui.utils.time import utcnow - - user = User( - email="disabled-magic@example.com", - password_hash=hash_password("pw"), - disabled_at=utcnow(), - ) - session.add(user) - await session.flush() - - # The token would be valid but the page handler checks disabled_at - token = create_access_token(user_id=str(user.id), role="unprivileged") - payload = decode_access_token(token) - assert payload is not None # Token itself is valid - assert user.disabled_at is not None # But user is disabled — handler would reject diff --git a/tests/test_mfa.py b/tests/test_mfa.py index 48b4eee..0028f59 100644 --- a/tests/test_mfa.py +++ b/tests/test_mfa.py @@ -1,4 +1,4 @@ -"""Tests for TOTP MFA functionality.""" +"""Tests for TOTP MFA — URI format, edge cases, QR generation, DB relationships.""" import pyotp @@ -12,22 +12,7 @@ from wiregui.models.mfa_method import MFAMethod from wiregui.models.user import User -# --- TOTP secret generation --- - - -def test_generate_secret(): - secret = generate_totp_secret() - assert len(secret) == 32 # base32 encoded - assert secret.isalpha() or any(c.isdigit() for c in secret) - - -def test_generate_secret_unique(): - s1 = generate_totp_secret() - s2 = generate_totp_secret() - assert s1 != s2 - - -# --- TOTP URI --- +# --- TOTP URI format --- def test_get_totp_uri(): @@ -43,19 +28,7 @@ def test_get_totp_uri_custom_issuer(): assert "issuer=MyVPN" in uri -# --- TOTP verification --- - - -def test_verify_valid_code(): - secret = generate_totp_secret() - totp = pyotp.TOTP(secret) - code = totp.now() - assert verify_totp_code(secret, code) is True - - -def test_verify_invalid_code(): - secret = generate_totp_secret() - assert verify_totp_code(secret, "000000") is False +# --- TOTP verification edge cases --- def test_verify_wrong_secret(): @@ -80,34 +53,7 @@ def test_generate_qr_svg(): assert "" in svg -# --- MFA method model integration --- - - -async def test_create_totp_method(session): - user = User(email="mfa-test@example.com") - session.add(user) - await session.flush() - - secret = generate_totp_secret() - method = MFAMethod( - name="My Phone", - type="totp", - payload={"secret": secret}, - user_id=user.id, - ) - session.add(method) - await session.flush() - - from sqlmodel import select - fetched = (await session.execute( - select(MFAMethod).where(MFAMethod.user_id == user.id) - )).scalar_one() - - assert fetched.name == "My Phone" - assert fetched.type == "totp" - stored_secret = fetched.payload["secret"] - code = pyotp.TOTP(stored_secret).now() - assert verify_totp_code(stored_secret, code) is True +# --- MFA method DB relationships --- async def test_user_multiple_mfa_methods(session): diff --git a/tests/test_models.py b/tests/test_models.py deleted file mode 100644 index ffeaa67..0000000 --- a/tests/test_models.py +++ /dev/null @@ -1,168 +0,0 @@ -"""Tests for SQLModel table definitions.""" - -import pytest # noqa: F401 — needed for pytest.raises -from sqlmodel import select - -from wiregui.models.api_token import ApiToken -from wiregui.models.configuration import Configuration -from wiregui.models.connectivity_check import ConnectivityCheck -from wiregui.models.device import Device -from wiregui.models.mfa_method import MFAMethod -from wiregui.models.oidc_connection import OIDCConnection -from wiregui.models.rule import Rule -from wiregui.models.user import User - - -async def test_create_user(session): - user = User(email="alice@example.com", role="admin") - session.add(user) - await session.flush() - - result = await session.execute(select(User).where(User.email == "alice@example.com")) - fetched = result.scalar_one() - assert fetched.id == user.id - assert fetched.role == "admin" - assert fetched.disabled_at is None - - -async def test_create_device_with_user(session): - user = User(email="bob@example.com") - session.add(user) - await session.flush() - - device = Device( - name="laptop", - public_key="pk-test-device-001", - user_id=user.id, - ) - session.add(device) - await session.flush() - - result = await session.execute(select(Device).where(Device.public_key == "pk-test-device-001")) - fetched = result.scalar_one() - assert fetched.name == "laptop" - assert fetched.user_id == user.id - assert fetched.use_default_dns is True - assert fetched.use_default_allowed_ips is True - assert fetched.rx_bytes is None - - -async def test_device_unique_public_key(session): - user = User(email="carol@example.com") - session.add(user) - await session.flush() - - d1 = Device(name="d1", public_key="duplicate-key", user_id=user.id) - session.add(d1) - await session.flush() - - d2 = Device(name="d2", public_key="duplicate-key", user_id=user.id) - session.add(d2) - with pytest.raises(Exception): # IntegrityError - await session.flush() - - -async def test_create_rule(session): - user = User(email="dave@example.com") - session.add(user) - await session.flush() - - rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id) - session.add(rule) - await session.flush() - - result = await session.execute(select(Rule).where(Rule.user_id == user.id)) - fetched = result.scalar_one() - assert fetched.action == "accept" - assert fetched.destination == "10.0.0.0/8" - assert fetched.port_type is None - assert fetched.port_range is None - - -async def test_create_rule_with_port(session): - rule = Rule( - action="drop", - destination="192.168.0.0/16", - port_type="tcp", - port_range="80-443", - ) - session.add(rule) - await session.flush() - - fetched = (await session.execute(select(Rule).where(Rule.id == rule.id))).scalar_one() - assert fetched.port_type == "tcp" - assert fetched.port_range == "80-443" - assert fetched.user_id is None # global rule - - -async def test_create_mfa_method(session): - user = User(email="eve@example.com") - session.add(user) - await session.flush() - - mfa = MFAMethod( - name="My Authenticator", - type="totp", - payload={"secret": "JBSWY3DPEHPK3PXP"}, - user_id=user.id, - ) - session.add(mfa) - await session.flush() - - fetched = (await session.execute(select(MFAMethod).where(MFAMethod.user_id == user.id))).scalar_one() - assert fetched.type == "totp" - assert fetched.payload["secret"] == "JBSWY3DPEHPK3PXP" - - -async def test_create_oidc_connection(session): - user = User(email="frank@example.com") - session.add(user) - await session.flush() - - conn = OIDCConnection(provider="google", refresh_token="tok_abc", user_id=user.id) - session.add(conn) - await session.flush() - - fetched = (await session.execute(select(OIDCConnection).where(OIDCConnection.user_id == user.id))).scalar_one() - assert fetched.provider == "google" - assert fetched.refresh_token == "tok_abc" - - -async def test_create_api_token(session): - user = User(email="grace@example.com") - session.add(user) - await session.flush() - - token = ApiToken(token_hash="sha256_fake_hash", user_id=user.id) - session.add(token) - await session.flush() - - fetched = (await session.execute(select(ApiToken).where(ApiToken.user_id == user.id))).scalar_one() - assert fetched.token_hash == "sha256_fake_hash" - assert fetched.expires_at is None - - -async def test_create_connectivity_check(session): - check = ConnectivityCheck(url="https://example.com", response_code=200) - session.add(check) - await session.flush() - - fetched = (await session.execute(select(ConnectivityCheck).where(ConnectivityCheck.id == check.id))).scalar_one() - assert fetched.response_code == 200 - - -async def test_configuration_defaults(session): - config = Configuration() - session.add(config) - await session.flush() - - fetched = (await session.execute(select(Configuration).where(Configuration.id == config.id))).scalar_one() - assert fetched.allow_unprivileged_device_management is True - assert fetched.local_auth_enabled is True - assert fetched.default_client_mtu == 1280 - assert fetched.default_client_persistent_keepalive == 25 - assert fetched.default_client_dns == ["1.1.1.1", "1.0.0.1"] - assert fetched.default_client_allowed_ips == ["0.0.0.0/0", "::/0"] - assert fetched.vpn_session_duration == 0 - assert fetched.openid_connect_providers == [] - assert fetched.saml_identity_providers == [] diff --git a/tests/test_notifications.py b/tests/test_notifications.py deleted file mode 100644 index 2b764a3..0000000 --- a/tests/test_notifications.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Tests for the notification service.""" - -from wiregui.services import notifications - - -def setup_function(): - """Clear notifications before each test.""" - notifications.clear_all() - - -def test_add_notification(): - n = notifications.add("info", "Test message") - assert n.severity == "info" - assert n.message == "Test message" - assert n.user is None - assert n.id is not None - assert n.timestamp is not None - - -def test_add_notification_with_user(): - n = notifications.add("error", "Something broke", user="admin@example.com") - assert n.user == "admin@example.com" - assert n.severity == "error" - - -def test_current_returns_newest_first(): - notifications.add("info", "First") - notifications.add("warning", "Second") - notifications.add("error", "Third") - - current = notifications.current() - assert len(current) == 3 - assert current[0].message == "Third" - assert current[1].message == "Second" - assert current[2].message == "First" - - -def test_count(): - assert notifications.count() == 0 - notifications.add("info", "One") - notifications.add("info", "Two") - assert notifications.count() == 2 - - -def test_clear_specific(): - n1 = notifications.add("info", "Keep this") - n2 = notifications.add("error", "Remove this") - - notifications.clear(n2.id) - current = notifications.current() - assert len(current) == 1 - assert current[0].id == n1.id - - -def test_clear_nonexistent_id_is_noop(): - notifications.add("info", "Test") - notifications.clear("nonexistent-id") - assert notifications.count() == 1 - - -def test_clear_all(): - notifications.add("info", "One") - notifications.add("info", "Two") - notifications.add("info", "Three") - assert notifications.count() == 3 - - notifications.clear_all() - assert notifications.count() == 0 - assert notifications.current() == [] - - -def test_to_dict(): - n = notifications.add("warning", "Test dict", user="someone@example.com") - d = n.to_dict() - assert d["severity"] == "warning" - assert d["message"] == "Test dict" - assert d["user"] == "someone@example.com" - assert "id" in d - assert "timestamp" in d - - -def test_max_notifications(): - """Deque should cap at MAX_NOTIFICATIONS.""" - for i in range(notifications.MAX_NOTIFICATIONS + 10): - notifications.add("info", f"Notification {i}") - - assert notifications.count() == notifications.MAX_NOTIFICATIONS - # Newest should be the last one added - assert notifications.current()[0].message == f"Notification {notifications.MAX_NOTIFICATIONS + 9}" diff --git a/tests/test_server_key.py b/tests/test_server_key.py index b325d45..52e2e32 100644 --- a/tests/test_server_key.py +++ b/tests/test_server_key.py @@ -2,62 +2,59 @@ import pytest -from wiregui.db import async_session from wiregui.models.configuration import Configuration from wiregui.utils.server_key import get_server_public_key -from sqlmodel import select -@pytest.fixture(autouse=True) -async def _snapshot_config(): - """Snapshot and restore server_public_key around each test.""" - async with async_session() as session: - c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() - orig = c.server_public_key if c else None - cid = c.id if c else None - - yield - - if cid: - async with async_session() as session: - c = await session.get(Configuration, cid) - if c: - c.server_public_key = orig - session.add(c) - await session.commit() - - -async def test_get_server_public_key_returns_key(): +async def test_get_server_public_key_returns_key(session, monkeypatch): """Returns the public key when configured.""" - async with async_session() as session: - c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() - c.server_public_key = "TestServerPubKey123456789012345678901234w=" - session.add(c) - await session.commit() + from contextlib import asynccontextmanager + + @asynccontextmanager + async def mock_session(): + yield session + + monkeypatch.setattr("wiregui.utils.server_key.async_session", mock_session) + + c = Configuration(server_public_key="TestServerPubKey123456789012345678901234w=") + session.add(c) + await session.flush() result = await get_server_public_key() assert result == "TestServerPubKey123456789012345678901234w=" -async def test_get_server_public_key_raises_when_missing(): +async def test_get_server_public_key_raises_when_missing(session, monkeypatch): """Raises RuntimeError when server_public_key is None.""" - async with async_session() as session: - c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() - c.server_public_key = None - session.add(c) - await session.commit() + from contextlib import asynccontextmanager + + @asynccontextmanager + async def mock_session(): + yield session + + monkeypatch.setattr("wiregui.utils.server_key.async_session", mock_session) + + c = Configuration(server_public_key=None) + session.add(c) + await session.flush() with pytest.raises(RuntimeError, match="not configured"): await get_server_public_key() -async def test_get_server_public_key_raises_when_empty_string(): +async def test_get_server_public_key_raises_when_empty_string(session, monkeypatch): """Raises RuntimeError when server_public_key is empty string.""" - async with async_session() as session: - c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() - c.server_public_key = "" - session.add(c) - await session.commit() + from contextlib import asynccontextmanager + + @asynccontextmanager + async def mock_session(): + yield session + + monkeypatch.setattr("wiregui.utils.server_key.async_session", mock_session) + + c = Configuration(server_public_key="") + session.add(c) + await session.flush() with pytest.raises(RuntimeError, match="not configured"): - await get_server_public_key() \ No newline at end of file + await get_server_public_key() diff --git a/tests/test_services.py b/tests/test_services.py index 1c32f0e..f74a4e9 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -1,4 +1,4 @@ -"""Tests for services — WireGuard and events.""" +"""Tests for services — WireGuard event error handling and rule events.""" from unittest.mock import AsyncMock, patch @@ -20,53 +20,6 @@ def _make_device(**kwargs) -> Device: return Device(**defaults) -# --- Events (with WG enabled) --- - - -@patch("wiregui.services.events.get_settings") -@patch("wiregui.services.events.firewall") -@patch("wiregui.services.events.wireguard") -async def test_on_device_created_calls_add_peer(mock_wg, mock_fw, mock_settings): - mock_settings.return_value.wg_enabled = True - mock_wg.add_peer = AsyncMock() - mock_fw.add_user_chain = AsyncMock() - mock_fw.add_device_jump_rule = AsyncMock() - - device = _make_device() - await on_device_created(device) - - mock_wg.add_peer.assert_awaited_once_with( - public_key="pk-test", - allowed_ips=["10.3.2.5/32", "fd00::3:2:5/128"], - preshared_key="psk-test", - ) - mock_fw.add_device_jump_rule.assert_awaited_once() - - -@patch("wiregui.services.events.get_settings") -@patch("wiregui.services.events.wireguard") -async def test_on_device_deleted_calls_remove_peer(mock_wg, mock_settings): - mock_settings.return_value.wg_enabled = True - mock_wg.remove_peer = AsyncMock() - - device = _make_device() - await on_device_deleted(device) - - mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-test") - - -@patch("wiregui.services.events.get_settings") -@patch("wiregui.services.events.wireguard") -async def test_on_device_updated_calls_add_peer(mock_wg, mock_settings): - mock_settings.return_value.wg_enabled = True - mock_wg.add_peer = AsyncMock() - - device = _make_device() - await on_device_updated(device) - - mock_wg.add_peer.assert_awaited_once() - - # --- Events (WG disabled) ---