fix: remove unit tests redundant with e2e, fix test DB isolation
Remove 7 test files fully covered by e2e tests (admin, account, models, API routes, integration MFA/OIDC, notifications). Trim 5 more files to keep only edge cases not reachable via e2e. Fix conftest to replace wiregui.db engine/session at import time so all code uses the test database. Use session-scoped tables with per-test savepoint isolation to prevent data leaking between tests.
This commit is contained in:
parent
a9f62d5caf
commit
a012635dff
15 changed files with 153 additions and 2006 deletions
|
|
@ -1,4 +1,12 @@
|
|||
"""Shared test fixtures — async DB session using a test database."""
|
||||
"""Shared test fixtures — async DB session using a test database.
|
||||
|
||||
The module-level code below replaces ``wiregui.db.engine`` and
|
||||
``wiregui.db.async_session`` with instances pointing at the **test** database
|
||||
*before* any test (or other module) can grab a reference to the originals.
|
||||
This means every ``from wiregui.db import async_session`` — whether in test
|
||||
files or in production code like ``wiregui.utils.server_key`` — will get the
|
||||
test-database session maker.
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
|
|
@ -8,6 +16,7 @@ from sqlalchemy import text
|
|||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
import wiregui.db as _db_module
|
||||
from wiregui.config import get_settings
|
||||
|
||||
# All models must be imported so SQLModel.metadata knows about them
|
||||
|
|
@ -51,19 +60,41 @@ def _ensure_test_db_sync():
|
|||
|
||||
_ensure_test_db_sync()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Replace the production engine/session in wiregui.db at import time so that
|
||||
# every module that does ``from wiregui.db import async_session`` picks up the
|
||||
# test database. This MUST happen before test modules are collected (which
|
||||
# triggers their top-level imports).
|
||||
# ---------------------------------------------------------------------------
|
||||
_test_engine = create_async_engine(TEST_DATABASE_URL)
|
||||
_test_session_factory = async_sessionmaker(_test_engine, expire_on_commit=False)
|
||||
_db_module.engine = _test_engine
|
||||
_db_module.async_session = _test_session_factory
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def _setup_test_tables():
|
||||
"""Create all tables once at the start of the test session, drop at end."""
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
yield
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
await _test_engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def session() -> AsyncGenerator[AsyncSession]:
|
||||
"""Fresh engine + session per test, with table setup/teardown."""
|
||||
engine = create_async_engine(TEST_DATABASE_URL)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
"""Per-test session with transaction isolation.
|
||||
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
async with factory() as sess:
|
||||
The session is bound to a connection-level transaction that is always
|
||||
rolled back at teardown. When tested code calls ``session.commit()``,
|
||||
SQLAlchemy only releases a SAVEPOINT — the outer transaction is never
|
||||
committed, so no test data persists between tests.
|
||||
"""
|
||||
async with _test_engine.connect() as conn:
|
||||
txn = await conn.begin()
|
||||
sess = AsyncSession(bind=conn, expire_on_commit=False, join_transaction_mode="create_savepoint")
|
||||
yield sess
|
||||
await sess.rollback()
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
await engine.dispose()
|
||||
await sess.close()
|
||||
await txn.rollback()
|
||||
|
|
|
|||
|
|
@ -1,161 +0,0 @@
|
|||
"""Tests for account functionality — password changes, API tokens, OIDC connections."""
|
||||
|
||||
import hashlib
|
||||
from datetime import timedelta
|
||||
|
||||
from sqlmodel import func, select
|
||||
|
||||
from wiregui.auth.api_token import generate_api_token
|
||||
from wiregui.auth.passwords import hash_password, verify_password
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# --- Password change ---
|
||||
|
||||
|
||||
async def test_password_change_flow(session):
|
||||
"""Simulate the password change flow: verify old, set new."""
|
||||
user = User(email="pw-change@example.com", password_hash=hash_password("old-password"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# Verify old password
|
||||
assert verify_password("old-password", user.password_hash) is True
|
||||
|
||||
# Change password
|
||||
user.password_hash = hash_password("new-password")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert verify_password("new-password", fetched.password_hash) is True
|
||||
assert verify_password("old-password", fetched.password_hash) is False
|
||||
|
||||
|
||||
async def test_password_change_wrong_current(session):
|
||||
"""Wrong current password should not allow change."""
|
||||
user = User(email="pw-wrong@example.com", password_hash=hash_password("correct"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# Simulate check
|
||||
assert verify_password("wrong", user.password_hash) is False
|
||||
|
||||
|
||||
# --- API token management ---
|
||||
|
||||
|
||||
async def test_create_multiple_tokens(session):
|
||||
user = User(email="multi-token@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
for _ in range(3):
|
||||
_, token_hash = generate_api_token()
|
||||
session.add(ApiToken(token_hash=token_hash, user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(ApiToken).where(ApiToken.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 3
|
||||
|
||||
|
||||
async def test_token_with_expiry(session):
|
||||
user = User(email="expiry-token@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
_, token_hash = generate_api_token()
|
||||
expires = utcnow() + timedelta(days=30)
|
||||
token = ApiToken(token_hash=token_hash, expires_at=expires, user_id=user.id)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(ApiToken, token.id)
|
||||
assert fetched.expires_at is not None
|
||||
assert fetched.expires_at > utcnow()
|
||||
|
||||
|
||||
async def test_delete_token(session):
|
||||
user = User(email="del-token@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
_, token_hash = generate_api_token()
|
||||
token = ApiToken(token_hash=token_hash, user_id=user.id)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
await session.delete(token)
|
||||
await session.flush()
|
||||
|
||||
assert await session.get(ApiToken, token.id) is None
|
||||
|
||||
|
||||
# --- OIDC connections ---
|
||||
|
||||
|
||||
async def test_oidc_connection_create(session):
|
||||
user = User(email="oidc-conn@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(
|
||||
provider="google",
|
||||
refresh_token="refresh-tok-123",
|
||||
refresh_response={"access_token": "at", "token_type": "Bearer"},
|
||||
refreshed_at=utcnow(),
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar_one()
|
||||
assert fetched.provider == "google"
|
||||
assert fetched.refresh_token == "refresh-tok-123"
|
||||
assert fetched.refresh_response["access_token"] == "at"
|
||||
|
||||
|
||||
async def test_multiple_oidc_providers(session):
|
||||
user = User(email="multi-oidc@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
for provider in ["google", "okta", "azure"]:
|
||||
conn = OIDCConnection(provider=provider, user_id=user.id)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 3
|
||||
|
||||
|
||||
async def test_oidc_connection_update_refresh_token(session):
|
||||
user = User(email="oidc-refresh@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(
|
||||
provider="google",
|
||||
refresh_token="old-token",
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
conn.refresh_token = "new-token"
|
||||
conn.refreshed_at = utcnow()
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(OIDCConnection, conn.id)
|
||||
assert fetched.refresh_token == "new-token"
|
||||
assert fetched.refreshed_at is not None
|
||||
|
|
@ -1,283 +0,0 @@
|
|||
"""Tests for admin functionality — user management, configuration, cascading deletes."""
|
||||
|
||||
import pytest
|
||||
from sqlmodel import func, select
|
||||
|
||||
from wiregui.auth.passwords import hash_password, verify_password
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# --- User CRUD ---
|
||||
|
||||
|
||||
async def test_create_user_with_role(session):
|
||||
user = User(email="new-admin@test.com", password_hash=hash_password("secret"), role="admin")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert fetched.role == "admin"
|
||||
assert verify_password("secret", fetched.password_hash)
|
||||
|
||||
|
||||
async def test_update_user_email(session):
|
||||
user = User(email="old@test.com", password_hash=hash_password("pw"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
user.email = "new@test.com"
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert fetched.email == "new@test.com"
|
||||
|
||||
|
||||
async def test_disable_user(session):
|
||||
user = User(email="active@test.com", password_hash=hash_password("pw"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
assert user.disabled_at is None
|
||||
|
||||
user.disabled_at = utcnow()
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert fetched.disabled_at is not None
|
||||
|
||||
|
||||
async def test_promote_demote_user(session):
|
||||
user = User(email="user@test.com", role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
assert user.role == "unprivileged"
|
||||
|
||||
user.role = "admin"
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert fetched.role == "admin"
|
||||
|
||||
user.role = "unprivileged"
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
assert (await session.get(User, user.id)).role == "unprivileged"
|
||||
|
||||
|
||||
# --- Cascading delete (manual, as we do it in the admin page) ---
|
||||
|
||||
|
||||
async def test_delete_user_cascades_devices(session):
|
||||
user = User(email="cascade@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
d1 = Device(name="d1", public_key="pk-cascade-1", ipv4="10.0.0.1", user_id=user.id)
|
||||
d2 = Device(name="d2", public_key="pk-cascade-2", ipv4="10.0.0.2", user_id=user.id)
|
||||
session.add_all([d1, d2])
|
||||
await session.flush()
|
||||
|
||||
# Manually delete devices then user (matching admin page behavior)
|
||||
devices = (await session.execute(select(Device).where(Device.user_id == user.id))).scalars().all()
|
||||
for d in devices:
|
||||
await session.delete(d)
|
||||
await session.delete(user)
|
||||
await session.flush()
|
||||
|
||||
assert (await session.execute(select(func.count()).select_from(Device).where(Device.user_id == user.id))).scalar() == 0
|
||||
assert await session.get(User, user.id) is None
|
||||
|
||||
|
||||
async def test_delete_user_cascades_rules(session):
|
||||
user = User(email="rule-cascade@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id)
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
|
||||
# Delete rules then user
|
||||
rules = (await session.execute(select(Rule).where(Rule.user_id == user.id))).scalars().all()
|
||||
for r in rules:
|
||||
await session.delete(r)
|
||||
await session.delete(user)
|
||||
await session.flush()
|
||||
|
||||
assert (await session.execute(select(func.count()).select_from(Rule).where(Rule.user_id == user.id))).scalar() == 0
|
||||
|
||||
|
||||
# --- Configuration singleton ---
|
||||
|
||||
|
||||
async def test_configuration_create_and_update(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
assert config.default_client_mtu == 1280
|
||||
assert config.local_auth_enabled is True
|
||||
|
||||
config.default_client_mtu = 1400
|
||||
config.local_auth_enabled = False
|
||||
config.vpn_session_duration = 3600
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Configuration, config.id)
|
||||
assert fetched.default_client_mtu == 1400
|
||||
assert fetched.local_auth_enabled is False
|
||||
assert fetched.vpn_session_duration == 3600
|
||||
|
||||
|
||||
async def test_configuration_oidc_providers(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
assert config.openid_connect_providers == []
|
||||
|
||||
providers = [
|
||||
{
|
||||
"id": "google",
|
||||
"label": "Sign in with Google",
|
||||
"scope": "openid email profile",
|
||||
"response_type": "code",
|
||||
"client_id": "google-client-id",
|
||||
"client_secret": "google-secret",
|
||||
"discovery_document_uri": "https://accounts.google.com/.well-known/openid-configuration",
|
||||
"auto_create_users": True,
|
||||
},
|
||||
{
|
||||
"id": "okta",
|
||||
"label": "Okta SSO",
|
||||
"scope": "openid email profile",
|
||||
"response_type": "code",
|
||||
"client_id": "okta-client-id",
|
||||
"client_secret": "okta-secret",
|
||||
"discovery_document_uri": "https://dev-123.okta.com/.well-known/openid-configuration",
|
||||
"auto_create_users": False,
|
||||
},
|
||||
]
|
||||
config.openid_connect_providers = providers
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Configuration, config.id)
|
||||
assert len(fetched.openid_connect_providers) == 2
|
||||
assert fetched.openid_connect_providers[0]["id"] == "google"
|
||||
assert fetched.openid_connect_providers[1]["auto_create_users"] is False
|
||||
|
||||
|
||||
async def test_configuration_update_client_defaults(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
config.default_client_endpoint = "vpn.example.com"
|
||||
config.default_client_dns = ["8.8.8.8", "8.8.4.4"]
|
||||
config.default_client_allowed_ips = ["10.0.0.0/8"]
|
||||
config.default_client_persistent_keepalive = 30
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Configuration, config.id)
|
||||
assert fetched.default_client_endpoint == "vpn.example.com"
|
||||
assert fetched.default_client_dns == ["8.8.8.8", "8.8.4.4"]
|
||||
assert fetched.default_client_allowed_ips == ["10.0.0.0/8"]
|
||||
assert fetched.default_client_persistent_keepalive == 30
|
||||
|
||||
|
||||
async def test_configuration_security_toggles(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
config.allow_unprivileged_device_management = False
|
||||
config.allow_unprivileged_device_configuration = False
|
||||
config.disable_vpn_on_oidc_error = True
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Configuration, config.id)
|
||||
assert fetched.allow_unprivileged_device_management is False
|
||||
assert fetched.allow_unprivileged_device_configuration is False
|
||||
assert fetched.disable_vpn_on_oidc_error is True
|
||||
|
||||
|
||||
# --- Device config overrides ---
|
||||
|
||||
|
||||
async def test_device_with_custom_config(session):
|
||||
user = User(email="config-user@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(
|
||||
name="custom-config",
|
||||
public_key="pk-custom-config",
|
||||
user_id=user.id,
|
||||
use_default_dns=False,
|
||||
use_default_endpoint=False,
|
||||
use_default_mtu=False,
|
||||
use_default_persistent_keepalive=False,
|
||||
use_default_allowed_ips=False,
|
||||
dns=["8.8.8.8"],
|
||||
endpoint="custom-vpn.example.com",
|
||||
mtu=1400,
|
||||
persistent_keepalive=15,
|
||||
allowed_ips=["10.0.0.0/8", "172.16.0.0/12"],
|
||||
)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Device, device.id)
|
||||
assert fetched.use_default_dns is False
|
||||
assert fetched.dns == ["8.8.8.8"]
|
||||
assert fetched.endpoint == "custom-vpn.example.com"
|
||||
assert fetched.mtu == 1400
|
||||
assert fetched.persistent_keepalive == 15
|
||||
assert fetched.allowed_ips == ["10.0.0.0/8", "172.16.0.0/12"]
|
||||
|
||||
|
||||
async def test_device_default_flags_are_true(session):
|
||||
user = User(email="defaults@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(name="defaults", public_key="pk-defaults", user_id=user.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Device, device.id)
|
||||
assert fetched.use_default_allowed_ips is True
|
||||
assert fetched.use_default_dns is True
|
||||
assert fetched.use_default_endpoint is True
|
||||
assert fetched.use_default_mtu is True
|
||||
assert fetched.use_default_persistent_keepalive is True
|
||||
|
||||
|
||||
# --- User device count ---
|
||||
|
||||
|
||||
async def test_user_device_count_query(session):
|
||||
user = User(email="count-user@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
for i in range(3):
|
||||
session.add(Device(name=f"d{i}", public_key=f"pk-count-{i}", user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(Device).where(Device.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 3
|
||||
|
|
@ -1,15 +1,12 @@
|
|||
"""Tests for API dependency injection — Bearer token auth and admin guard."""
|
||||
|
||||
import hashlib
|
||||
from datetime import timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from wiregui.auth.api_token import generate_api_token
|
||||
from wiregui.auth.api_token import generate_api_token, resolve_bearer_token
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
|
@ -18,143 +15,80 @@ from wiregui.utils.time import utcnow
|
|||
# ========== resolve_bearer_token ==========
|
||||
|
||||
|
||||
async def test_resolve_valid_token():
|
||||
async def test_resolve_valid_token(session):
|
||||
"""Valid, non-expired token resolves to user."""
|
||||
from wiregui.auth.api_token import resolve_bearer_token
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
|
||||
async with async_session() as session:
|
||||
user = User(email="api-test@test.com", password_hash=hash_password("x"), role="admin")
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.flush()
|
||||
|
||||
api_token = ApiToken(
|
||||
token_hash=token_hash,
|
||||
user_id=user.id,
|
||||
expires_at=utcnow() + timedelta(hours=1),
|
||||
)
|
||||
api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=utcnow() + timedelta(hours=1))
|
||||
session.add(api_token)
|
||||
await session.commit()
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
async with async_session() as session:
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is not None
|
||||
assert resolved.id == user.id
|
||||
assert resolved.email == "api-test@test.com"
|
||||
finally:
|
||||
async with async_session() as session:
|
||||
await session.delete(await session.get(ApiToken, api_token.id))
|
||||
await session.delete(await session.get(User, user.id))
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def test_resolve_expired_token():
|
||||
async def test_resolve_expired_token(session):
|
||||
"""Expired token returns None."""
|
||||
from wiregui.auth.api_token import resolve_bearer_token
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
|
||||
async with async_session() as session:
|
||||
user = User(email="api-expired@test.com", password_hash=hash_password("x"), role="admin")
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.flush()
|
||||
|
||||
api_token = ApiToken(
|
||||
token_hash=token_hash,
|
||||
user_id=user.id,
|
||||
expires_at=utcnow() - timedelta(hours=1), # already expired
|
||||
)
|
||||
api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=utcnow() - timedelta(hours=1))
|
||||
session.add(api_token)
|
||||
await session.commit()
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
async with async_session() as session:
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is None
|
||||
finally:
|
||||
async with async_session() as session:
|
||||
await session.delete(await session.get(ApiToken, api_token.id))
|
||||
await session.delete(await session.get(User, user.id))
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def test_resolve_invalid_token():
|
||||
async def test_resolve_invalid_token(session):
|
||||
"""Nonexistent token returns None."""
|
||||
from wiregui.auth.api_token import resolve_bearer_token
|
||||
|
||||
async with async_session() as session:
|
||||
resolved = await resolve_bearer_token(session, "totally-bogus-token")
|
||||
assert resolved is None
|
||||
|
||||
|
||||
async def test_resolve_token_disabled_user():
|
||||
async def test_resolve_token_disabled_user(session):
|
||||
"""Token for disabled user returns None."""
|
||||
from wiregui.auth.api_token import resolve_bearer_token
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
|
||||
async with async_session() as session:
|
||||
user = User(
|
||||
email="api-disabled@test.com", password_hash=hash_password("x"),
|
||||
role="admin", disabled_at=utcnow(),
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.flush()
|
||||
|
||||
api_token = ApiToken(
|
||||
token_hash=token_hash,
|
||||
user_id=user.id,
|
||||
expires_at=utcnow() + timedelta(hours=1),
|
||||
)
|
||||
api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=utcnow() + timedelta(hours=1))
|
||||
session.add(api_token)
|
||||
await session.commit()
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
async with async_session() as session:
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is None
|
||||
finally:
|
||||
async with async_session() as session:
|
||||
await session.delete(await session.get(ApiToken, api_token.id))
|
||||
await session.delete(await session.get(User, user.id))
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def test_resolve_token_no_expiry():
|
||||
async def test_resolve_token_no_expiry(session):
|
||||
"""Token without expires_at (never expires) resolves successfully."""
|
||||
from wiregui.auth.api_token import resolve_bearer_token
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
|
||||
async with async_session() as session:
|
||||
user = User(email="api-noexp@test.com", password_hash=hash_password("x"), role="admin")
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.flush()
|
||||
|
||||
api_token = ApiToken(
|
||||
token_hash=token_hash,
|
||||
user_id=user.id,
|
||||
expires_at=None,
|
||||
)
|
||||
api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=None)
|
||||
session.add(api_token)
|
||||
await session.commit()
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
async with async_session() as session:
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is not None
|
||||
assert resolved.id == user.id
|
||||
finally:
|
||||
async with async_session() as session:
|
||||
await session.delete(await session.get(ApiToken, api_token.id))
|
||||
await session.delete(await session.get(User, user.id))
|
||||
await session.commit()
|
||||
|
||||
|
||||
# ========== get_current_api_user (via FastAPI deps) ==========
|
||||
|
|
@ -187,7 +121,7 @@ async def test_get_current_api_user_bad_scheme():
|
|||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
async def test_get_current_api_user_invalid_token():
|
||||
async def test_get_current_api_user_invalid_token(session):
|
||||
"""Valid Bearer scheme but bogus token raises 401."""
|
||||
from fastapi import HTTPException
|
||||
from wiregui.api.deps import get_current_api_user
|
||||
|
|
@ -195,45 +129,31 @@ async def test_get_current_api_user_invalid_token():
|
|||
request = MagicMock()
|
||||
request.headers = {"Authorization": "Bearer bogus-token-value"}
|
||||
|
||||
async with async_session() as session:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_api_user(request, session=session)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid" in exc_info.value.detail
|
||||
|
||||
|
||||
async def test_get_current_api_user_valid_token():
|
||||
async def test_get_current_api_user_valid_token(session):
|
||||
"""Valid Bearer token resolves to user."""
|
||||
from wiregui.api.deps import get_current_api_user
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
|
||||
async with async_session() as session:
|
||||
user = User(email="api-dep-test@test.com", password_hash=hash_password("x"), role="admin")
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.flush()
|
||||
|
||||
api_token = ApiToken(
|
||||
token_hash=token_hash,
|
||||
user_id=user.id,
|
||||
expires_at=utcnow() + timedelta(hours=1),
|
||||
)
|
||||
api_token = ApiToken(token_hash=token_hash, user_id=user.id, expires_at=utcnow() + timedelta(hours=1))
|
||||
session.add(api_token)
|
||||
await session.commit()
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
request = MagicMock()
|
||||
request.headers = {"Authorization": f"Bearer {plaintext}"}
|
||||
|
||||
async with async_session() as session:
|
||||
resolved = await get_current_api_user(request, session=session)
|
||||
assert resolved.id == user.id
|
||||
finally:
|
||||
async with async_session() as session:
|
||||
await session.delete(await session.get(ApiToken, api_token.id))
|
||||
await session.delete(await session.get(User, user.id))
|
||||
await session.commit()
|
||||
|
||||
|
||||
# ========== require_admin ==========
|
||||
|
|
|
|||
|
|
@ -1,325 +0,0 @@
|
|||
"""Tests for REST API routes via httpx AsyncClient against the FastAPI app."""
|
||||
|
||||
import hashlib
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.api.deps import get_current_api_user, get_db, require_admin
|
||||
from wiregui.api.v0 import router as api_router
|
||||
from wiregui.auth.api_token import generate_api_token
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
def _build_app(session, admin_user=None, regular_user=None):
|
||||
"""Build a test FastAPI app with overridden dependencies."""
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(api_router, prefix="/api")
|
||||
|
||||
async def override_get_db():
|
||||
yield session
|
||||
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
if admin_user:
|
||||
test_app.dependency_overrides[get_current_api_user] = lambda: admin_user
|
||||
test_app.dependency_overrides[require_admin] = lambda: admin_user
|
||||
|
||||
return test_app
|
||||
|
||||
|
||||
async def _make_admin(session) -> User:
|
||||
user = User(email="api-admin@test.com", password_hash=hash_password("pw"), role="admin")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
return user
|
||||
|
||||
|
||||
async def _make_user(session, email="api-user@test.com") -> User:
|
||||
user = User(email=email, password_hash=hash_password("pw"), role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
return user
|
||||
|
||||
|
||||
# ========== Users API ==========
|
||||
|
||||
|
||||
async def test_list_users(session):
|
||||
admin = await _make_admin(session)
|
||||
await _make_user(session, "user1@test.com")
|
||||
await _make_user(session, "user2@test.com")
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/users/")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 3 # admin + 2 users
|
||||
|
||||
|
||||
async def test_get_user(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get(f"/api/v0/users/{admin.id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["email"] == "api-admin@test.com"
|
||||
|
||||
|
||||
async def test_get_user_not_found(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get(f"/api/v0/users/{uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_create_user(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.post("/api/v0/users/", json={
|
||||
"email": "new-api-user@test.com",
|
||||
"password": "secret123",
|
||||
"role": "unprivileged",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["email"] == "new-api-user@test.com"
|
||||
assert data["role"] == "unprivileged"
|
||||
assert "id" in data
|
||||
|
||||
|
||||
async def test_update_user(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put(f"/api/v0/users/{user.id}", json={
|
||||
"role": "admin",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["role"] == "admin"
|
||||
|
||||
|
||||
async def test_update_user_password(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put(f"/api/v0/users/{user.id}", json={
|
||||
"password": "new-password-123",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
from wiregui.auth.passwords import verify_password
|
||||
refreshed = await session.get(User, user.id)
|
||||
assert verify_password("new-password-123", refreshed.password_hash)
|
||||
|
||||
|
||||
async def test_delete_user(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.delete(f"/api/v0/users/{user.id}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
assert await session.get(User, user.id) is None
|
||||
|
||||
|
||||
# ========== Devices API ==========
|
||||
|
||||
|
||||
async def test_list_devices_admin_sees_all(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session)
|
||||
session.add(Device(name="d1", public_key="pk-api-d1", user_id=admin.id))
|
||||
session.add(Device(name="d2", public_key="pk-api-d2", user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/devices/")
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) >= 2
|
||||
|
||||
|
||||
async def test_list_devices_user_sees_own(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session, "own-devices@test.com")
|
||||
session.add(Device(name="mine", public_key="pk-api-mine", user_id=user.id))
|
||||
session.add(Device(name="not-mine", public_key="pk-api-notmine", user_id=admin.id))
|
||||
await session.flush()
|
||||
|
||||
# Override to be the regular user
|
||||
test_app = _build_app(session)
|
||||
test_app.dependency_overrides[get_current_api_user] = lambda: user
|
||||
async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/devices/")
|
||||
assert resp.status_code == 200
|
||||
names = [d["name"] for d in resp.json()]
|
||||
assert "mine" in names
|
||||
assert "not-mine" not in names
|
||||
|
||||
|
||||
async def test_get_device(session):
|
||||
admin = await _make_admin(session)
|
||||
device = Device(name="detail", public_key="pk-api-detail", user_id=admin.id, ipv4="10.0.0.5")
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get(f"/api/v0/devices/{device.id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "detail"
|
||||
assert resp.json()["ipv4"] == "10.0.0.5"
|
||||
|
||||
|
||||
async def test_get_device_forbidden_for_other_user(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session, "other-dev@test.com")
|
||||
device = Device(name="admin-dev", public_key="pk-api-forbid", user_id=admin.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
test_app = _build_app(session)
|
||||
test_app.dependency_overrides[get_current_api_user] = lambda: user
|
||||
async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
|
||||
resp = await client.get(f"/api/v0/devices/{device.id}")
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
async def test_update_device(session):
|
||||
admin = await _make_admin(session)
|
||||
device = Device(name="old-name", public_key="pk-api-update", user_id=admin.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put(f"/api/v0/devices/{device.id}", json={"name": "new-name"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "new-name"
|
||||
|
||||
|
||||
async def test_delete_device(session):
|
||||
admin = await _make_admin(session)
|
||||
device = Device(name="to-delete", public_key="pk-api-del", user_id=admin.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
did = device.id
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.delete(f"/api/v0/devices/{did}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
assert await session.get(Device, did) is None
|
||||
|
||||
|
||||
# ========== Rules API ==========
|
||||
|
||||
|
||||
async def test_list_rules(session):
|
||||
admin = await _make_admin(session)
|
||||
session.add(Rule(action="accept", destination="10.0.0.0/8"))
|
||||
session.add(Rule(action="drop", destination="192.168.0.0/16", user_id=admin.id))
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/rules/")
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) >= 2
|
||||
|
||||
|
||||
async def test_create_rule(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.post("/api/v0/rules/", json={
|
||||
"action": "accept",
|
||||
"destination": "172.16.0.0/12",
|
||||
"port_type": "tcp",
|
||||
"port_range": "443",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["action"] == "accept"
|
||||
assert data["destination"] == "172.16.0.0/12"
|
||||
assert data["port_type"] == "tcp"
|
||||
assert data["port_range"] == "443"
|
||||
|
||||
|
||||
async def test_update_rule(session):
|
||||
admin = await _make_admin(session)
|
||||
rule = Rule(action="accept", destination="10.0.0.0/8")
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put(f"/api/v0/rules/{rule.id}", json={"action": "drop"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["action"] == "drop"
|
||||
|
||||
|
||||
async def test_delete_rule(session):
|
||||
admin = await _make_admin(session)
|
||||
rule = Rule(action="drop", destination="0.0.0.0/0")
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
rid = rule.id
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.delete(f"/api/v0/rules/{rid}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
assert await session.get(Rule, rid) is None
|
||||
|
||||
|
||||
# ========== Configuration API ==========
|
||||
|
||||
|
||||
async def test_get_configuration_auto_creates(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/configuration/")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["default_client_mtu"] == 1280
|
||||
assert data["local_auth_enabled"] is True
|
||||
|
||||
|
||||
async def test_update_configuration(session):
|
||||
admin = await _make_admin(session)
|
||||
# Pre-create config
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put("/api/v0/configuration/", json={
|
||||
"default_client_mtu": 1400,
|
||||
"vpn_session_duration": 3600,
|
||||
"default_client_dns": ["8.8.8.8"],
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["default_client_mtu"] == 1400
|
||||
assert data["vpn_session_duration"] == 3600
|
||||
assert data["default_client_dns"] == ["8.8.8.8"]
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for authentication modules."""
|
||||
"""Tests for authentication modules — seed logic and JWT edge cases."""
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
|
|
@ -8,17 +8,7 @@ from wiregui.auth.seed import seed_admin
|
|||
from wiregui.models.user import User
|
||||
|
||||
|
||||
# --- Password hashing ---
|
||||
|
||||
|
||||
def test_hash_and_verify():
|
||||
hashed = hash_password("my-secret")
|
||||
assert verify_password("my-secret", hashed) is True
|
||||
|
||||
|
||||
def test_verify_wrong_password():
|
||||
hashed = hash_password("correct")
|
||||
assert verify_password("wrong", hashed) is False
|
||||
# --- Password hashing (format guard) ---
|
||||
|
||||
|
||||
def test_hash_is_not_plaintext():
|
||||
|
|
@ -27,16 +17,7 @@ def test_hash_is_not_plaintext():
|
|||
assert hashed.startswith("$2b$")
|
||||
|
||||
|
||||
# --- JWT ---
|
||||
|
||||
|
||||
def test_create_and_decode_token():
|
||||
token = create_access_token(user_id="user-123", role="admin")
|
||||
payload = decode_access_token(token)
|
||||
assert payload is not None
|
||||
assert payload["sub"] == "user-123"
|
||||
assert payload["role"] == "admin"
|
||||
assert "exp" in payload
|
||||
# --- JWT edge cases ---
|
||||
|
||||
|
||||
def test_decode_invalid_token():
|
||||
|
|
@ -54,8 +35,6 @@ def test_decode_tampered_token():
|
|||
|
||||
async def test_seed_admin_creates_user(session, monkeypatch):
|
||||
"""seed_admin should create an admin when no users exist."""
|
||||
# Patch async_session to use our test session
|
||||
from unittest.mock import AsyncMock
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
|
|||
|
|
@ -1,65 +1,9 @@
|
|||
"""Extended auth tests — OIDC registration, WebAuthn options, session edge cases."""
|
||||
"""Extended auth tests — OIDC registration, WebAuthn options, rule event handlers."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.auth.session import authenticate_user
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# ========== Session / authenticate_user edge cases ==========
|
||||
|
||||
|
||||
async def test_authenticate_user_no_password_hash(session, monkeypatch):
|
||||
"""Users without a password (OIDC-only) should not authenticate via password."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
user = User(email="no-pw@test.com", password_hash=None)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
result = await authenticate_user("no-pw@test.com", "anything")
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_authenticate_user_disabled(session, monkeypatch):
|
||||
"""Disabled users should not authenticate."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
user = User(email="disabled-auth@test.com", password_hash=hash_password("pw"), disabled_at=utcnow())
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
result = await authenticate_user("disabled-auth@test.com", "pw")
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_authenticate_user_nonexistent(session, monkeypatch):
|
||||
"""Nonexistent email should return None."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
result = await authenticate_user("ghost@nowhere.com", "pw")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ========== OIDC provider registration ==========
|
||||
|
|
@ -163,13 +107,11 @@ async def test_on_rule_updated_triggers_rebuild(mock_fw, mock_settings):
|
|||
from wiregui.models.rule import Rule
|
||||
from wiregui.services.events import on_rule_updated
|
||||
|
||||
# Need to mock the DB call inside _rebuild_user_chain
|
||||
with patch("wiregui.services.events.async_session") as mock_session_factory:
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Mock the select results
|
||||
mock_rules_result = MagicMock()
|
||||
mock_rules_result.scalars.return_value.all.return_value = []
|
||||
mock_devices_result = MagicMock()
|
||||
|
|
|
|||
|
|
@ -1,239 +0,0 @@
|
|||
"""Integration tests for MFA — full registration and authentication flows through the database."""
|
||||
|
||||
import pyotp
|
||||
from sqlmodel import func, select
|
||||
|
||||
from wiregui.auth.mfa import generate_totp_secret, verify_totp_code
|
||||
from wiregui.auth.passwords import hash_password, verify_password
|
||||
from wiregui.auth.session import authenticate_user
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
async def test_full_totp_registration_flow(session, monkeypatch):
|
||||
"""End-to-end: create user → generate secret → verify code → store method → re-verify from DB."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
# Create user with password
|
||||
user = User(email="mfa-flow@example.com", password_hash=hash_password("secure123"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# Step 1: Generate TOTP secret (happens in account page)
|
||||
secret = generate_totp_secret()
|
||||
|
||||
# Step 2: User scans QR, enters code from their authenticator
|
||||
totp = pyotp.TOTP(secret)
|
||||
code = totp.now()
|
||||
|
||||
# Step 3: Verify the code is correct before saving
|
||||
assert verify_totp_code(secret, code) is True
|
||||
|
||||
# Step 4: Save the MFA method to DB
|
||||
method = MFAMethod(
|
||||
name="My Authenticator",
|
||||
type="totp",
|
||||
payload={"secret": secret},
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
# Step 5: Simulate future login — load method from DB and verify a fresh code
|
||||
fetched_methods = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalars().all()
|
||||
|
||||
assert len(fetched_methods) == 1
|
||||
stored_secret = fetched_methods[0].payload["secret"]
|
||||
fresh_code = pyotp.TOTP(stored_secret).now()
|
||||
assert verify_totp_code(stored_secret, fresh_code) is True
|
||||
|
||||
|
||||
async def test_mfa_blocks_login_without_code(session, monkeypatch):
|
||||
"""User with MFA should not be fully authenticated without completing MFA challenge."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
# Create user with MFA
|
||||
user = User(email="mfa-block@example.com", password_hash=hash_password("password1"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(name="Phone", type="totp", payload={"secret": secret}, user_id=user.id)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
# Password auth succeeds
|
||||
authed_user = await authenticate_user("mfa-block@example.com", "password1")
|
||||
assert authed_user is not None
|
||||
|
||||
# But MFA methods exist — login page would redirect to /mfa instead of completing login
|
||||
mfa_methods = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == authed_user.id)
|
||||
)).scalars().all()
|
||||
assert len(mfa_methods) > 0 # Login flow would check this and redirect to /mfa
|
||||
|
||||
|
||||
async def test_mfa_wrong_code_rejected(session):
|
||||
"""Wrong TOTP code should be rejected even if method is valid."""
|
||||
user = User(email="mfa-wrong@example.com", password_hash=hash_password("pw"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
# Load from DB and try wrong code
|
||||
fetched = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar_one()
|
||||
|
||||
assert verify_totp_code(fetched.payload["secret"], "000000") is False
|
||||
assert verify_totp_code(fetched.payload["secret"], "123456") is False
|
||||
|
||||
|
||||
async def test_mfa_multiple_methods_any_valid_code_works(session):
|
||||
"""If user has multiple TOTP methods, a valid code from any should work."""
|
||||
user = User(email="mfa-multi@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret1 = generate_totp_secret()
|
||||
secret2 = generate_totp_secret()
|
||||
|
||||
session.add(MFAMethod(name="Phone", type="totp", payload={"secret": secret1}, user_id=user.id))
|
||||
session.add(MFAMethod(name="Backup", type="totp", payload={"secret": secret2}, user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
methods = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalars().all()
|
||||
|
||||
# Code from method 1 should verify against method 1's secret
|
||||
code1 = pyotp.TOTP(secret1).now()
|
||||
verified = False
|
||||
for m in methods:
|
||||
if verify_totp_code(m.payload["secret"], code1):
|
||||
verified = True
|
||||
break
|
||||
assert verified is True
|
||||
|
||||
# Code from method 2 should also work
|
||||
code2 = pyotp.TOTP(secret2).now()
|
||||
verified2 = False
|
||||
for m in methods:
|
||||
if verify_totp_code(m.payload["secret"], code2):
|
||||
verified2 = True
|
||||
break
|
||||
assert verified2 is True
|
||||
|
||||
|
||||
async def test_mfa_method_last_used_tracking(session):
|
||||
"""Verifying MFA should update last_used_at timestamp."""
|
||||
user = User(email="mfa-tracking@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
assert method.last_used_at is None
|
||||
|
||||
# Simulate successful verification and update
|
||||
code = pyotp.TOTP(secret).now()
|
||||
assert verify_totp_code(secret, code) is True
|
||||
|
||||
method.last_used_at = utcnow()
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(MFAMethod, method.id)
|
||||
assert fetched.last_used_at is not None
|
||||
|
||||
|
||||
async def test_mfa_delete_method_allows_login_without_mfa(session, monkeypatch):
|
||||
"""After removing all MFA methods, user should not be redirected to MFA challenge."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
user = User(email="mfa-remove@example.com", password_hash=hash_password("pw"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(name="Temp", type="totp", payload={"secret": secret}, user_id=user.id)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
# MFA exists
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 1
|
||||
|
||||
# Delete it
|
||||
await session.delete(method)
|
||||
await session.flush()
|
||||
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 0
|
||||
|
||||
# Password auth still works
|
||||
authed = await authenticate_user("mfa-remove@example.com", "pw")
|
||||
assert authed is not None
|
||||
|
||||
# No MFA methods — login flow would skip MFA challenge
|
||||
mfa_check = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == authed.id)
|
||||
)).scalars().all()
|
||||
assert len(mfa_check) == 0
|
||||
|
||||
|
||||
async def test_disabled_user_with_mfa_cannot_login(session, monkeypatch):
|
||||
"""Disabled user should be rejected at password stage, never reaching MFA."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
user = User(
|
||||
email="mfa-disabled@example.com",
|
||||
password_hash=hash_password("pw"),
|
||||
disabled_at=utcnow(),
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
session.add(MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
# Password auth rejects disabled user before MFA is ever checked
|
||||
result = await authenticate_user("mfa-disabled@example.com", "pw")
|
||||
assert result is None
|
||||
|
|
@ -1,309 +0,0 @@
|
|||
"""Integration tests for OIDC — mock provider endpoints, test full auth code flow."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import respx
|
||||
from httpx import Response
|
||||
from jose import jwt
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.oidc import get_provider_config, load_providers, oauth, register_providers
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
# --- Helper to create a fake OIDC provider config in the DB ---
|
||||
|
||||
|
||||
async def _setup_oidc_config(session) -> Configuration:
|
||||
"""Insert a Configuration with a test OIDC provider."""
|
||||
config = Configuration(
|
||||
openid_connect_providers=[
|
||||
{
|
||||
"id": "test-idp",
|
||||
"label": "Test IdP",
|
||||
"scope": "openid email profile",
|
||||
"response_type": "code",
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-client-secret",
|
||||
"discovery_document_uri": "https://idp.example.com/.well-known/openid-configuration",
|
||||
"auto_create_users": True,
|
||||
}
|
||||
],
|
||||
)
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
return config
|
||||
|
||||
|
||||
def _mock_discovery():
|
||||
"""Mock OIDC discovery document response."""
|
||||
return {
|
||||
"issuer": "https://idp.example.com",
|
||||
"authorization_endpoint": "https://idp.example.com/authorize",
|
||||
"token_endpoint": "https://idp.example.com/token",
|
||||
"userinfo_endpoint": "https://idp.example.com/userinfo",
|
||||
"jwks_uri": "https://idp.example.com/.well-known/jwks.json",
|
||||
}
|
||||
|
||||
|
||||
def _mock_token_response(email: str = "oidc-user@example.com"):
|
||||
"""Mock OIDC token endpoint response with ID token."""
|
||||
now = int(time.time())
|
||||
id_token_payload = {
|
||||
"iss": "https://idp.example.com",
|
||||
"sub": "oidc-subject-123",
|
||||
"aud": "test-client-id",
|
||||
"email": email,
|
||||
"name": "OIDC User",
|
||||
"iat": now,
|
||||
"exp": now + 3600,
|
||||
"nonce": "test-nonce",
|
||||
}
|
||||
# Sign with a simple secret (in real life this would be RSA)
|
||||
id_token = jwt.encode(id_token_payload, "fake-secret", algorithm="HS256")
|
||||
|
||||
return {
|
||||
"access_token": "mock-access-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"id_token": id_token,
|
||||
}
|
||||
|
||||
|
||||
# --- Provider config loading ---
|
||||
|
||||
|
||||
async def test_load_providers_from_config(session, monkeypatch):
|
||||
"""Providers should be loaded from the Configuration table."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||||
|
||||
await _setup_oidc_config(session)
|
||||
|
||||
providers = await load_providers()
|
||||
assert len(providers) == 1
|
||||
assert providers[0]["id"] == "test-idp"
|
||||
assert providers[0]["client_id"] == "test-client-id"
|
||||
|
||||
|
||||
async def test_load_providers_empty_when_no_config(session, monkeypatch):
|
||||
"""Should return empty list when no Configuration exists."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||||
|
||||
providers = await load_providers()
|
||||
assert providers == []
|
||||
|
||||
|
||||
async def test_get_provider_config_by_id(session, monkeypatch):
|
||||
"""Should find a specific provider by ID."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||||
|
||||
await _setup_oidc_config(session)
|
||||
|
||||
config = await get_provider_config("test-idp")
|
||||
assert config is not None
|
||||
assert config["label"] == "Test IdP"
|
||||
|
||||
config_missing = await get_provider_config("nonexistent")
|
||||
assert config_missing is None
|
||||
|
||||
|
||||
# --- OIDC connection storage ---
|
||||
|
||||
|
||||
async def test_oidc_connection_created_on_login(session):
|
||||
"""Simulates what the callback route does: create user + OIDC connection."""
|
||||
user = User(email="oidc-new@example.com", role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
token_data = _mock_token_response("oidc-new@example.com")
|
||||
conn = OIDCConnection(
|
||||
provider="test-idp",
|
||||
refresh_token=token_data["refresh_token"],
|
||||
refresh_response=token_data,
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
# Verify it was stored
|
||||
fetched = (await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar_one()
|
||||
assert fetched.provider == "test-idp"
|
||||
assert fetched.refresh_token == "mock-refresh-token"
|
||||
assert fetched.refresh_response["access_token"] == "mock-access-token"
|
||||
|
||||
|
||||
async def test_oidc_connection_updated_on_re_login(session):
|
||||
"""Re-login should update the existing OIDC connection, not create a duplicate."""
|
||||
user = User(email="oidc-relogin@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# First login
|
||||
conn = OIDCConnection(
|
||||
provider="test-idp",
|
||||
refresh_token="old-refresh-token",
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
# Re-login — update existing connection (as the callback route does)
|
||||
existing = (await session.execute(
|
||||
select(OIDCConnection).where(
|
||||
OIDCConnection.user_id == user.id,
|
||||
OIDCConnection.provider == "test-idp",
|
||||
)
|
||||
)).scalar_one()
|
||||
|
||||
existing.refresh_token = "new-refresh-token"
|
||||
from wiregui.utils.time import utcnow
|
||||
existing.refreshed_at = utcnow()
|
||||
session.add(existing)
|
||||
await session.flush()
|
||||
|
||||
# Should still be one connection
|
||||
from sqlmodel import func
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 1
|
||||
|
||||
fetched = (await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar_one()
|
||||
assert fetched.refresh_token == "new-refresh-token"
|
||||
|
||||
|
||||
async def test_oidc_auto_create_user(session):
|
||||
"""When auto_create_users is True, a new user should be created from OIDC email."""
|
||||
email = "auto-created@example.com"
|
||||
|
||||
# Verify user doesn't exist
|
||||
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
|
||||
assert existing is None
|
||||
|
||||
# Simulate what callback does with auto_create
|
||||
user = User(email=email, role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
from wiregui.utils.time import utcnow
|
||||
user.last_signed_in_at = utcnow()
|
||||
user.last_signed_in_method = "oidc:test-idp"
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
created = (await session.execute(select(User).where(User.email == email))).scalar_one()
|
||||
assert created.role == "unprivileged"
|
||||
assert created.last_signed_in_method == "oidc:test-idp"
|
||||
|
||||
|
||||
async def test_oidc_disabled_user_rejected(session):
|
||||
"""Disabled users should not be logged in via OIDC."""
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
user = User(email="oidc-disabled@example.com", disabled_at=utcnow())
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# The callback route checks disabled_at before creating session
|
||||
assert user.disabled_at is not None # Would redirect to /login
|
||||
|
||||
|
||||
async def test_oidc_user_without_auto_create_rejected(session):
|
||||
"""When auto_create is False and user doesn't exist, login should fail."""
|
||||
email = "no-auto-create@example.com"
|
||||
|
||||
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
|
||||
assert existing is None
|
||||
|
||||
# The callback route checks auto_create_users from provider config
|
||||
# With auto_create=False and no existing user, it would redirect to /login
|
||||
# This verifies the precondition
|
||||
|
||||
|
||||
# --- OIDC refresh token flow ---
|
||||
|
||||
|
||||
async def test_oidc_refresh_stores_new_token(session):
|
||||
"""Simulates a successful token refresh updating the connection."""
|
||||
user = User(email="oidc-refresh-test@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(
|
||||
provider="test-idp",
|
||||
refresh_token="old-refresh",
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
# Simulate refresh result
|
||||
new_token = {
|
||||
"access_token": "new-access",
|
||||
"refresh_token": "new-refresh",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
conn.refresh_token = new_token.get("refresh_token", conn.refresh_token)
|
||||
conn.refresh_response = new_token
|
||||
from wiregui.utils.time import utcnow
|
||||
conn.refreshed_at = utcnow()
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(OIDCConnection, conn.id)
|
||||
assert fetched.refresh_token == "new-refresh"
|
||||
assert fetched.refresh_response["access_token"] == "new-access"
|
||||
assert fetched.refreshed_at is not None
|
||||
|
||||
|
||||
async def test_oidc_multiple_providers_per_user(session):
|
||||
"""User can have connections to multiple OIDC providers."""
|
||||
user = User(email="multi-provider@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
for provider in ["google", "okta", "azure-ad"]:
|
||||
session.add(OIDCConnection(
|
||||
provider=provider,
|
||||
refresh_token=f"token-{provider}",
|
||||
user_id=user.id,
|
||||
))
|
||||
await session.flush()
|
||||
|
||||
conns = (await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.user_id == user.id).order_by(OIDCConnection.provider)
|
||||
)).scalars().all()
|
||||
|
||||
assert len(conns) == 3
|
||||
assert [c.provider for c in conns] == ["azure-ad", "google", "okta"]
|
||||
|
|
@ -1,34 +1,6 @@
|
|||
"""Tests for magic link authentication flow."""
|
||||
|
||||
from datetime import timedelta
|
||||
"""Tests for magic link authentication — token subject validation."""
|
||||
|
||||
from wiregui.auth.jwt import create_access_token, decode_access_token
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
def test_magic_link_token_creation():
|
||||
"""Magic link token should be a valid JWT with short expiry."""
|
||||
token = create_access_token(
|
||||
user_id="user-123",
|
||||
role="unprivileged",
|
||||
expires_delta=timedelta(minutes=15),
|
||||
)
|
||||
payload = decode_access_token(token)
|
||||
assert payload is not None
|
||||
assert payload["sub"] == "user-123"
|
||||
assert payload["role"] == "unprivileged"
|
||||
|
||||
|
||||
def test_magic_link_token_expired():
|
||||
"""Expired magic link token should be rejected."""
|
||||
token = create_access_token(
|
||||
user_id="user-123",
|
||||
role="admin",
|
||||
expires_delta=timedelta(minutes=-1), # Already expired
|
||||
)
|
||||
payload = decode_access_token(token)
|
||||
assert payload is None
|
||||
|
||||
|
||||
def test_magic_link_token_wrong_user():
|
||||
|
|
@ -37,22 +9,3 @@ def test_magic_link_token_wrong_user():
|
|||
payload = decode_access_token(token)
|
||||
assert payload["sub"] == "user-A"
|
||||
# Caller is responsible for checking sub matches the URL user_id
|
||||
|
||||
|
||||
async def test_magic_link_disabled_user_rejected(session):
|
||||
"""Disabled users should not be able to use magic links."""
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
user = User(
|
||||
email="disabled-magic@example.com",
|
||||
password_hash=hash_password("pw"),
|
||||
disabled_at=utcnow(),
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# The token would be valid but the page handler checks disabled_at
|
||||
token = create_access_token(user_id=str(user.id), role="unprivileged")
|
||||
payload = decode_access_token(token)
|
||||
assert payload is not None # Token itself is valid
|
||||
assert user.disabled_at is not None # But user is disabled — handler would reject
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for TOTP MFA functionality."""
|
||||
"""Tests for TOTP MFA — URI format, edge cases, QR generation, DB relationships."""
|
||||
|
||||
import pyotp
|
||||
|
||||
|
|
@ -12,22 +12,7 @@ from wiregui.models.mfa_method import MFAMethod
|
|||
from wiregui.models.user import User
|
||||
|
||||
|
||||
# --- TOTP secret generation ---
|
||||
|
||||
|
||||
def test_generate_secret():
|
||||
secret = generate_totp_secret()
|
||||
assert len(secret) == 32 # base32 encoded
|
||||
assert secret.isalpha() or any(c.isdigit() for c in secret)
|
||||
|
||||
|
||||
def test_generate_secret_unique():
|
||||
s1 = generate_totp_secret()
|
||||
s2 = generate_totp_secret()
|
||||
assert s1 != s2
|
||||
|
||||
|
||||
# --- TOTP URI ---
|
||||
# --- TOTP URI format ---
|
||||
|
||||
|
||||
def test_get_totp_uri():
|
||||
|
|
@ -43,19 +28,7 @@ def test_get_totp_uri_custom_issuer():
|
|||
assert "issuer=MyVPN" in uri
|
||||
|
||||
|
||||
# --- TOTP verification ---
|
||||
|
||||
|
||||
def test_verify_valid_code():
|
||||
secret = generate_totp_secret()
|
||||
totp = pyotp.TOTP(secret)
|
||||
code = totp.now()
|
||||
assert verify_totp_code(secret, code) is True
|
||||
|
||||
|
||||
def test_verify_invalid_code():
|
||||
secret = generate_totp_secret()
|
||||
assert verify_totp_code(secret, "000000") is False
|
||||
# --- TOTP verification edge cases ---
|
||||
|
||||
|
||||
def test_verify_wrong_secret():
|
||||
|
|
@ -80,34 +53,7 @@ def test_generate_qr_svg():
|
|||
assert "</svg>" in svg
|
||||
|
||||
|
||||
# --- MFA method model integration ---
|
||||
|
||||
|
||||
async def test_create_totp_method(session):
|
||||
user = User(email="mfa-test@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(
|
||||
name="My Phone",
|
||||
type="totp",
|
||||
payload={"secret": secret},
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
from sqlmodel import select
|
||||
fetched = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar_one()
|
||||
|
||||
assert fetched.name == "My Phone"
|
||||
assert fetched.type == "totp"
|
||||
stored_secret = fetched.payload["secret"]
|
||||
code = pyotp.TOTP(stored_secret).now()
|
||||
assert verify_totp_code(stored_secret, code) is True
|
||||
# --- MFA method DB relationships ---
|
||||
|
||||
|
||||
async def test_user_multiple_mfa_methods(session):
|
||||
|
|
|
|||
|
|
@ -1,168 +0,0 @@
|
|||
"""Tests for SQLModel table definitions."""
|
||||
|
||||
import pytest # noqa: F401 — needed for pytest.raises
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.connectivity_check import ConnectivityCheck
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
async def test_create_user(session):
|
||||
user = User(email="alice@example.com", role="admin")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
result = await session.execute(select(User).where(User.email == "alice@example.com"))
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.id == user.id
|
||||
assert fetched.role == "admin"
|
||||
assert fetched.disabled_at is None
|
||||
|
||||
|
||||
async def test_create_device_with_user(session):
|
||||
user = User(email="bob@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(
|
||||
name="laptop",
|
||||
public_key="pk-test-device-001",
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
result = await session.execute(select(Device).where(Device.public_key == "pk-test-device-001"))
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.name == "laptop"
|
||||
assert fetched.user_id == user.id
|
||||
assert fetched.use_default_dns is True
|
||||
assert fetched.use_default_allowed_ips is True
|
||||
assert fetched.rx_bytes is None
|
||||
|
||||
|
||||
async def test_device_unique_public_key(session):
|
||||
user = User(email="carol@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
d1 = Device(name="d1", public_key="duplicate-key", user_id=user.id)
|
||||
session.add(d1)
|
||||
await session.flush()
|
||||
|
||||
d2 = Device(name="d2", public_key="duplicate-key", user_id=user.id)
|
||||
session.add(d2)
|
||||
with pytest.raises(Exception): # IntegrityError
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def test_create_rule(session):
|
||||
user = User(email="dave@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id)
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
|
||||
result = await session.execute(select(Rule).where(Rule.user_id == user.id))
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.action == "accept"
|
||||
assert fetched.destination == "10.0.0.0/8"
|
||||
assert fetched.port_type is None
|
||||
assert fetched.port_range is None
|
||||
|
||||
|
||||
async def test_create_rule_with_port(session):
|
||||
rule = Rule(
|
||||
action="drop",
|
||||
destination="192.168.0.0/16",
|
||||
port_type="tcp",
|
||||
port_range="80-443",
|
||||
)
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(Rule).where(Rule.id == rule.id))).scalar_one()
|
||||
assert fetched.port_type == "tcp"
|
||||
assert fetched.port_range == "80-443"
|
||||
assert fetched.user_id is None # global rule
|
||||
|
||||
|
||||
async def test_create_mfa_method(session):
|
||||
user = User(email="eve@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
mfa = MFAMethod(
|
||||
name="My Authenticator",
|
||||
type="totp",
|
||||
payload={"secret": "JBSWY3DPEHPK3PXP"},
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(mfa)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(MFAMethod).where(MFAMethod.user_id == user.id))).scalar_one()
|
||||
assert fetched.type == "totp"
|
||||
assert fetched.payload["secret"] == "JBSWY3DPEHPK3PXP"
|
||||
|
||||
|
||||
async def test_create_oidc_connection(session):
|
||||
user = User(email="frank@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(provider="google", refresh_token="tok_abc", user_id=user.id)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(OIDCConnection).where(OIDCConnection.user_id == user.id))).scalar_one()
|
||||
assert fetched.provider == "google"
|
||||
assert fetched.refresh_token == "tok_abc"
|
||||
|
||||
|
||||
async def test_create_api_token(session):
|
||||
user = User(email="grace@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
token = ApiToken(token_hash="sha256_fake_hash", user_id=user.id)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(ApiToken).where(ApiToken.user_id == user.id))).scalar_one()
|
||||
assert fetched.token_hash == "sha256_fake_hash"
|
||||
assert fetched.expires_at is None
|
||||
|
||||
|
||||
async def test_create_connectivity_check(session):
|
||||
check = ConnectivityCheck(url="https://example.com", response_code=200)
|
||||
session.add(check)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(ConnectivityCheck).where(ConnectivityCheck.id == check.id))).scalar_one()
|
||||
assert fetched.response_code == 200
|
||||
|
||||
|
||||
async def test_configuration_defaults(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(Configuration).where(Configuration.id == config.id))).scalar_one()
|
||||
assert fetched.allow_unprivileged_device_management is True
|
||||
assert fetched.local_auth_enabled is True
|
||||
assert fetched.default_client_mtu == 1280
|
||||
assert fetched.default_client_persistent_keepalive == 25
|
||||
assert fetched.default_client_dns == ["1.1.1.1", "1.0.0.1"]
|
||||
assert fetched.default_client_allowed_ips == ["0.0.0.0/0", "::/0"]
|
||||
assert fetched.vpn_session_duration == 0
|
||||
assert fetched.openid_connect_providers == []
|
||||
assert fetched.saml_identity_providers == []
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
"""Tests for the notification service."""
|
||||
|
||||
from wiregui.services import notifications
|
||||
|
||||
|
||||
def setup_function():
|
||||
"""Clear notifications before each test."""
|
||||
notifications.clear_all()
|
||||
|
||||
|
||||
def test_add_notification():
|
||||
n = notifications.add("info", "Test message")
|
||||
assert n.severity == "info"
|
||||
assert n.message == "Test message"
|
||||
assert n.user is None
|
||||
assert n.id is not None
|
||||
assert n.timestamp is not None
|
||||
|
||||
|
||||
def test_add_notification_with_user():
|
||||
n = notifications.add("error", "Something broke", user="admin@example.com")
|
||||
assert n.user == "admin@example.com"
|
||||
assert n.severity == "error"
|
||||
|
||||
|
||||
def test_current_returns_newest_first():
|
||||
notifications.add("info", "First")
|
||||
notifications.add("warning", "Second")
|
||||
notifications.add("error", "Third")
|
||||
|
||||
current = notifications.current()
|
||||
assert len(current) == 3
|
||||
assert current[0].message == "Third"
|
||||
assert current[1].message == "Second"
|
||||
assert current[2].message == "First"
|
||||
|
||||
|
||||
def test_count():
|
||||
assert notifications.count() == 0
|
||||
notifications.add("info", "One")
|
||||
notifications.add("info", "Two")
|
||||
assert notifications.count() == 2
|
||||
|
||||
|
||||
def test_clear_specific():
|
||||
n1 = notifications.add("info", "Keep this")
|
||||
n2 = notifications.add("error", "Remove this")
|
||||
|
||||
notifications.clear(n2.id)
|
||||
current = notifications.current()
|
||||
assert len(current) == 1
|
||||
assert current[0].id == n1.id
|
||||
|
||||
|
||||
def test_clear_nonexistent_id_is_noop():
|
||||
notifications.add("info", "Test")
|
||||
notifications.clear("nonexistent-id")
|
||||
assert notifications.count() == 1
|
||||
|
||||
|
||||
def test_clear_all():
|
||||
notifications.add("info", "One")
|
||||
notifications.add("info", "Two")
|
||||
notifications.add("info", "Three")
|
||||
assert notifications.count() == 3
|
||||
|
||||
notifications.clear_all()
|
||||
assert notifications.count() == 0
|
||||
assert notifications.current() == []
|
||||
|
||||
|
||||
def test_to_dict():
|
||||
n = notifications.add("warning", "Test dict", user="someone@example.com")
|
||||
d = n.to_dict()
|
||||
assert d["severity"] == "warning"
|
||||
assert d["message"] == "Test dict"
|
||||
assert d["user"] == "someone@example.com"
|
||||
assert "id" in d
|
||||
assert "timestamp" in d
|
||||
|
||||
|
||||
def test_max_notifications():
|
||||
"""Deque should cap at MAX_NOTIFICATIONS."""
|
||||
for i in range(notifications.MAX_NOTIFICATIONS + 10):
|
||||
notifications.add("info", f"Notification {i}")
|
||||
|
||||
assert notifications.count() == notifications.MAX_NOTIFICATIONS
|
||||
# Newest should be the last one added
|
||||
assert notifications.current()[0].message == f"Notification {notifications.MAX_NOTIFICATIONS + 9}"
|
||||
|
|
@ -2,62 +2,59 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.utils.server_key import get_server_public_key
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _snapshot_config():
|
||||
"""Snapshot and restore server_public_key around each test."""
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||
orig = c.server_public_key if c else None
|
||||
cid = c.id if c else None
|
||||
|
||||
yield
|
||||
|
||||
if cid:
|
||||
async with async_session() as session:
|
||||
c = await session.get(Configuration, cid)
|
||||
if c:
|
||||
c.server_public_key = orig
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def test_get_server_public_key_returns_key():
|
||||
async def test_get_server_public_key_returns_key(session, monkeypatch):
|
||||
"""Returns the public key when configured."""
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||
c.server_public_key = "TestServerPubKey123456789012345678901234w="
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.utils.server_key.async_session", mock_session)
|
||||
|
||||
c = Configuration(server_public_key="TestServerPubKey123456789012345678901234w=")
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
await session.flush()
|
||||
|
||||
result = await get_server_public_key()
|
||||
assert result == "TestServerPubKey123456789012345678901234w="
|
||||
|
||||
|
||||
async def test_get_server_public_key_raises_when_missing():
|
||||
async def test_get_server_public_key_raises_when_missing(session, monkeypatch):
|
||||
"""Raises RuntimeError when server_public_key is None."""
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||
c.server_public_key = None
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.utils.server_key.async_session", mock_session)
|
||||
|
||||
c = Configuration(server_public_key=None)
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
await session.flush()
|
||||
|
||||
with pytest.raises(RuntimeError, match="not configured"):
|
||||
await get_server_public_key()
|
||||
|
||||
|
||||
async def test_get_server_public_key_raises_when_empty_string():
|
||||
async def test_get_server_public_key_raises_when_empty_string(session, monkeypatch):
|
||||
"""Raises RuntimeError when server_public_key is empty string."""
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||
c.server_public_key = ""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.utils.server_key.async_session", mock_session)
|
||||
|
||||
c = Configuration(server_public_key="")
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
await session.flush()
|
||||
|
||||
with pytest.raises(RuntimeError, match="not configured"):
|
||||
await get_server_public_key()
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for services — WireGuard and events."""
|
||||
"""Tests for services — WireGuard event error handling and rule events."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
|
|
@ -20,53 +20,6 @@ def _make_device(**kwargs) -> Device:
|
|||
return Device(**defaults)
|
||||
|
||||
|
||||
# --- Events (with WG enabled) ---
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.firewall")
|
||||
@patch("wiregui.services.events.wireguard")
|
||||
async def test_on_device_created_calls_add_peer(mock_wg, mock_fw, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
mock_fw.add_user_chain = AsyncMock()
|
||||
mock_fw.add_device_jump_rule = AsyncMock()
|
||||
|
||||
device = _make_device()
|
||||
await on_device_created(device)
|
||||
|
||||
mock_wg.add_peer.assert_awaited_once_with(
|
||||
public_key="pk-test",
|
||||
allowed_ips=["10.3.2.5/32", "fd00::3:2:5/128"],
|
||||
preshared_key="psk-test",
|
||||
)
|
||||
mock_fw.add_device_jump_rule.assert_awaited_once()
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.wireguard")
|
||||
async def test_on_device_deleted_calls_remove_peer(mock_wg, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
device = _make_device()
|
||||
await on_device_deleted(device)
|
||||
|
||||
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-test")
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.wireguard")
|
||||
async def test_on_device_updated_calls_add_peer(mock_wg, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
|
||||
device = _make_device()
|
||||
await on_device_updated(device)
|
||||
|
||||
mock_wg.add_peer.assert_awaited_once()
|
||||
|
||||
|
||||
# --- Events (WG disabled) ---
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue