fix: remove unit tests redundant with e2e, fix test DB isolation
Some checks failed
Dev / test (push) Failing after 7m41s
Dev / docker (push) Has been skipped

Remove 7 test files fully covered by e2e tests (admin, account, models,
API routes, integration MFA/OIDC, notifications). Trim 5 more files to
keep only edge cases not reachable via e2e.

Fix conftest to replace wiregui.db engine/session at import time so all
code uses the test database. Use session-scoped tables with per-test
savepoint isolation to prevent data leaking between tests.
This commit is contained in:
Stefano Bertelli 2026-03-31 21:27:46 -05:00
parent a9f62d5caf
commit a012635dff
15 changed files with 153 additions and 2006 deletions

View file

@ -1,4 +1,12 @@
"""Shared test fixtures — async DB session using a test database."""
"""Shared test fixtures — async DB session using a test database.
The module-level code below replaces ``wiregui.db.engine`` and
``wiregui.db.async_session`` with instances pointing at the **test** database
*before* any test (or other module) can grab a reference to the originals.
This means every ``from wiregui.db import async_session`` whether in test
files or in production code like ``wiregui.utils.server_key`` will get the
test-database session maker.
"""
import os
from collections.abc import AsyncGenerator
@ -8,6 +16,7 @@ from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlmodel import SQLModel
import wiregui.db as _db_module
from wiregui.config import get_settings
# All models must be imported so SQLModel.metadata knows about them
@ -51,19 +60,41 @@ def _ensure_test_db_sync():
_ensure_test_db_sync()
# ---------------------------------------------------------------------------
# Replace the production engine/session in wiregui.db at import time so that
# every module that does ``from wiregui.db import async_session`` picks up the
# test database. This MUST happen before test modules are collected (which
# triggers their top-level imports).
# ---------------------------------------------------------------------------
_test_engine = create_async_engine(TEST_DATABASE_URL)
_test_session_factory = async_sessionmaker(_test_engine, expire_on_commit=False)
_db_module.engine = _test_engine
_db_module.async_session = _test_session_factory
@pytest_asyncio.fixture(scope="session", autouse=True)
async def _setup_test_tables():
"""Create all tables once at the start of the test session, drop at end."""
async with _test_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
yield
async with _test_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
await _test_engine.dispose()
@pytest_asyncio.fixture
async def session() -> AsyncGenerator[AsyncSession]:
"""Fresh engine + session per test, with table setup/teardown."""
engine = create_async_engine(TEST_DATABASE_URL)
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
"""Per-test session with transaction isolation.
factory = async_sessionmaker(engine, expire_on_commit=False)
async with factory() as sess:
The session is bound to a connection-level transaction that is always
rolled back at teardown. When tested code calls ``session.commit()``,
SQLAlchemy only releases a SAVEPOINT the outer transaction is never
committed, so no test data persists between tests.
"""
async with _test_engine.connect() as conn:
txn = await conn.begin()
sess = AsyncSession(bind=conn, expire_on_commit=False, join_transaction_mode="create_savepoint")
yield sess
await sess.rollback()
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
await engine.dispose()
await sess.close()
await txn.rollback()

View file

@ -1,161 +0,0 @@
"""Tests for account functionality — password changes, API tokens, OIDC connections."""
import hashlib
from datetime import timedelta
from sqlmodel import func, select
from wiregui.auth.api_token import generate_api_token
from wiregui.auth.passwords import hash_password, verify_password
from wiregui.models.api_token import ApiToken
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# --- Password change ---
async def test_password_change_flow(session):
"""Simulate the password change flow: verify old, set new."""
user = User(email="pw-change@example.com", password_hash=hash_password("old-password"))
session.add(user)
await session.flush()
# Verify old password
assert verify_password("old-password", user.password_hash) is True
# Change password
user.password_hash = hash_password("new-password")
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert verify_password("new-password", fetched.password_hash) is True
assert verify_password("old-password", fetched.password_hash) is False
async def test_password_change_wrong_current(session):
"""Wrong current password should not allow change."""
user = User(email="pw-wrong@example.com", password_hash=hash_password("correct"))
session.add(user)
await session.flush()
# Simulate check
assert verify_password("wrong", user.password_hash) is False
# --- API token management ---
async def test_create_multiple_tokens(session):
user = User(email="multi-token@example.com")
session.add(user)
await session.flush()
for _ in range(3):
_, token_hash = generate_api_token()
session.add(ApiToken(token_hash=token_hash, user_id=user.id))
await session.flush()
count = (await session.execute(
select(func.count()).select_from(ApiToken).where(ApiToken.user_id == user.id)
)).scalar()
assert count == 3
async def test_token_with_expiry(session):
user = User(email="expiry-token@example.com")
session.add(user)
await session.flush()
_, token_hash = generate_api_token()
expires = utcnow() + timedelta(days=30)
token = ApiToken(token_hash=token_hash, expires_at=expires, user_id=user.id)
session.add(token)
await session.flush()
fetched = await session.get(ApiToken, token.id)
assert fetched.expires_at is not None
assert fetched.expires_at > utcnow()
async def test_delete_token(session):
user = User(email="del-token@example.com")
session.add(user)
await session.flush()
_, token_hash = generate_api_token()
token = ApiToken(token_hash=token_hash, user_id=user.id)
session.add(token)
await session.flush()
await session.delete(token)
await session.flush()
assert await session.get(ApiToken, token.id) is None
# --- OIDC connections ---
async def test_oidc_connection_create(session):
user = User(email="oidc-conn@example.com")
session.add(user)
await session.flush()
conn = OIDCConnection(
provider="google",
refresh_token="refresh-tok-123",
refresh_response={"access_token": "at", "token_type": "Bearer"},
refreshed_at=utcnow(),
user_id=user.id,
)
session.add(conn)
await session.flush()
fetched = (await session.execute(
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar_one()
assert fetched.provider == "google"
assert fetched.refresh_token == "refresh-tok-123"
assert fetched.refresh_response["access_token"] == "at"
async def test_multiple_oidc_providers(session):
user = User(email="multi-oidc@example.com")
session.add(user)
await session.flush()
for provider in ["google", "okta", "azure"]:
conn = OIDCConnection(provider=provider, user_id=user.id)
session.add(conn)
await session.flush()
count = (await session.execute(
select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar()
assert count == 3
async def test_oidc_connection_update_refresh_token(session):
user = User(email="oidc-refresh@example.com")
session.add(user)
await session.flush()
conn = OIDCConnection(
provider="google",
refresh_token="old-token",
user_id=user.id,
)
session.add(conn)
await session.flush()
conn.refresh_token = "new-token"
conn.refreshed_at = utcnow()
session.add(conn)
await session.flush()
fetched = await session.get(OIDCConnection, conn.id)
assert fetched.refresh_token == "new-token"
assert fetched.refreshed_at is not None

View file

@ -1,283 +0,0 @@
"""Tests for admin functionality — user management, configuration, cascading deletes."""
import pytest
from sqlmodel import func, select
from wiregui.auth.passwords import hash_password, verify_password
from wiregui.models.api_token import ApiToken
from wiregui.models.configuration import Configuration
from wiregui.models.device import Device
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.rule import Rule
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# --- User CRUD ---
async def test_create_user_with_role(session):
user = User(email="new-admin@test.com", password_hash=hash_password("secret"), role="admin")
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert fetched.role == "admin"
assert verify_password("secret", fetched.password_hash)
async def test_update_user_email(session):
user = User(email="old@test.com", password_hash=hash_password("pw"))
session.add(user)
await session.flush()
user.email = "new@test.com"
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert fetched.email == "new@test.com"
async def test_disable_user(session):
user = User(email="active@test.com", password_hash=hash_password("pw"))
session.add(user)
await session.flush()
assert user.disabled_at is None
user.disabled_at = utcnow()
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert fetched.disabled_at is not None
async def test_promote_demote_user(session):
user = User(email="user@test.com", role="unprivileged")
session.add(user)
await session.flush()
assert user.role == "unprivileged"
user.role = "admin"
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert fetched.role == "admin"
user.role = "unprivileged"
session.add(user)
await session.flush()
assert (await session.get(User, user.id)).role == "unprivileged"
# --- Cascading delete (manual, as we do it in the admin page) ---
async def test_delete_user_cascades_devices(session):
user = User(email="cascade@test.com")
session.add(user)
await session.flush()
d1 = Device(name="d1", public_key="pk-cascade-1", ipv4="10.0.0.1", user_id=user.id)
d2 = Device(name="d2", public_key="pk-cascade-2", ipv4="10.0.0.2", user_id=user.id)
session.add_all([d1, d2])
await session.flush()
# Manually delete devices then user (matching admin page behavior)
devices = (await session.execute(select(Device).where(Device.user_id == user.id))).scalars().all()
for d in devices:
await session.delete(d)
await session.delete(user)
await session.flush()
assert (await session.execute(select(func.count()).select_from(Device).where(Device.user_id == user.id))).scalar() == 0
assert await session.get(User, user.id) is None
async def test_delete_user_cascades_rules(session):
user = User(email="rule-cascade@test.com")
session.add(user)
await session.flush()
rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id)
session.add(rule)
await session.flush()
# Delete rules then user
rules = (await session.execute(select(Rule).where(Rule.user_id == user.id))).scalars().all()
for r in rules:
await session.delete(r)
await session.delete(user)
await session.flush()
assert (await session.execute(select(func.count()).select_from(Rule).where(Rule.user_id == user.id))).scalar() == 0
# --- Configuration singleton ---
async def test_configuration_create_and_update(session):
config = Configuration()
session.add(config)
await session.flush()
assert config.default_client_mtu == 1280
assert config.local_auth_enabled is True
config.default_client_mtu = 1400
config.local_auth_enabled = False
config.vpn_session_duration = 3600
session.add(config)
await session.flush()
fetched = await session.get(Configuration, config.id)
assert fetched.default_client_mtu == 1400
assert fetched.local_auth_enabled is False
assert fetched.vpn_session_duration == 3600
async def test_configuration_oidc_providers(session):
config = Configuration()
session.add(config)
await session.flush()
assert config.openid_connect_providers == []
providers = [
{
"id": "google",
"label": "Sign in with Google",
"scope": "openid email profile",
"response_type": "code",
"client_id": "google-client-id",
"client_secret": "google-secret",
"discovery_document_uri": "https://accounts.google.com/.well-known/openid-configuration",
"auto_create_users": True,
},
{
"id": "okta",
"label": "Okta SSO",
"scope": "openid email profile",
"response_type": "code",
"client_id": "okta-client-id",
"client_secret": "okta-secret",
"discovery_document_uri": "https://dev-123.okta.com/.well-known/openid-configuration",
"auto_create_users": False,
},
]
config.openid_connect_providers = providers
session.add(config)
await session.flush()
fetched = await session.get(Configuration, config.id)
assert len(fetched.openid_connect_providers) == 2
assert fetched.openid_connect_providers[0]["id"] == "google"
assert fetched.openid_connect_providers[1]["auto_create_users"] is False
async def test_configuration_update_client_defaults(session):
config = Configuration()
session.add(config)
await session.flush()
config.default_client_endpoint = "vpn.example.com"
config.default_client_dns = ["8.8.8.8", "8.8.4.4"]
config.default_client_allowed_ips = ["10.0.0.0/8"]
config.default_client_persistent_keepalive = 30
session.add(config)
await session.flush()
fetched = await session.get(Configuration, config.id)
assert fetched.default_client_endpoint == "vpn.example.com"
assert fetched.default_client_dns == ["8.8.8.8", "8.8.4.4"]
assert fetched.default_client_allowed_ips == ["10.0.0.0/8"]
assert fetched.default_client_persistent_keepalive == 30
async def test_configuration_security_toggles(session):
config = Configuration()
session.add(config)
await session.flush()
config.allow_unprivileged_device_management = False
config.allow_unprivileged_device_configuration = False
config.disable_vpn_on_oidc_error = True
session.add(config)
await session.flush()
fetched = await session.get(Configuration, config.id)
assert fetched.allow_unprivileged_device_management is False
assert fetched.allow_unprivileged_device_configuration is False
assert fetched.disable_vpn_on_oidc_error is True
# --- Device config overrides ---
async def test_device_with_custom_config(session):
user = User(email="config-user@test.com")
session.add(user)
await session.flush()
device = Device(
name="custom-config",
public_key="pk-custom-config",
user_id=user.id,
use_default_dns=False,
use_default_endpoint=False,
use_default_mtu=False,
use_default_persistent_keepalive=False,
use_default_allowed_ips=False,
dns=["8.8.8.8"],
endpoint="custom-vpn.example.com",
mtu=1400,
persistent_keepalive=15,
allowed_ips=["10.0.0.0/8", "172.16.0.0/12"],
)
session.add(device)
await session.flush()
fetched = await session.get(Device, device.id)
assert fetched.use_default_dns is False
assert fetched.dns == ["8.8.8.8"]
assert fetched.endpoint == "custom-vpn.example.com"
assert fetched.mtu == 1400
assert fetched.persistent_keepalive == 15
assert fetched.allowed_ips == ["10.0.0.0/8", "172.16.0.0/12"]
async def test_device_default_flags_are_true(session):
user = User(email="defaults@test.com")
session.add(user)
await session.flush()
device = Device(name="defaults", public_key="pk-defaults", user_id=user.id)
session.add(device)
await session.flush()
fetched = await session.get(Device, device.id)
assert fetched.use_default_allowed_ips is True
assert fetched.use_default_dns is True
assert fetched.use_default_endpoint is True
assert fetched.use_default_mtu is True
assert fetched.use_default_persistent_keepalive is True
# --- User device count ---
async def test_user_device_count_query(session):
user = User(email="count-user@test.com")
session.add(user)
await session.flush()
for i in range(3):
session.add(Device(name=f"d{i}", public_key=f"pk-count-{i}", user_id=user.id))
await session.flush()
count = (await session.execute(
select(func.count()).select_from(Device).where(Device.user_id == user.id)
)).scalar()
assert count == 3

View file

@ -1,15 +1,12 @@
"""Tests for API dependency injection — Bearer token auth and admin guard."""
import hashlib
from datetime import timedelta
from uuid import uuid4
import pytest
from unittest.mock import AsyncMock, MagicMock
from wiregui.auth.api_token import generate_api_token
from wiregui.auth.api_token import generate_api_token, resolve_bearer_token
from wiregui.auth.passwords import hash_password
from wiregui.db import async_session
from wiregui.models.api_token import ApiToken
from wiregui.models.user import User
from wiregui.utils.time import utcnow
@ -18,143 +15,80 @@ from wiregui.utils.time import utcnow
# ========== resolve_bearer_token ==========
async def test_resolve_valid_token():
async def test_resolve_valid_token(session):
"""Valid, non-expired token resolves to user."""
from wiregui.auth.api_token import resolve_bearer_token
plaintext, token_hash = generate_api_token()
async with async_session() as session:
user = User(email="api-test@test.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.commit()
await session.refresh(user)
user = User(email="api-test@test.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.flush()
api_token = ApiToken(
token_hash=token_hash,
user_id=user.id,
expires_at=utcnow() + timedelta(hours=1),
)
session.add(api_token)
await session.commit()
api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=utcnow() + timedelta(hours=1))
session.add(api_token)
await session.flush()
try:
async with async_session() as session:
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is not None
assert resolved.id == user.id
assert resolved.email == "api-test@test.com"
finally:
async with async_session() as session:
await session.delete(await session.get(ApiToken, api_token.id))
await session.delete(await session.get(User, user.id))
await session.commit()
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is not None
assert resolved.id == user.id
assert resolved.email == "api-test@test.com"
async def test_resolve_expired_token():
async def test_resolve_expired_token(session):
"""Expired token returns None."""
from wiregui.auth.api_token import resolve_bearer_token
plaintext, token_hash = generate_api_token()
async with async_session() as session:
user = User(email="api-expired@test.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.commit()
await session.refresh(user)
user = User(email="api-expired@test.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.flush()
api_token = ApiToken(
token_hash=token_hash,
user_id=user.id,
expires_at=utcnow() - timedelta(hours=1), # already expired
)
session.add(api_token)
await session.commit()
api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=utcnow() - timedelta(hours=1))
session.add(api_token)
await session.flush()
try:
async with async_session() as session:
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is None
finally:
async with async_session() as session:
await session.delete(await session.get(ApiToken, api_token.id))
await session.delete(await session.get(User, user.id))
await session.commit()
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is None
async def test_resolve_invalid_token():
async def test_resolve_invalid_token(session):
"""Nonexistent token returns None."""
from wiregui.auth.api_token import resolve_bearer_token
async with async_session() as session:
resolved = await resolve_bearer_token(session, "totally-bogus-token")
assert resolved is None
resolved = await resolve_bearer_token(session, "totally-bogus-token")
assert resolved is None
async def test_resolve_token_disabled_user():
async def test_resolve_token_disabled_user(session):
"""Token for disabled user returns None."""
from wiregui.auth.api_token import resolve_bearer_token
plaintext, token_hash = generate_api_token()
async with async_session() as session:
user = User(
email="api-disabled@test.com", password_hash=hash_password("x"),
role="admin", disabled_at=utcnow(),
)
session.add(user)
await session.commit()
await session.refresh(user)
user = User(
email="api-disabled@test.com", password_hash=hash_password("x"),
role="admin", disabled_at=utcnow(),
)
session.add(user)
await session.flush()
api_token = ApiToken(
token_hash=token_hash,
user_id=user.id,
expires_at=utcnow() + timedelta(hours=1),
)
session.add(api_token)
await session.commit()
api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=utcnow() + timedelta(hours=1))
session.add(api_token)
await session.flush()
try:
async with async_session() as session:
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is None
finally:
async with async_session() as session:
await session.delete(await session.get(ApiToken, api_token.id))
await session.delete(await session.get(User, user.id))
await session.commit()
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is None
async def test_resolve_token_no_expiry():
async def test_resolve_token_no_expiry(session):
"""Token without expires_at (never expires) resolves successfully."""
from wiregui.auth.api_token import resolve_bearer_token
plaintext, token_hash = generate_api_token()
async with async_session() as session:
user = User(email="api-noexp@test.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.commit()
await session.refresh(user)
user = User(email="api-noexp@test.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.flush()
api_token = ApiToken(
token_hash=token_hash,
user_id=user.id,
expires_at=None,
)
session.add(api_token)
await session.commit()
api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=None)
session.add(api_token)
await session.flush()
try:
async with async_session() as session:
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is not None
assert resolved.id == user.id
finally:
async with async_session() as session:
await session.delete(await session.get(ApiToken, api_token.id))
await session.delete(await session.get(User, user.id))
await session.commit()
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is not None
assert resolved.id == user.id
# ========== get_current_api_user (via FastAPI deps) ==========
@ -187,7 +121,7 @@ async def test_get_current_api_user_bad_scheme():
assert exc_info.value.status_code == 401
async def test_get_current_api_user_invalid_token():
async def test_get_current_api_user_invalid_token(session):
"""Valid Bearer scheme but bogus token raises 401."""
from fastapi import HTTPException
from wiregui.api.deps import get_current_api_user
@ -195,45 +129,31 @@ async def test_get_current_api_user_invalid_token():
request = MagicMock()
request.headers = {"Authorization": "Bearer bogus-token-value"}
async with async_session() as session:
with pytest.raises(HTTPException) as exc_info:
await get_current_api_user(request, session=session)
assert exc_info.value.status_code == 401
assert "Invalid" in exc_info.value.detail
with pytest.raises(HTTPException) as exc_info:
await get_current_api_user(request, session=session)
assert exc_info.value.status_code == 401
assert "Invalid" in exc_info.value.detail
async def test_get_current_api_user_valid_token():
async def test_get_current_api_user_valid_token(session):
"""Valid Bearer token resolves to user."""
from wiregui.api.deps import get_current_api_user
plaintext, token_hash = generate_api_token()
async with async_session() as session:
user = User(email="api-dep-test@test.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.commit()
await session.refresh(user)
user = User(email="api-dep-test@test.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.flush()
api_token = ApiToken(
token_hash=token_hash,
user_id=user.id,
expires_at=utcnow() + timedelta(hours=1),
)
session.add(api_token)
await session.commit()
api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=utcnow() + timedelta(hours=1))
session.add(api_token)
await session.flush()
try:
request = MagicMock()
request.headers = {"Authorization": f"Bearer {plaintext}"}
request = MagicMock()
request.headers = {"Authorization": f"Bearer {plaintext}"}
async with async_session() as session:
resolved = await get_current_api_user(request, session=session)
assert resolved.id == user.id
finally:
async with async_session() as session:
await session.delete(await session.get(ApiToken, api_token.id))
await session.delete(await session.get(User, user.id))
await session.commit()
resolved = await get_current_api_user(request, session=session)
assert resolved.id == user.id
# ========== require_admin ==========
@ -260,4 +180,4 @@ async def test_require_admin_rejects_unprivileged():
with pytest.raises(HTTPException) as exc_info:
await require_admin(user=regular_user)
assert exc_info.value.status_code == 403
assert "Admin" in exc_info.value.detail
assert "Admin" in exc_info.value.detail

View file

@ -1,325 +0,0 @@
"""Tests for REST API routes via httpx AsyncClient against the FastAPI app."""
import hashlib
from uuid import UUID, uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from httpx import ASGITransport, AsyncClient
from sqlmodel import select
from wiregui.api.deps import get_current_api_user, get_db, require_admin
from wiregui.api.v0 import router as api_router
from wiregui.auth.api_token import generate_api_token
from wiregui.auth.passwords import hash_password
from wiregui.models.api_token import ApiToken
from wiregui.models.configuration import Configuration
from wiregui.models.device import Device
from wiregui.models.rule import Rule
from wiregui.models.user import User
def _build_app(session, admin_user=None, regular_user=None):
"""Build a test FastAPI app with overridden dependencies."""
test_app = FastAPI()
test_app.include_router(api_router, prefix="/api")
async def override_get_db():
yield session
test_app.dependency_overrides[get_db] = override_get_db
if admin_user:
test_app.dependency_overrides[get_current_api_user] = lambda: admin_user
test_app.dependency_overrides[require_admin] = lambda: admin_user
return test_app
async def _make_admin(session) -> User:
user = User(email="api-admin@test.com", password_hash=hash_password("pw"), role="admin")
session.add(user)
await session.flush()
return user
async def _make_user(session, email="api-user@test.com") -> User:
user = User(email=email, password_hash=hash_password("pw"), role="unprivileged")
session.add(user)
await session.flush()
return user
# ========== Users API ==========
async def test_list_users(session):
admin = await _make_admin(session)
await _make_user(session, "user1@test.com")
await _make_user(session, "user2@test.com")
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/v0/users/")
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 3 # admin + 2 users
async def test_get_user(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get(f"/api/v0/users/{admin.id}")
assert resp.status_code == 200
assert resp.json()["email"] == "api-admin@test.com"
async def test_get_user_not_found(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get(f"/api/v0/users/{uuid4()}")
assert resp.status_code == 404
async def test_create_user(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/api/v0/users/", json={
"email": "new-api-user@test.com",
"password": "secret123",
"role": "unprivileged",
})
assert resp.status_code == 201
data = resp.json()
assert data["email"] == "new-api-user@test.com"
assert data["role"] == "unprivileged"
assert "id" in data
async def test_update_user(session):
admin = await _make_admin(session)
user = await _make_user(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put(f"/api/v0/users/{user.id}", json={
"role": "admin",
})
assert resp.status_code == 200
assert resp.json()["role"] == "admin"
async def test_update_user_password(session):
admin = await _make_admin(session)
user = await _make_user(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put(f"/api/v0/users/{user.id}", json={
"password": "new-password-123",
})
assert resp.status_code == 200
from wiregui.auth.passwords import verify_password
refreshed = await session.get(User, user.id)
assert verify_password("new-password-123", refreshed.password_hash)
async def test_delete_user(session):
admin = await _make_admin(session)
user = await _make_user(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.delete(f"/api/v0/users/{user.id}")
assert resp.status_code == 204
assert await session.get(User, user.id) is None
# ========== Devices API ==========
async def test_list_devices_admin_sees_all(session):
admin = await _make_admin(session)
user = await _make_user(session)
session.add(Device(name="d1", public_key="pk-api-d1", user_id=admin.id))
session.add(Device(name="d2", public_key="pk-api-d2", user_id=user.id))
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/v0/devices/")
assert resp.status_code == 200
assert len(resp.json()) >= 2
async def test_list_devices_user_sees_own(session):
admin = await _make_admin(session)
user = await _make_user(session, "own-devices@test.com")
session.add(Device(name="mine", public_key="pk-api-mine", user_id=user.id))
session.add(Device(name="not-mine", public_key="pk-api-notmine", user_id=admin.id))
await session.flush()
# Override to be the regular user
test_app = _build_app(session)
test_app.dependency_overrides[get_current_api_user] = lambda: user
async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
resp = await client.get("/api/v0/devices/")
assert resp.status_code == 200
names = [d["name"] for d in resp.json()]
assert "mine" in names
assert "not-mine" not in names
async def test_get_device(session):
admin = await _make_admin(session)
device = Device(name="detail", public_key="pk-api-detail", user_id=admin.id, ipv4="10.0.0.5")
session.add(device)
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get(f"/api/v0/devices/{device.id}")
assert resp.status_code == 200
assert resp.json()["name"] == "detail"
assert resp.json()["ipv4"] == "10.0.0.5"
async def test_get_device_forbidden_for_other_user(session):
admin = await _make_admin(session)
user = await _make_user(session, "other-dev@test.com")
device = Device(name="admin-dev", public_key="pk-api-forbid", user_id=admin.id)
session.add(device)
await session.flush()
test_app = _build_app(session)
test_app.dependency_overrides[get_current_api_user] = lambda: user
async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
resp = await client.get(f"/api/v0/devices/{device.id}")
assert resp.status_code == 403
async def test_update_device(session):
admin = await _make_admin(session)
device = Device(name="old-name", public_key="pk-api-update", user_id=admin.id)
session.add(device)
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put(f"/api/v0/devices/{device.id}", json={"name": "new-name"})
assert resp.status_code == 200
assert resp.json()["name"] == "new-name"
async def test_delete_device(session):
admin = await _make_admin(session)
device = Device(name="to-delete", public_key="pk-api-del", user_id=admin.id)
session.add(device)
await session.flush()
did = device.id
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.delete(f"/api/v0/devices/{did}")
assert resp.status_code == 204
assert await session.get(Device, did) is None
# ========== Rules API ==========
async def test_list_rules(session):
admin = await _make_admin(session)
session.add(Rule(action="accept", destination="10.0.0.0/8"))
session.add(Rule(action="drop", destination="192.168.0.0/16", user_id=admin.id))
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/v0/rules/")
assert resp.status_code == 200
assert len(resp.json()) >= 2
async def test_create_rule(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/api/v0/rules/", json={
"action": "accept",
"destination": "172.16.0.0/12",
"port_type": "tcp",
"port_range": "443",
})
assert resp.status_code == 201
data = resp.json()
assert data["action"] == "accept"
assert data["destination"] == "172.16.0.0/12"
assert data["port_type"] == "tcp"
assert data["port_range"] == "443"
async def test_update_rule(session):
admin = await _make_admin(session)
rule = Rule(action="accept", destination="10.0.0.0/8")
session.add(rule)
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put(f"/api/v0/rules/{rule.id}", json={"action": "drop"})
assert resp.status_code == 200
assert resp.json()["action"] == "drop"
async def test_delete_rule(session):
admin = await _make_admin(session)
rule = Rule(action="drop", destination="0.0.0.0/0")
session.add(rule)
await session.flush()
rid = rule.id
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.delete(f"/api/v0/rules/{rid}")
assert resp.status_code == 204
assert await session.get(Rule, rid) is None
# ========== Configuration API ==========
async def test_get_configuration_auto_creates(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/v0/configuration/")
assert resp.status_code == 200
data = resp.json()
assert data["default_client_mtu"] == 1280
assert data["local_auth_enabled"] is True
async def test_update_configuration(session):
admin = await _make_admin(session)
# Pre-create config
config = Configuration()
session.add(config)
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put("/api/v0/configuration/", json={
"default_client_mtu": 1400,
"vpn_session_duration": 3600,
"default_client_dns": ["8.8.8.8"],
})
assert resp.status_code == 200
data = resp.json()
assert data["default_client_mtu"] == 1400
assert data["vpn_session_duration"] == 3600
assert data["default_client_dns"] == ["8.8.8.8"]

View file

@ -1,4 +1,4 @@
"""Tests for authentication modules."""
"""Tests for authentication modules — seed logic and JWT edge cases."""
from sqlmodel import select
@ -8,17 +8,7 @@ from wiregui.auth.seed import seed_admin
from wiregui.models.user import User
# --- Password hashing ---
def test_hash_and_verify():
hashed = hash_password("my-secret")
assert verify_password("my-secret", hashed) is True
def test_verify_wrong_password():
hashed = hash_password("correct")
assert verify_password("wrong", hashed) is False
# --- Password hashing (format guard) ---
def test_hash_is_not_plaintext():
@ -27,16 +17,7 @@ def test_hash_is_not_plaintext():
assert hashed.startswith("$2b$")
# --- JWT ---
def test_create_and_decode_token():
token = create_access_token(user_id="user-123", role="admin")
payload = decode_access_token(token)
assert payload is not None
assert payload["sub"] == "user-123"
assert payload["role"] == "admin"
assert "exp" in payload
# --- JWT edge cases ---
def test_decode_invalid_token():
@ -54,8 +35,6 @@ def test_decode_tampered_token():
async def test_seed_admin_creates_user(session, monkeypatch):
"""seed_admin should create an admin when no users exist."""
# Patch async_session to use our test session
from unittest.mock import AsyncMock
from contextlib import asynccontextmanager
@asynccontextmanager

View file

@ -1,65 +1,9 @@
"""Extended auth tests — OIDC registration, WebAuthn options, session edge cases."""
"""Extended auth tests — OIDC registration, WebAuthn options, rule event handlers."""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
from wiregui.auth.passwords import hash_password
from wiregui.auth.session import authenticate_user
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# ========== Session / authenticate_user edge cases ==========
async def test_authenticate_user_no_password_hash(session, monkeypatch):
"""Users without a password (OIDC-only) should not authenticate via password."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
user = User(email="no-pw@test.com", password_hash=None)
session.add(user)
await session.flush()
result = await authenticate_user("no-pw@test.com", "anything")
assert result is None
async def test_authenticate_user_disabled(session, monkeypatch):
"""Disabled users should not authenticate."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
user = User(email="disabled-auth@test.com", password_hash=hash_password("pw"), disabled_at=utcnow())
session.add(user)
await session.flush()
result = await authenticate_user("disabled-auth@test.com", "pw")
assert result is None
async def test_authenticate_user_nonexistent(session, monkeypatch):
"""Nonexistent email should return None."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
result = await authenticate_user("ghost@nowhere.com", "pw")
assert result is None
# ========== OIDC provider registration ==========
@ -163,13 +107,11 @@ async def test_on_rule_updated_triggers_rebuild(mock_fw, mock_settings):
from wiregui.models.rule import Rule
from wiregui.services.events import on_rule_updated
# Need to mock the DB call inside _rebuild_user_chain
with patch("wiregui.services.events.async_session") as mock_session_factory:
mock_session = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
# Mock the select results
mock_rules_result = MagicMock()
mock_rules_result.scalars.return_value.all.return_value = []
mock_devices_result = MagicMock()

View file

@ -1,239 +0,0 @@
"""Integration tests for MFA — full registration and authentication flows through the database."""
import pyotp
from sqlmodel import func, select
from wiregui.auth.mfa import generate_totp_secret, verify_totp_code
from wiregui.auth.passwords import hash_password, verify_password
from wiregui.auth.session import authenticate_user
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.user import User
from wiregui.utils.time import utcnow
async def test_full_totp_registration_flow(session, monkeypatch):
"""End-to-end: create user → generate secret → verify code → store method → re-verify from DB."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
# Create user with password
user = User(email="mfa-flow@example.com", password_hash=hash_password("secure123"))
session.add(user)
await session.flush()
# Step 1: Generate TOTP secret (happens in account page)
secret = generate_totp_secret()
# Step 2: User scans QR, enters code from their authenticator
totp = pyotp.TOTP(secret)
code = totp.now()
# Step 3: Verify the code is correct before saving
assert verify_totp_code(secret, code) is True
# Step 4: Save the MFA method to DB
method = MFAMethod(
name="My Authenticator",
type="totp",
payload={"secret": secret},
user_id=user.id,
)
session.add(method)
await session.flush()
# Step 5: Simulate future login — load method from DB and verify a fresh code
fetched_methods = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalars().all()
assert len(fetched_methods) == 1
stored_secret = fetched_methods[0].payload["secret"]
fresh_code = pyotp.TOTP(stored_secret).now()
assert verify_totp_code(stored_secret, fresh_code) is True
async def test_mfa_blocks_login_without_code(session, monkeypatch):
"""User with MFA should not be fully authenticated without completing MFA challenge."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
# Create user with MFA
user = User(email="mfa-block@example.com", password_hash=hash_password("password1"))
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(name="Phone", type="totp", payload={"secret": secret}, user_id=user.id)
session.add(method)
await session.flush()
# Password auth succeeds
authed_user = await authenticate_user("mfa-block@example.com", "password1")
assert authed_user is not None
# But MFA methods exist — login page would redirect to /mfa instead of completing login
mfa_methods = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == authed_user.id)
)).scalars().all()
assert len(mfa_methods) > 0 # Login flow would check this and redirect to /mfa
async def test_mfa_wrong_code_rejected(session):
"""Wrong TOTP code should be rejected even if method is valid."""
user = User(email="mfa-wrong@example.com", password_hash=hash_password("pw"))
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id)
session.add(method)
await session.flush()
# Load from DB and try wrong code
fetched = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar_one()
assert verify_totp_code(fetched.payload["secret"], "000000") is False
assert verify_totp_code(fetched.payload["secret"], "123456") is False
async def test_mfa_multiple_methods_any_valid_code_works(session):
"""If user has multiple TOTP methods, a valid code from any should work."""
user = User(email="mfa-multi@example.com")
session.add(user)
await session.flush()
secret1 = generate_totp_secret()
secret2 = generate_totp_secret()
session.add(MFAMethod(name="Phone", type="totp", payload={"secret": secret1}, user_id=user.id))
session.add(MFAMethod(name="Backup", type="totp", payload={"secret": secret2}, user_id=user.id))
await session.flush()
methods = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalars().all()
# Code from method 1 should verify against method 1's secret
code1 = pyotp.TOTP(secret1).now()
verified = False
for m in methods:
if verify_totp_code(m.payload["secret"], code1):
verified = True
break
assert verified is True
# Code from method 2 should also work
code2 = pyotp.TOTP(secret2).now()
verified2 = False
for m in methods:
if verify_totp_code(m.payload["secret"], code2):
verified2 = True
break
assert verified2 is True
async def test_mfa_method_last_used_tracking(session):
"""Verifying MFA should update last_used_at timestamp."""
user = User(email="mfa-tracking@example.com")
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id)
session.add(method)
await session.flush()
assert method.last_used_at is None
# Simulate successful verification and update
code = pyotp.TOTP(secret).now()
assert verify_totp_code(secret, code) is True
method.last_used_at = utcnow()
session.add(method)
await session.flush()
fetched = await session.get(MFAMethod, method.id)
assert fetched.last_used_at is not None
async def test_mfa_delete_method_allows_login_without_mfa(session, monkeypatch):
"""After removing all MFA methods, user should not be redirected to MFA challenge."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
user = User(email="mfa-remove@example.com", password_hash=hash_password("pw"))
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(name="Temp", type="totp", payload={"secret": secret}, user_id=user.id)
session.add(method)
await session.flush()
# MFA exists
count = (await session.execute(
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar()
assert count == 1
# Delete it
await session.delete(method)
await session.flush()
count = (await session.execute(
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar()
assert count == 0
# Password auth still works
authed = await authenticate_user("mfa-remove@example.com", "pw")
assert authed is not None
# No MFA methods — login flow would skip MFA challenge
mfa_check = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == authed.id)
)).scalars().all()
assert len(mfa_check) == 0
async def test_disabled_user_with_mfa_cannot_login(session, monkeypatch):
"""Disabled user should be rejected at password stage, never reaching MFA."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
user = User(
email="mfa-disabled@example.com",
password_hash=hash_password("pw"),
disabled_at=utcnow(),
)
session.add(user)
await session.flush()
secret = generate_totp_secret()
session.add(MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id))
await session.flush()
# Password auth rejects disabled user before MFA is ever checked
result = await authenticate_user("mfa-disabled@example.com", "pw")
assert result is None

View file

@ -1,309 +0,0 @@
"""Integration tests for OIDC — mock provider endpoints, test full auth code flow."""
import json
import time
from unittest.mock import patch
from uuid import uuid4
import respx
from httpx import Response
from jose import jwt
from sqlmodel import select
from wiregui.auth.oidc import get_provider_config, load_providers, oauth, register_providers
from wiregui.config import get_settings
from wiregui.models.configuration import Configuration
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.user import User
# --- Helper to create a fake OIDC provider config in the DB ---
async def _setup_oidc_config(session) -> Configuration:
"""Insert a Configuration with a test OIDC provider."""
config = Configuration(
openid_connect_providers=[
{
"id": "test-idp",
"label": "Test IdP",
"scope": "openid email profile",
"response_type": "code",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"discovery_document_uri": "https://idp.example.com/.well-known/openid-configuration",
"auto_create_users": True,
}
],
)
session.add(config)
await session.commit()
return config
def _mock_discovery():
"""Mock OIDC discovery document response."""
return {
"issuer": "https://idp.example.com",
"authorization_endpoint": "https://idp.example.com/authorize",
"token_endpoint": "https://idp.example.com/token",
"userinfo_endpoint": "https://idp.example.com/userinfo",
"jwks_uri": "https://idp.example.com/.well-known/jwks.json",
}
def _mock_token_response(email: str = "oidc-user@example.com"):
"""Mock OIDC token endpoint response with ID token."""
now = int(time.time())
id_token_payload = {
"iss": "https://idp.example.com",
"sub": "oidc-subject-123",
"aud": "test-client-id",
"email": email,
"name": "OIDC User",
"iat": now,
"exp": now + 3600,
"nonce": "test-nonce",
}
# Sign with a simple secret (in real life this would be RSA)
id_token = jwt.encode(id_token_payload, "fake-secret", algorithm="HS256")
return {
"access_token": "mock-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "mock-refresh-token",
"id_token": id_token,
}
# --- Provider config loading ---
async def test_load_providers_from_config(session, monkeypatch):
"""Providers should be loaded from the Configuration table."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
await _setup_oidc_config(session)
providers = await load_providers()
assert len(providers) == 1
assert providers[0]["id"] == "test-idp"
assert providers[0]["client_id"] == "test-client-id"
async def test_load_providers_empty_when_no_config(session, monkeypatch):
"""Should return empty list when no Configuration exists."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
providers = await load_providers()
assert providers == []
async def test_get_provider_config_by_id(session, monkeypatch):
"""Should find a specific provider by ID."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
await _setup_oidc_config(session)
config = await get_provider_config("test-idp")
assert config is not None
assert config["label"] == "Test IdP"
config_missing = await get_provider_config("nonexistent")
assert config_missing is None
# --- OIDC connection storage ---
async def test_oidc_connection_created_on_login(session):
"""Simulates what the callback route does: create user + OIDC connection."""
user = User(email="oidc-new@example.com", role="unprivileged")
session.add(user)
await session.flush()
token_data = _mock_token_response("oidc-new@example.com")
conn = OIDCConnection(
provider="test-idp",
refresh_token=token_data["refresh_token"],
refresh_response=token_data,
user_id=user.id,
)
session.add(conn)
await session.flush()
# Verify it was stored
fetched = (await session.execute(
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar_one()
assert fetched.provider == "test-idp"
assert fetched.refresh_token == "mock-refresh-token"
assert fetched.refresh_response["access_token"] == "mock-access-token"
async def test_oidc_connection_updated_on_re_login(session):
"""Re-login should update the existing OIDC connection, not create a duplicate."""
user = User(email="oidc-relogin@example.com")
session.add(user)
await session.flush()
# First login
conn = OIDCConnection(
provider="test-idp",
refresh_token="old-refresh-token",
user_id=user.id,
)
session.add(conn)
await session.flush()
# Re-login — update existing connection (as the callback route does)
existing = (await session.execute(
select(OIDCConnection).where(
OIDCConnection.user_id == user.id,
OIDCConnection.provider == "test-idp",
)
)).scalar_one()
existing.refresh_token = "new-refresh-token"
from wiregui.utils.time import utcnow
existing.refreshed_at = utcnow()
session.add(existing)
await session.flush()
# Should still be one connection
from sqlmodel import func
count = (await session.execute(
select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar()
assert count == 1
fetched = (await session.execute(
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar_one()
assert fetched.refresh_token == "new-refresh-token"
async def test_oidc_auto_create_user(session):
"""When auto_create_users is True, a new user should be created from OIDC email."""
email = "auto-created@example.com"
# Verify user doesn't exist
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
assert existing is None
# Simulate what callback does with auto_create
user = User(email=email, role="unprivileged")
session.add(user)
await session.flush()
from wiregui.utils.time import utcnow
user.last_signed_in_at = utcnow()
user.last_signed_in_method = "oidc:test-idp"
session.add(user)
await session.flush()
created = (await session.execute(select(User).where(User.email == email))).scalar_one()
assert created.role == "unprivileged"
assert created.last_signed_in_method == "oidc:test-idp"
async def test_oidc_disabled_user_rejected(session):
"""Disabled users should not be logged in via OIDC."""
from wiregui.utils.time import utcnow
user = User(email="oidc-disabled@example.com", disabled_at=utcnow())
session.add(user)
await session.flush()
# The callback route checks disabled_at before creating session
assert user.disabled_at is not None # Would redirect to /login
async def test_oidc_user_without_auto_create_rejected(session):
"""When auto_create is False and user doesn't exist, login should fail."""
email = "no-auto-create@example.com"
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
assert existing is None
# The callback route checks auto_create_users from provider config
# With auto_create=False and no existing user, it would redirect to /login
# This verifies the precondition
# --- OIDC refresh token flow ---
async def test_oidc_refresh_stores_new_token(session):
"""Simulates a successful token refresh updating the connection."""
user = User(email="oidc-refresh-test@example.com")
session.add(user)
await session.flush()
conn = OIDCConnection(
provider="test-idp",
refresh_token="old-refresh",
user_id=user.id,
)
session.add(conn)
await session.flush()
# Simulate refresh result
new_token = {
"access_token": "new-access",
"refresh_token": "new-refresh",
"expires_in": 3600,
}
conn.refresh_token = new_token.get("refresh_token", conn.refresh_token)
conn.refresh_response = new_token
from wiregui.utils.time import utcnow
conn.refreshed_at = utcnow()
session.add(conn)
await session.flush()
fetched = await session.get(OIDCConnection, conn.id)
assert fetched.refresh_token == "new-refresh"
assert fetched.refresh_response["access_token"] == "new-access"
assert fetched.refreshed_at is not None
async def test_oidc_multiple_providers_per_user(session):
"""User can have connections to multiple OIDC providers."""
user = User(email="multi-provider@example.com")
session.add(user)
await session.flush()
for provider in ["google", "okta", "azure-ad"]:
session.add(OIDCConnection(
provider=provider,
refresh_token=f"token-{provider}",
user_id=user.id,
))
await session.flush()
conns = (await session.execute(
select(OIDCConnection).where(OIDCConnection.user_id == user.id).order_by(OIDCConnection.provider)
)).scalars().all()
assert len(conns) == 3
assert [c.provider for c in conns] == ["azure-ad", "google", "okta"]

View file

@ -1,34 +1,6 @@
"""Tests for magic link authentication flow."""
from datetime import timedelta
"""Tests for magic link authentication — token subject validation."""
from wiregui.auth.jwt import create_access_token, decode_access_token
from wiregui.auth.passwords import hash_password
from wiregui.models.user import User
def test_magic_link_token_creation():
"""Magic link token should be a valid JWT with short expiry."""
token = create_access_token(
user_id="user-123",
role="unprivileged",
expires_delta=timedelta(minutes=15),
)
payload = decode_access_token(token)
assert payload is not None
assert payload["sub"] == "user-123"
assert payload["role"] == "unprivileged"
def test_magic_link_token_expired():
"""Expired magic link token should be rejected."""
token = create_access_token(
user_id="user-123",
role="admin",
expires_delta=timedelta(minutes=-1), # Already expired
)
payload = decode_access_token(token)
assert payload is None
def test_magic_link_token_wrong_user():
@ -37,22 +9,3 @@ def test_magic_link_token_wrong_user():
payload = decode_access_token(token)
assert payload["sub"] == "user-A"
# Caller is responsible for checking sub matches the URL user_id
async def test_magic_link_disabled_user_rejected(session):
"""Disabled users should not be able to use magic links."""
from wiregui.utils.time import utcnow
user = User(
email="disabled-magic@example.com",
password_hash=hash_password("pw"),
disabled_at=utcnow(),
)
session.add(user)
await session.flush()
# The token would be valid but the page handler checks disabled_at
token = create_access_token(user_id=str(user.id), role="unprivileged")
payload = decode_access_token(token)
assert payload is not None # Token itself is valid
assert user.disabled_at is not None # But user is disabled — handler would reject

View file

@ -1,4 +1,4 @@
"""Tests for TOTP MFA functionality."""
"""Tests for TOTP MFA — URI format, edge cases, QR generation, DB relationships."""
import pyotp
@ -12,22 +12,7 @@ from wiregui.models.mfa_method import MFAMethod
from wiregui.models.user import User
# --- TOTP secret generation ---
def test_generate_secret():
secret = generate_totp_secret()
assert len(secret) == 32 # base32 encoded
assert secret.isalpha() or any(c.isdigit() for c in secret)
def test_generate_secret_unique():
s1 = generate_totp_secret()
s2 = generate_totp_secret()
assert s1 != s2
# --- TOTP URI ---
# --- TOTP URI format ---
def test_get_totp_uri():
@ -43,19 +28,7 @@ def test_get_totp_uri_custom_issuer():
assert "issuer=MyVPN" in uri
# --- TOTP verification ---
def test_verify_valid_code():
secret = generate_totp_secret()
totp = pyotp.TOTP(secret)
code = totp.now()
assert verify_totp_code(secret, code) is True
def test_verify_invalid_code():
secret = generate_totp_secret()
assert verify_totp_code(secret, "000000") is False
# --- TOTP verification edge cases ---
def test_verify_wrong_secret():
@ -80,34 +53,7 @@ def test_generate_qr_svg():
assert "</svg>" in svg
# --- MFA method model integration ---
async def test_create_totp_method(session):
user = User(email="mfa-test@example.com")
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(
name="My Phone",
type="totp",
payload={"secret": secret},
user_id=user.id,
)
session.add(method)
await session.flush()
from sqlmodel import select
fetched = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar_one()
assert fetched.name == "My Phone"
assert fetched.type == "totp"
stored_secret = fetched.payload["secret"]
code = pyotp.TOTP(stored_secret).now()
assert verify_totp_code(stored_secret, code) is True
# --- MFA method DB relationships ---
async def test_user_multiple_mfa_methods(session):

View file

@ -1,168 +0,0 @@
"""Tests for SQLModel table definitions."""
import pytest # noqa: F401 — needed for pytest.raises
from sqlmodel import select
from wiregui.models.api_token import ApiToken
from wiregui.models.configuration import Configuration
from wiregui.models.connectivity_check import ConnectivityCheck
from wiregui.models.device import Device
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.rule import Rule
from wiregui.models.user import User
async def test_create_user(session):
user = User(email="alice@example.com", role="admin")
session.add(user)
await session.flush()
result = await session.execute(select(User).where(User.email == "alice@example.com"))
fetched = result.scalar_one()
assert fetched.id == user.id
assert fetched.role == "admin"
assert fetched.disabled_at is None
async def test_create_device_with_user(session):
user = User(email="bob@example.com")
session.add(user)
await session.flush()
device = Device(
name="laptop",
public_key="pk-test-device-001",
user_id=user.id,
)
session.add(device)
await session.flush()
result = await session.execute(select(Device).where(Device.public_key == "pk-test-device-001"))
fetched = result.scalar_one()
assert fetched.name == "laptop"
assert fetched.user_id == user.id
assert fetched.use_default_dns is True
assert fetched.use_default_allowed_ips is True
assert fetched.rx_bytes is None
async def test_device_unique_public_key(session):
user = User(email="carol@example.com")
session.add(user)
await session.flush()
d1 = Device(name="d1", public_key="duplicate-key", user_id=user.id)
session.add(d1)
await session.flush()
d2 = Device(name="d2", public_key="duplicate-key", user_id=user.id)
session.add(d2)
with pytest.raises(Exception): # IntegrityError
await session.flush()
async def test_create_rule(session):
user = User(email="dave@example.com")
session.add(user)
await session.flush()
rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id)
session.add(rule)
await session.flush()
result = await session.execute(select(Rule).where(Rule.user_id == user.id))
fetched = result.scalar_one()
assert fetched.action == "accept"
assert fetched.destination == "10.0.0.0/8"
assert fetched.port_type is None
assert fetched.port_range is None
async def test_create_rule_with_port(session):
rule = Rule(
action="drop",
destination="192.168.0.0/16",
port_type="tcp",
port_range="80-443",
)
session.add(rule)
await session.flush()
fetched = (await session.execute(select(Rule).where(Rule.id == rule.id))).scalar_one()
assert fetched.port_type == "tcp"
assert fetched.port_range == "80-443"
assert fetched.user_id is None # global rule
async def test_create_mfa_method(session):
user = User(email="eve@example.com")
session.add(user)
await session.flush()
mfa = MFAMethod(
name="My Authenticator",
type="totp",
payload={"secret": "JBSWY3DPEHPK3PXP"},
user_id=user.id,
)
session.add(mfa)
await session.flush()
fetched = (await session.execute(select(MFAMethod).where(MFAMethod.user_id == user.id))).scalar_one()
assert fetched.type == "totp"
assert fetched.payload["secret"] == "JBSWY3DPEHPK3PXP"
async def test_create_oidc_connection(session):
user = User(email="frank@example.com")
session.add(user)
await session.flush()
conn = OIDCConnection(provider="google", refresh_token="tok_abc", user_id=user.id)
session.add(conn)
await session.flush()
fetched = (await session.execute(select(OIDCConnection).where(OIDCConnection.user_id == user.id))).scalar_one()
assert fetched.provider == "google"
assert fetched.refresh_token == "tok_abc"
async def test_create_api_token(session):
user = User(email="grace@example.com")
session.add(user)
await session.flush()
token = ApiToken(token_hash="sha256_fake_hash", user_id=user.id)
session.add(token)
await session.flush()
fetched = (await session.execute(select(ApiToken).where(ApiToken.user_id == user.id))).scalar_one()
assert fetched.token_hash == "sha256_fake_hash"
assert fetched.expires_at is None
async def test_create_connectivity_check(session):
check = ConnectivityCheck(url="https://example.com", response_code=200)
session.add(check)
await session.flush()
fetched = (await session.execute(select(ConnectivityCheck).where(ConnectivityCheck.id == check.id))).scalar_one()
assert fetched.response_code == 200
async def test_configuration_defaults(session):
config = Configuration()
session.add(config)
await session.flush()
fetched = (await session.execute(select(Configuration).where(Configuration.id == config.id))).scalar_one()
assert fetched.allow_unprivileged_device_management is True
assert fetched.local_auth_enabled is True
assert fetched.default_client_mtu == 1280
assert fetched.default_client_persistent_keepalive == 25
assert fetched.default_client_dns == ["1.1.1.1", "1.0.0.1"]
assert fetched.default_client_allowed_ips == ["0.0.0.0/0", "::/0"]
assert fetched.vpn_session_duration == 0
assert fetched.openid_connect_providers == []
assert fetched.saml_identity_providers == []

View file

@ -1,89 +0,0 @@
"""Tests for the notification service."""
from wiregui.services import notifications
def setup_function():
"""Clear notifications before each test."""
notifications.clear_all()
def test_add_notification():
n = notifications.add("info", "Test message")
assert n.severity == "info"
assert n.message == "Test message"
assert n.user is None
assert n.id is not None
assert n.timestamp is not None
def test_add_notification_with_user():
n = notifications.add("error", "Something broke", user="admin@example.com")
assert n.user == "admin@example.com"
assert n.severity == "error"
def test_current_returns_newest_first():
notifications.add("info", "First")
notifications.add("warning", "Second")
notifications.add("error", "Third")
current = notifications.current()
assert len(current) == 3
assert current[0].message == "Third"
assert current[1].message == "Second"
assert current[2].message == "First"
def test_count():
assert notifications.count() == 0
notifications.add("info", "One")
notifications.add("info", "Two")
assert notifications.count() == 2
def test_clear_specific():
n1 = notifications.add("info", "Keep this")
n2 = notifications.add("error", "Remove this")
notifications.clear(n2.id)
current = notifications.current()
assert len(current) == 1
assert current[0].id == n1.id
def test_clear_nonexistent_id_is_noop():
notifications.add("info", "Test")
notifications.clear("nonexistent-id")
assert notifications.count() == 1
def test_clear_all():
notifications.add("info", "One")
notifications.add("info", "Two")
notifications.add("info", "Three")
assert notifications.count() == 3
notifications.clear_all()
assert notifications.count() == 0
assert notifications.current() == []
def test_to_dict():
n = notifications.add("warning", "Test dict", user="someone@example.com")
d = n.to_dict()
assert d["severity"] == "warning"
assert d["message"] == "Test dict"
assert d["user"] == "someone@example.com"
assert "id" in d
assert "timestamp" in d
def test_max_notifications():
"""Deque should cap at MAX_NOTIFICATIONS."""
for i in range(notifications.MAX_NOTIFICATIONS + 10):
notifications.add("info", f"Notification {i}")
assert notifications.count() == notifications.MAX_NOTIFICATIONS
# Newest should be the last one added
assert notifications.current()[0].message == f"Notification {notifications.MAX_NOTIFICATIONS + 9}"

View file

@ -2,62 +2,59 @@
import pytest
from wiregui.db import async_session
from wiregui.models.configuration import Configuration
from wiregui.utils.server_key import get_server_public_key
from sqlmodel import select
@pytest.fixture(autouse=True)
async def _snapshot_config():
"""Snapshot and restore server_public_key around each test."""
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
orig = c.server_public_key if c else None
cid = c.id if c else None
yield
if cid:
async with async_session() as session:
c = await session.get(Configuration, cid)
if c:
c.server_public_key = orig
session.add(c)
await session.commit()
async def test_get_server_public_key_returns_key():
async def test_get_server_public_key_returns_key(session, monkeypatch):
"""Returns the public key when configured."""
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
c.server_public_key = "TestServerPubKey123456789012345678901234w="
session.add(c)
await session.commit()
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.utils.server_key.async_session", mock_session)
c = Configuration(server_public_key="TestServerPubKey123456789012345678901234w=")
session.add(c)
await session.flush()
result = await get_server_public_key()
assert result == "TestServerPubKey123456789012345678901234w="
async def test_get_server_public_key_raises_when_missing():
async def test_get_server_public_key_raises_when_missing(session, monkeypatch):
"""Raises RuntimeError when server_public_key is None."""
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
c.server_public_key = None
session.add(c)
await session.commit()
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.utils.server_key.async_session", mock_session)
c = Configuration(server_public_key=None)
session.add(c)
await session.flush()
with pytest.raises(RuntimeError, match="not configured"):
await get_server_public_key()
async def test_get_server_public_key_raises_when_empty_string():
async def test_get_server_public_key_raises_when_empty_string(session, monkeypatch):
"""Raises RuntimeError when server_public_key is empty string."""
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
c.server_public_key = ""
session.add(c)
await session.commit()
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.utils.server_key.async_session", mock_session)
c = Configuration(server_public_key="")
session.add(c)
await session.flush()
with pytest.raises(RuntimeError, match="not configured"):
await get_server_public_key()
await get_server_public_key()

View file

@ -1,4 +1,4 @@
"""Tests for services — WireGuard and events."""
"""Tests for services — WireGuard event error handling and rule events."""
from unittest.mock import AsyncMock, patch
@ -20,53 +20,6 @@ def _make_device(**kwargs) -> Device:
return Device(**defaults)
# --- Events (with WG enabled) ---
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.firewall")
@patch("wiregui.services.events.wireguard")
async def test_on_device_created_calls_add_peer(mock_wg, mock_fw, mock_settings):
mock_settings.return_value.wg_enabled = True
mock_wg.add_peer = AsyncMock()
mock_fw.add_user_chain = AsyncMock()
mock_fw.add_device_jump_rule = AsyncMock()
device = _make_device()
await on_device_created(device)
mock_wg.add_peer.assert_awaited_once_with(
public_key="pk-test",
allowed_ips=["10.3.2.5/32", "fd00::3:2:5/128"],
preshared_key="psk-test",
)
mock_fw.add_device_jump_rule.assert_awaited_once()
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.wireguard")
async def test_on_device_deleted_calls_remove_peer(mock_wg, mock_settings):
mock_settings.return_value.wg_enabled = True
mock_wg.remove_peer = AsyncMock()
device = _make_device()
await on_device_deleted(device)
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-test")
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.wireguard")
async def test_on_device_updated_calls_add_peer(mock_wg, mock_settings):
mock_settings.return_value.wg_enabled = True
mock_wg.add_peer = AsyncMock()
device = _make_device()
await on_device_updated(device)
mock_wg.add_peer.assert_awaited_once()
# --- Events (WG disabled) ---