feat: initial WireGUI implementation — full VPN management platform
Some checks failed
CI / test (push) Failing after 26s
CI / release (push) Has been skipped
CI / docker (push) Has been skipped

Complete Python/NiceGUI rewrite of the Wirezone (Elixir/Phoenix) VPN
management platform. All 10 implementation phases delivered.

Core stack:
- NiceGUI reactive UI with SQLModel ORM on PostgreSQL (asyncpg)
- Alembic migrations, Valkey/Redis cache, pydantic-settings config
- WireGuard management via subprocess (wg/ip/nft CLIs)
- 164 tests passing, 35% code coverage

Features:
- User/device/rule CRUD with admin and unprivileged roles
- Full device config form with per-device WG overrides
- WireGuard client config generation with QR codes
- REST API (v0) with Bearer token auth for all resources
- TOTP MFA with QR registration and challenge flow
- OIDC SSO with authlib (provider registry, auto-create users)
- Magic link passwordless sign-in via email
- SAML SP-initiated SSO with IdP metadata parsing
- WebAuthn/FIDO2 security key registration
- nftables firewall with per-user chains and masquerade
- Background tasks: WG stats polling, VPN session expiry,
  OIDC token refresh, WAN connectivity checks
- Startup reconciliation (DB ↔ WireGuard state sync)
- In-memory notification system with header badge
- Admin UI: users, devices, rules, settings (3 tabs), diagnostics
- Loguru logging with optional timestamped file output

Deployment:
- Multi-stage Dockerfile (python:3.13-slim)
- Docker Compose prod stack (bridge networking, NET_ADMIN, nftables)
- Forgejo CI: tests → semantic versioning → Docker registry push
- Health endpoint at /api/health
This commit is contained in:
Stefano Bertelli 2026-03-30 16:53:46 -05:00
commit 0546b44507
109 changed files with 11793 additions and 0 deletions

0
tests/__init__.py Normal file
View file

65
tests/conftest.py Normal file
View file

@ -0,0 +1,65 @@
"""Shared test fixtures — async DB session using a test database."""
from collections.abc import AsyncGenerator
import pytest
import pytest_asyncio
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlmodel import SQLModel
from wiregui.config import get_settings
# All models must be imported so SQLModel.metadata knows about them
from wiregui.models import * # noqa: F401, F403
def _test_database_url() -> str:
url = get_settings().database_url
base, _dbname = url.rsplit("/", 1)
return f"{base}/wiregui_test"
TEST_DATABASE_URL = _test_database_url()
# Module-level engine creation (runs once via autouse session fixture)
_engine = None
def _ensure_test_db_sync():
"""Ensure wiregui_test database exists (called once)."""
import asyncio
async def _create():
base_url = get_settings().database_url.rsplit("/", 1)[0] + "/postgres"
admin_engine = create_async_engine(base_url, isolation_level="AUTOCOMMIT")
async with admin_engine.connect() as conn:
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = 'wiregui_test'")
)
if result.scalar() is None:
await conn.execute(text("CREATE DATABASE wiregui_test"))
await admin_engine.dispose()
asyncio.run(_create())
# Create test DB once at import time
_ensure_test_db_sync()
@pytest_asyncio.fixture
async def session() -> AsyncGenerator[AsyncSession]:
"""Fresh engine + session per test, with table setup/teardown."""
engine = create_async_engine(TEST_DATABASE_URL)
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
factory = async_sessionmaker(engine, expire_on_commit=False)
async with factory() as sess:
yield sess
await sess.rollback()
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
await engine.dispose()

161
tests/test_account.py Normal file
View file

@ -0,0 +1,161 @@
"""Tests for account functionality — password changes, API tokens, OIDC connections."""
import hashlib
from datetime import timedelta
from sqlmodel import func, select
from wiregui.auth.api_token import generate_api_token
from wiregui.auth.passwords import hash_password, verify_password
from wiregui.models.api_token import ApiToken
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# --- Password change ---
async def test_password_change_flow(session):
"""Simulate the password change flow: verify old, set new."""
user = User(email="pw-change@example.com", password_hash=hash_password("old-password"))
session.add(user)
await session.flush()
# Verify old password
assert verify_password("old-password", user.password_hash) is True
# Change password
user.password_hash = hash_password("new-password")
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert verify_password("new-password", fetched.password_hash) is True
assert verify_password("old-password", fetched.password_hash) is False
async def test_password_change_wrong_current(session):
"""Wrong current password should not allow change."""
user = User(email="pw-wrong@example.com", password_hash=hash_password("correct"))
session.add(user)
await session.flush()
# Simulate check
assert verify_password("wrong", user.password_hash) is False
# --- API token management ---
async def test_create_multiple_tokens(session):
user = User(email="multi-token@example.com")
session.add(user)
await session.flush()
for _ in range(3):
_, token_hash = generate_api_token()
session.add(ApiToken(token_hash=token_hash, user_id=user.id))
await session.flush()
count = (await session.execute(
select(func.count()).select_from(ApiToken).where(ApiToken.user_id == user.id)
)).scalar()
assert count == 3
async def test_token_with_expiry(session):
user = User(email="expiry-token@example.com")
session.add(user)
await session.flush()
_, token_hash = generate_api_token()
expires = utcnow() + timedelta(days=30)
token = ApiToken(token_hash=token_hash, expires_at=expires, user_id=user.id)
session.add(token)
await session.flush()
fetched = await session.get(ApiToken, token.id)
assert fetched.expires_at is not None
assert fetched.expires_at > utcnow()
async def test_delete_token(session):
user = User(email="del-token@example.com")
session.add(user)
await session.flush()
_, token_hash = generate_api_token()
token = ApiToken(token_hash=token_hash, user_id=user.id)
session.add(token)
await session.flush()
await session.delete(token)
await session.flush()
assert await session.get(ApiToken, token.id) is None
# --- OIDC connections ---
async def test_oidc_connection_create(session):
user = User(email="oidc-conn@example.com")
session.add(user)
await session.flush()
conn = OIDCConnection(
provider="google",
refresh_token="refresh-tok-123",
refresh_response={"access_token": "at", "token_type": "Bearer"},
refreshed_at=utcnow(),
user_id=user.id,
)
session.add(conn)
await session.flush()
fetched = (await session.execute(
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar_one()
assert fetched.provider == "google"
assert fetched.refresh_token == "refresh-tok-123"
assert fetched.refresh_response["access_token"] == "at"
async def test_multiple_oidc_providers(session):
user = User(email="multi-oidc@example.com")
session.add(user)
await session.flush()
for provider in ["google", "okta", "azure"]:
conn = OIDCConnection(provider=provider, user_id=user.id)
session.add(conn)
await session.flush()
count = (await session.execute(
select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar()
assert count == 3
async def test_oidc_connection_update_refresh_token(session):
user = User(email="oidc-refresh@example.com")
session.add(user)
await session.flush()
conn = OIDCConnection(
provider="google",
refresh_token="old-token",
user_id=user.id,
)
session.add(conn)
await session.flush()
conn.refresh_token = "new-token"
conn.refreshed_at = utcnow()
session.add(conn)
await session.flush()
fetched = await session.get(OIDCConnection, conn.id)
assert fetched.refresh_token == "new-token"
assert fetched.refreshed_at is not None

283
tests/test_admin.py Normal file
View file

@ -0,0 +1,283 @@
"""Tests for admin functionality — user management, configuration, cascading deletes."""
import pytest
from sqlmodel import func, select
from wiregui.auth.passwords import hash_password, verify_password
from wiregui.models.api_token import ApiToken
from wiregui.models.configuration import Configuration
from wiregui.models.device import Device
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.rule import Rule
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# --- User CRUD ---
async def test_create_user_with_role(session):
user = User(email="new-admin@test.com", password_hash=hash_password("secret"), role="admin")
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert fetched.role == "admin"
assert verify_password("secret", fetched.password_hash)
async def test_update_user_email(session):
user = User(email="old@test.com", password_hash=hash_password("pw"))
session.add(user)
await session.flush()
user.email = "new@test.com"
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert fetched.email == "new@test.com"
async def test_disable_user(session):
user = User(email="active@test.com", password_hash=hash_password("pw"))
session.add(user)
await session.flush()
assert user.disabled_at is None
user.disabled_at = utcnow()
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert fetched.disabled_at is not None
async def test_promote_demote_user(session):
user = User(email="user@test.com", role="unprivileged")
session.add(user)
await session.flush()
assert user.role == "unprivileged"
user.role = "admin"
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert fetched.role == "admin"
user.role = "unprivileged"
session.add(user)
await session.flush()
assert (await session.get(User, user.id)).role == "unprivileged"
# --- Cascading delete (manual, as we do it in the admin page) ---
async def test_delete_user_cascades_devices(session):
user = User(email="cascade@test.com")
session.add(user)
await session.flush()
d1 = Device(name="d1", public_key="pk-cascade-1", ipv4="10.0.0.1", user_id=user.id)
d2 = Device(name="d2", public_key="pk-cascade-2", ipv4="10.0.0.2", user_id=user.id)
session.add_all([d1, d2])
await session.flush()
# Manually delete devices then user (matching admin page behavior)
devices = (await session.execute(select(Device).where(Device.user_id == user.id))).scalars().all()
for d in devices:
await session.delete(d)
await session.delete(user)
await session.flush()
assert (await session.execute(select(func.count()).select_from(Device).where(Device.user_id == user.id))).scalar() == 0
assert await session.get(User, user.id) is None
async def test_delete_user_cascades_rules(session):
user = User(email="rule-cascade@test.com")
session.add(user)
await session.flush()
rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id)
session.add(rule)
await session.flush()
# Delete rules then user
rules = (await session.execute(select(Rule).where(Rule.user_id == user.id))).scalars().all()
for r in rules:
await session.delete(r)
await session.delete(user)
await session.flush()
assert (await session.execute(select(func.count()).select_from(Rule).where(Rule.user_id == user.id))).scalar() == 0
# --- Configuration singleton ---
async def test_configuration_create_and_update(session):
config = Configuration()
session.add(config)
await session.flush()
assert config.default_client_mtu == 1280
assert config.local_auth_enabled is True
config.default_client_mtu = 1400
config.local_auth_enabled = False
config.vpn_session_duration = 3600
session.add(config)
await session.flush()
fetched = await session.get(Configuration, config.id)
assert fetched.default_client_mtu == 1400
assert fetched.local_auth_enabled is False
assert fetched.vpn_session_duration == 3600
async def test_configuration_oidc_providers(session):
config = Configuration()
session.add(config)
await session.flush()
assert config.openid_connect_providers == []
providers = [
{
"id": "google",
"label": "Sign in with Google",
"scope": "openid email profile",
"response_type": "code",
"client_id": "google-client-id",
"client_secret": "google-secret",
"discovery_document_uri": "https://accounts.google.com/.well-known/openid-configuration",
"auto_create_users": True,
},
{
"id": "okta",
"label": "Okta SSO",
"scope": "openid email profile",
"response_type": "code",
"client_id": "okta-client-id",
"client_secret": "okta-secret",
"discovery_document_uri": "https://dev-123.okta.com/.well-known/openid-configuration",
"auto_create_users": False,
},
]
config.openid_connect_providers = providers
session.add(config)
await session.flush()
fetched = await session.get(Configuration, config.id)
assert len(fetched.openid_connect_providers) == 2
assert fetched.openid_connect_providers[0]["id"] == "google"
assert fetched.openid_connect_providers[1]["auto_create_users"] is False
async def test_configuration_update_client_defaults(session):
config = Configuration()
session.add(config)
await session.flush()
config.default_client_endpoint = "vpn.example.com"
config.default_client_dns = ["8.8.8.8", "8.8.4.4"]
config.default_client_allowed_ips = ["10.0.0.0/8"]
config.default_client_persistent_keepalive = 30
session.add(config)
await session.flush()
fetched = await session.get(Configuration, config.id)
assert fetched.default_client_endpoint == "vpn.example.com"
assert fetched.default_client_dns == ["8.8.8.8", "8.8.4.4"]
assert fetched.default_client_allowed_ips == ["10.0.0.0/8"]
assert fetched.default_client_persistent_keepalive == 30
async def test_configuration_security_toggles(session):
config = Configuration()
session.add(config)
await session.flush()
config.allow_unprivileged_device_management = False
config.allow_unprivileged_device_configuration = False
config.disable_vpn_on_oidc_error = True
session.add(config)
await session.flush()
fetched = await session.get(Configuration, config.id)
assert fetched.allow_unprivileged_device_management is False
assert fetched.allow_unprivileged_device_configuration is False
assert fetched.disable_vpn_on_oidc_error is True
# --- Device config overrides ---
async def test_device_with_custom_config(session):
user = User(email="config-user@test.com")
session.add(user)
await session.flush()
device = Device(
name="custom-config",
public_key="pk-custom-config",
user_id=user.id,
use_default_dns=False,
use_default_endpoint=False,
use_default_mtu=False,
use_default_persistent_keepalive=False,
use_default_allowed_ips=False,
dns=["8.8.8.8"],
endpoint="custom-vpn.example.com",
mtu=1400,
persistent_keepalive=15,
allowed_ips=["10.0.0.0/8", "172.16.0.0/12"],
)
session.add(device)
await session.flush()
fetched = await session.get(Device, device.id)
assert fetched.use_default_dns is False
assert fetched.dns == ["8.8.8.8"]
assert fetched.endpoint == "custom-vpn.example.com"
assert fetched.mtu == 1400
assert fetched.persistent_keepalive == 15
assert fetched.allowed_ips == ["10.0.0.0/8", "172.16.0.0/12"]
async def test_device_default_flags_are_true(session):
user = User(email="defaults@test.com")
session.add(user)
await session.flush()
device = Device(name="defaults", public_key="pk-defaults", user_id=user.id)
session.add(device)
await session.flush()
fetched = await session.get(Device, device.id)
assert fetched.use_default_allowed_ips is True
assert fetched.use_default_dns is True
assert fetched.use_default_endpoint is True
assert fetched.use_default_mtu is True
assert fetched.use_default_persistent_keepalive is True
# --- User device count ---
async def test_user_device_count_query(session):
user = User(email="count-user@test.com")
session.add(user)
await session.flush()
for i in range(3):
session.add(Device(name=f"d{i}", public_key=f"pk-count-{i}", user_id=user.id))
await session.flush()
count = (await session.execute(
select(func.count()).select_from(Device).where(Device.user_id == user.id)
)).scalar()
assert count == 3

86
tests/test_api.py Normal file
View file

@ -0,0 +1,86 @@
"""Tests for REST API endpoints and token auth."""
import hashlib
from wiregui.auth.api_token import generate_api_token, resolve_bearer_token
from wiregui.auth.passwords import hash_password
from wiregui.models.api_token import ApiToken
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# --- Token generation ---
def test_generate_api_token():
plaintext, token_hash = generate_api_token()
assert len(plaintext) > 20
assert token_hash == hashlib.sha256(plaintext.encode()).hexdigest()
def test_generate_api_token_unique():
t1, h1 = generate_api_token()
t2, h2 = generate_api_token()
assert t1 != t2
assert h1 != h2
# --- Token resolution ---
async def test_resolve_valid_token(session):
user = User(email="api-user@example.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.flush()
plaintext, token_hash = generate_api_token()
token = ApiToken(token_hash=token_hash, user_id=user.id)
session.add(token)
await session.flush()
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is not None
assert resolved.id == user.id
async def test_resolve_invalid_token(session):
resolved = await resolve_bearer_token(session, "bogus-token")
assert resolved is None
async def test_resolve_expired_token(session):
from datetime import timedelta
user = User(email="expired-api@example.com", password_hash=hash_password("x"))
session.add(user)
await session.flush()
plaintext, token_hash = generate_api_token()
token = ApiToken(
token_hash=token_hash,
user_id=user.id,
expires_at=utcnow() - timedelta(hours=1),
)
session.add(token)
await session.flush()
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is None
async def test_resolve_token_disabled_user(session):
user = User(
email="disabled-api@example.com",
password_hash=hash_password("x"),
disabled_at=utcnow(),
)
session.add(user)
await session.flush()
plaintext, token_hash = generate_api_token()
token = ApiToken(token_hash=token_hash, user_id=user.id)
session.add(token)
await session.flush()
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is None

325
tests/test_api_routes.py Normal file
View file

@ -0,0 +1,325 @@
"""Tests for REST API routes via httpx AsyncClient against the FastAPI app."""
import hashlib
from uuid import UUID, uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from httpx import ASGITransport, AsyncClient
from sqlmodel import select
from wiregui.api.deps import get_current_api_user, get_db, require_admin
from wiregui.api.v0 import router as api_router
from wiregui.auth.api_token import generate_api_token
from wiregui.auth.passwords import hash_password
from wiregui.models.api_token import ApiToken
from wiregui.models.configuration import Configuration
from wiregui.models.device import Device
from wiregui.models.rule import Rule
from wiregui.models.user import User
def _build_app(session, admin_user=None, regular_user=None):
"""Build a test FastAPI app with overridden dependencies."""
test_app = FastAPI()
test_app.include_router(api_router, prefix="/api")
async def override_get_db():
yield session
test_app.dependency_overrides[get_db] = override_get_db
if admin_user:
test_app.dependency_overrides[get_current_api_user] = lambda: admin_user
test_app.dependency_overrides[require_admin] = lambda: admin_user
return test_app
async def _make_admin(session) -> User:
user = User(email="api-admin@test.com", password_hash=hash_password("pw"), role="admin")
session.add(user)
await session.flush()
return user
async def _make_user(session, email="api-user@test.com") -> User:
user = User(email=email, password_hash=hash_password("pw"), role="unprivileged")
session.add(user)
await session.flush()
return user
# ========== Users API ==========
async def test_list_users(session):
admin = await _make_admin(session)
await _make_user(session, "user1@test.com")
await _make_user(session, "user2@test.com")
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/v0/users/")
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 3 # admin + 2 users
async def test_get_user(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get(f"/api/v0/users/{admin.id}")
assert resp.status_code == 200
assert resp.json()["email"] == "api-admin@test.com"
async def test_get_user_not_found(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get(f"/api/v0/users/{uuid4()}")
assert resp.status_code == 404
async def test_create_user(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/api/v0/users/", json={
"email": "new-api-user@test.com",
"password": "secret123",
"role": "unprivileged",
})
assert resp.status_code == 201
data = resp.json()
assert data["email"] == "new-api-user@test.com"
assert data["role"] == "unprivileged"
assert "id" in data
async def test_update_user(session):
admin = await _make_admin(session)
user = await _make_user(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put(f"/api/v0/users/{user.id}", json={
"role": "admin",
})
assert resp.status_code == 200
assert resp.json()["role"] == "admin"
async def test_update_user_password(session):
admin = await _make_admin(session)
user = await _make_user(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put(f"/api/v0/users/{user.id}", json={
"password": "new-password-123",
})
assert resp.status_code == 200
from wiregui.auth.passwords import verify_password
refreshed = await session.get(User, user.id)
assert verify_password("new-password-123", refreshed.password_hash)
async def test_delete_user(session):
admin = await _make_admin(session)
user = await _make_user(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.delete(f"/api/v0/users/{user.id}")
assert resp.status_code == 204
assert await session.get(User, user.id) is None
# ========== Devices API ==========
async def test_list_devices_admin_sees_all(session):
admin = await _make_admin(session)
user = await _make_user(session)
session.add(Device(name="d1", public_key="pk-api-d1", user_id=admin.id))
session.add(Device(name="d2", public_key="pk-api-d2", user_id=user.id))
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/v0/devices/")
assert resp.status_code == 200
assert len(resp.json()) >= 2
async def test_list_devices_user_sees_own(session):
admin = await _make_admin(session)
user = await _make_user(session, "own-devices@test.com")
session.add(Device(name="mine", public_key="pk-api-mine", user_id=user.id))
session.add(Device(name="not-mine", public_key="pk-api-notmine", user_id=admin.id))
await session.flush()
# Override to be the regular user
test_app = _build_app(session)
test_app.dependency_overrides[get_current_api_user] = lambda: user
async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
resp = await client.get("/api/v0/devices/")
assert resp.status_code == 200
names = [d["name"] for d in resp.json()]
assert "mine" in names
assert "not-mine" not in names
async def test_get_device(session):
admin = await _make_admin(session)
device = Device(name="detail", public_key="pk-api-detail", user_id=admin.id, ipv4="10.0.0.5")
session.add(device)
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get(f"/api/v0/devices/{device.id}")
assert resp.status_code == 200
assert resp.json()["name"] == "detail"
assert resp.json()["ipv4"] == "10.0.0.5"
async def test_get_device_forbidden_for_other_user(session):
admin = await _make_admin(session)
user = await _make_user(session, "other-dev@test.com")
device = Device(name="admin-dev", public_key="pk-api-forbid", user_id=admin.id)
session.add(device)
await session.flush()
test_app = _build_app(session)
test_app.dependency_overrides[get_current_api_user] = lambda: user
async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
resp = await client.get(f"/api/v0/devices/{device.id}")
assert resp.status_code == 403
async def test_update_device(session):
admin = await _make_admin(session)
device = Device(name="old-name", public_key="pk-api-update", user_id=admin.id)
session.add(device)
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put(f"/api/v0/devices/{device.id}", json={"name": "new-name"})
assert resp.status_code == 200
assert resp.json()["name"] == "new-name"
async def test_delete_device(session):
admin = await _make_admin(session)
device = Device(name="to-delete", public_key="pk-api-del", user_id=admin.id)
session.add(device)
await session.flush()
did = device.id
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.delete(f"/api/v0/devices/{did}")
assert resp.status_code == 204
assert await session.get(Device, did) is None
# ========== Rules API ==========
async def test_list_rules(session):
admin = await _make_admin(session)
session.add(Rule(action="accept", destination="10.0.0.0/8"))
session.add(Rule(action="drop", destination="192.168.0.0/16", user_id=admin.id))
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/v0/rules/")
assert resp.status_code == 200
assert len(resp.json()) >= 2
async def test_create_rule(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/api/v0/rules/", json={
"action": "accept",
"destination": "172.16.0.0/12",
"port_type": "tcp",
"port_range": "443",
})
assert resp.status_code == 201
data = resp.json()
assert data["action"] == "accept"
assert data["destination"] == "172.16.0.0/12"
assert data["port_type"] == "tcp"
assert data["port_range"] == "443"
async def test_update_rule(session):
admin = await _make_admin(session)
rule = Rule(action="accept", destination="10.0.0.0/8")
session.add(rule)
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put(f"/api/v0/rules/{rule.id}", json={"action": "drop"})
assert resp.status_code == 200
assert resp.json()["action"] == "drop"
async def test_delete_rule(session):
admin = await _make_admin(session)
rule = Rule(action="drop", destination="0.0.0.0/0")
session.add(rule)
await session.flush()
rid = rule.id
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.delete(f"/api/v0/rules/{rid}")
assert resp.status_code == 204
assert await session.get(Rule, rid) is None
# ========== Configuration API ==========
async def test_get_configuration_auto_creates(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/v0/configuration/")
assert resp.status_code == 200
data = resp.json()
assert data["default_client_mtu"] == 1280
assert data["local_auth_enabled"] is True
async def test_update_configuration(session):
admin = await _make_admin(session)
# Pre-create config
config = Configuration()
session.add(config)
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put("/api/v0/configuration/", json={
"default_client_mtu": 1400,
"vpn_session_duration": 3600,
"default_client_dns": ["8.8.8.8"],
})
assert resp.status_code == 200
data = resp.json()
assert data["default_client_mtu"] == 1400
assert data["vpn_session_duration"] == 3600
assert data["default_client_dns"] == ["8.8.8.8"]

98
tests/test_auth.py Normal file
View file

@ -0,0 +1,98 @@
"""Tests for authentication modules."""
from sqlmodel import select
from wiregui.auth.jwt import create_access_token, decode_access_token
from wiregui.auth.passwords import hash_password, verify_password
from wiregui.auth.seed import seed_admin
from wiregui.models.user import User
# --- Password hashing ---
def test_hash_and_verify():
hashed = hash_password("my-secret")
assert verify_password("my-secret", hashed) is True
def test_verify_wrong_password():
hashed = hash_password("correct")
assert verify_password("wrong", hashed) is False
def test_hash_is_not_plaintext():
hashed = hash_password("plaintext")
assert hashed != "plaintext"
assert hashed.startswith("$2b$")
# --- JWT ---
def test_create_and_decode_token():
token = create_access_token(user_id="user-123", role="admin")
payload = decode_access_token(token)
assert payload is not None
assert payload["sub"] == "user-123"
assert payload["role"] == "admin"
assert "exp" in payload
def test_decode_invalid_token():
assert decode_access_token("garbage.token.value") is None
def test_decode_tampered_token():
token = create_access_token(user_id="user-123", role="admin")
tampered = token[:-4] + "XXXX"
assert decode_access_token(tampered) is None
# --- Admin seed ---
async def test_seed_admin_creates_user(session, monkeypatch):
"""seed_admin should create an admin when no users exist."""
# Patch async_session to use our test session
from unittest.mock import AsyncMock
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.seed.async_session", mock_session)
monkeypatch.setattr("wiregui.auth.seed.get_settings", lambda: type("S", (), {
"admin_email": "seed-test@example.com",
"admin_password": "seed-pass-123",
})())
await seed_admin()
result = await session.execute(select(User).where(User.email == "seed-test@example.com"))
admin = result.scalar_one()
assert admin.role == "admin"
assert verify_password("seed-pass-123", admin.password_hash)
async def test_seed_admin_skips_when_users_exist(session, monkeypatch):
"""seed_admin should not create a second admin if users already exist."""
from contextlib import asynccontextmanager
existing = User(email="existing@example.com", role="unprivileged")
session.add(existing)
await session.flush()
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.seed.async_session", mock_session)
await seed_admin()
result = await session.execute(select(User))
users = result.scalars().all()
assert len(users) == 1
assert users[0].email == "existing@example.com"

226
tests/test_auth_extended.py Normal file
View file

@ -0,0 +1,226 @@
"""Extended auth tests — OIDC registration, WebAuthn options, session edge cases."""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
from wiregui.auth.passwords import hash_password
from wiregui.auth.session import authenticate_user
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# ========== Session / authenticate_user edge cases ==========
async def test_authenticate_user_no_password_hash(session, monkeypatch):
"""Users without a password (OIDC-only) should not authenticate via password."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
user = User(email="no-pw@test.com", password_hash=None)
session.add(user)
await session.flush()
result = await authenticate_user("no-pw@test.com", "anything")
assert result is None
async def test_authenticate_user_disabled(session, monkeypatch):
"""Disabled users should not authenticate."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
user = User(email="disabled-auth@test.com", password_hash=hash_password("pw"), disabled_at=utcnow())
session.add(user)
await session.flush()
result = await authenticate_user("disabled-auth@test.com", "pw")
assert result is None
async def test_authenticate_user_nonexistent(session, monkeypatch):
"""Nonexistent email should return None."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
result = await authenticate_user("ghost@nowhere.com", "pw")
assert result is None
# ========== OIDC provider registration ==========
async def test_register_providers_from_config(session, monkeypatch):
"""register_providers should register configured OIDC providers with authlib."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
from wiregui.models.configuration import Configuration
config = Configuration(openid_connect_providers=[
{
"id": "test-reg",
"label": "Test",
"scope": "openid email",
"client_id": "cid",
"client_secret": "cs",
"discovery_document_uri": "https://idp.test/.well-known/openid-configuration",
}
])
session.add(config)
await session.flush()
with patch("wiregui.auth.oidc.oauth") as mock_oauth:
from wiregui.auth.oidc import register_providers
await register_providers()
mock_oauth.register.assert_called_once()
call_kwargs = mock_oauth.register.call_args[1]
assert call_kwargs["name"] == "test-reg"
assert call_kwargs["client_id"] == "cid"
async def test_get_client_unknown_provider():
"""get_client should raise for unregistered providers."""
import pytest
from wiregui.auth.oidc import get_client
with pytest.raises(ValueError, match="not registered"):
get_client("nonexistent-provider-xyz")
# ========== WebAuthn options ==========
def test_webauthn_registration_options(monkeypatch):
"""create_registration_options should return valid options and challenge."""
monkeypatch.setattr("wiregui.auth.webauthn.get_settings", lambda: type("S", (), {
"external_url": "https://vpn.example.com",
})())
from wiregui.auth.webauthn import create_registration_options
user_id = uuid4()
result = create_registration_options(user_id, "user@example.com")
assert "options_json" in result
assert "challenge" in result
assert len(result["challenge"]) > 10
assert "user@example.com" in result["options_json"]
def test_webauthn_registration_options_with_excludes(monkeypatch):
"""Existing credentials should be excluded from registration options."""
monkeypatch.setattr("wiregui.auth.webauthn.get_settings", lambda: type("S", (), {
"external_url": "https://vpn.example.com",
})())
from wiregui.auth.webauthn import create_registration_options
existing = [{"credential_id": "AQIDBA"}] # base64url of bytes [1,2,3,4]
result = create_registration_options(uuid4(), "user@example.com", existing)
assert "options_json" in result
def test_webauthn_authentication_options(monkeypatch):
"""create_authentication_options should accept credential descriptors."""
monkeypatch.setattr("wiregui.auth.webauthn.get_settings", lambda: type("S", (), {
"external_url": "https://vpn.example.com",
})())
from wiregui.auth.webauthn import create_authentication_options
credentials = [{"credential_id": "AQIDBA"}]
result = create_authentication_options(credentials)
assert "options_json" in result
assert "challenge" in result
# ========== Events — rule update/delete with rebuild ==========
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.firewall")
async def test_on_rule_updated_triggers_rebuild(mock_fw, mock_settings):
"""on_rule_updated should rebuild the user's firewall chain."""
mock_settings.return_value.wg_enabled = True
mock_fw.rebuild_all_rules = AsyncMock()
from wiregui.models.rule import Rule
from wiregui.services.events import on_rule_updated
# Need to mock the DB call inside _rebuild_user_chain
with patch("wiregui.services.events.async_session") as mock_session_factory:
mock_session = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
# Mock the select results
mock_rules_result = MagicMock()
mock_rules_result.scalars.return_value.all.return_value = []
mock_devices_result = MagicMock()
mock_devices_result.scalars.return_value.all.return_value = []
mock_session.execute = AsyncMock(side_effect=[mock_rules_result, mock_devices_result])
mock_session_factory.return_value = mock_session
rule = Rule(action="accept", destination="10.0.0.0/8", user_id="a1b2c3d4-0000-0000-0000-000000000000")
await on_rule_updated(rule)
mock_fw.rebuild_all_rules.assert_awaited_once()
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.firewall")
async def test_on_rule_deleted_triggers_rebuild(mock_fw, mock_settings):
"""on_rule_deleted should rebuild the user's firewall chain."""
mock_settings.return_value.wg_enabled = True
mock_fw.rebuild_all_rules = AsyncMock()
from wiregui.models.rule import Rule
from wiregui.services.events import on_rule_deleted
with patch("wiregui.services.events.async_session") as mock_session_factory:
mock_session = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
mock_rules_result = MagicMock()
mock_rules_result.scalars.return_value.all.return_value = []
mock_devices_result = MagicMock()
mock_devices_result.scalars.return_value.all.return_value = []
mock_session.execute = AsyncMock(side_effect=[mock_rules_result, mock_devices_result])
mock_session_factory.return_value = mock_session
rule = Rule(action="drop", destination="0.0.0.0/0", user_id="a1b2c3d4-0000-0000-0000-000000000000")
await on_rule_deleted(rule)
mock_fw.rebuild_all_rules.assert_awaited_once()
@patch("wiregui.services.events.get_settings")
async def test_on_rule_deleted_skips_when_disabled(mock_settings):
"""Rule events should be no-ops when WG is disabled."""
mock_settings.return_value.wg_enabled = False
from wiregui.models.rule import Rule
from wiregui.services.events import on_rule_deleted, on_rule_updated
rule = Rule(action="drop", destination="0.0.0.0/0", user_id="a1b2c3d4-0000-0000-0000-000000000000")
await on_rule_updated(rule) # Should not raise
await on_rule_deleted(rule) # Should not raise

40
tests/test_firewall.py Normal file
View file

@ -0,0 +1,40 @@
"""Tests for firewall service — rule expression building and chain naming."""
from wiregui.services.firewall import _build_rule_expr, _user_chain_name
def test_user_chain_name():
uid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
name = _user_chain_name(uid)
assert name == "user_a1b2c3d4e5f6"
assert len(name) <= 30
def test_user_chain_name_deterministic():
uid = "12345678-1234-1234-1234-123456789abc"
assert _user_chain_name(uid) == _user_chain_name(uid)
def test_build_rule_expr_ipv4_accept():
expr = _build_rule_expr("10.0.0.0/8", "accept")
assert expr == "ip daddr 10.0.0.0/8 accept"
def test_build_rule_expr_ipv6_drop():
expr = _build_rule_expr("fd00::/64", "drop")
assert expr == "ip6 daddr fd00::/64 drop"
def test_build_rule_expr_with_port():
expr = _build_rule_expr("192.168.0.0/16", "accept", port_type="tcp", port_range="80-443")
assert expr == "ip daddr 192.168.0.0/16 tcp dport 80-443 accept"
def test_build_rule_expr_single_port():
expr = _build_rule_expr("10.0.0.1/32", "drop", port_type="udp", port_range="53")
assert expr == "ip daddr 10.0.0.1/32 udp dport 53 drop"
def test_build_rule_expr_no_port():
expr = _build_rule_expr("0.0.0.0/0", "accept", port_type=None, port_range=None)
assert expr == "ip daddr 0.0.0.0/0 accept"

View file

@ -0,0 +1,239 @@
"""Integration tests for MFA — full registration and authentication flows through the database."""
import pyotp
from sqlmodel import func, select
from wiregui.auth.mfa import generate_totp_secret, verify_totp_code
from wiregui.auth.passwords import hash_password, verify_password
from wiregui.auth.session import authenticate_user
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.user import User
from wiregui.utils.time import utcnow
async def test_full_totp_registration_flow(session, monkeypatch):
"""End-to-end: create user → generate secret → verify code → store method → re-verify from DB."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
# Create user with password
user = User(email="mfa-flow@example.com", password_hash=hash_password("secure123"))
session.add(user)
await session.flush()
# Step 1: Generate TOTP secret (happens in account page)
secret = generate_totp_secret()
# Step 2: User scans QR, enters code from their authenticator
totp = pyotp.TOTP(secret)
code = totp.now()
# Step 3: Verify the code is correct before saving
assert verify_totp_code(secret, code) is True
# Step 4: Save the MFA method to DB
method = MFAMethod(
name="My Authenticator",
type="totp",
payload={"secret": secret},
user_id=user.id,
)
session.add(method)
await session.flush()
# Step 5: Simulate future login — load method from DB and verify a fresh code
fetched_methods = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalars().all()
assert len(fetched_methods) == 1
stored_secret = fetched_methods[0].payload["secret"]
fresh_code = pyotp.TOTP(stored_secret).now()
assert verify_totp_code(stored_secret, fresh_code) is True
async def test_mfa_blocks_login_without_code(session, monkeypatch):
"""User with MFA should not be fully authenticated without completing MFA challenge."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
# Create user with MFA
user = User(email="mfa-block@example.com", password_hash=hash_password("password1"))
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(name="Phone", type="totp", payload={"secret": secret}, user_id=user.id)
session.add(method)
await session.flush()
# Password auth succeeds
authed_user = await authenticate_user("mfa-block@example.com", "password1")
assert authed_user is not None
# But MFA methods exist — login page would redirect to /mfa instead of completing login
mfa_methods = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == authed_user.id)
)).scalars().all()
assert len(mfa_methods) > 0 # Login flow would check this and redirect to /mfa
async def test_mfa_wrong_code_rejected(session):
"""Wrong TOTP code should be rejected even if method is valid."""
user = User(email="mfa-wrong@example.com", password_hash=hash_password("pw"))
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id)
session.add(method)
await session.flush()
# Load from DB and try wrong code
fetched = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar_one()
assert verify_totp_code(fetched.payload["secret"], "000000") is False
assert verify_totp_code(fetched.payload["secret"], "123456") is False
async def test_mfa_multiple_methods_any_valid_code_works(session):
"""If user has multiple TOTP methods, a valid code from any should work."""
user = User(email="mfa-multi@example.com")
session.add(user)
await session.flush()
secret1 = generate_totp_secret()
secret2 = generate_totp_secret()
session.add(MFAMethod(name="Phone", type="totp", payload={"secret": secret1}, user_id=user.id))
session.add(MFAMethod(name="Backup", type="totp", payload={"secret": secret2}, user_id=user.id))
await session.flush()
methods = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalars().all()
# Code from method 1 should verify against method 1's secret
code1 = pyotp.TOTP(secret1).now()
verified = False
for m in methods:
if verify_totp_code(m.payload["secret"], code1):
verified = True
break
assert verified is True
# Code from method 2 should also work
code2 = pyotp.TOTP(secret2).now()
verified2 = False
for m in methods:
if verify_totp_code(m.payload["secret"], code2):
verified2 = True
break
assert verified2 is True
async def test_mfa_method_last_used_tracking(session):
"""Verifying MFA should update last_used_at timestamp."""
user = User(email="mfa-tracking@example.com")
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id)
session.add(method)
await session.flush()
assert method.last_used_at is None
# Simulate successful verification and update
code = pyotp.TOTP(secret).now()
assert verify_totp_code(secret, code) is True
method.last_used_at = utcnow()
session.add(method)
await session.flush()
fetched = await session.get(MFAMethod, method.id)
assert fetched.last_used_at is not None
async def test_mfa_delete_method_allows_login_without_mfa(session, monkeypatch):
"""After removing all MFA methods, user should not be redirected to MFA challenge."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
user = User(email="mfa-remove@example.com", password_hash=hash_password("pw"))
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(name="Temp", type="totp", payload={"secret": secret}, user_id=user.id)
session.add(method)
await session.flush()
# MFA exists
count = (await session.execute(
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar()
assert count == 1
# Delete it
await session.delete(method)
await session.flush()
count = (await session.execute(
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar()
assert count == 0
# Password auth still works
authed = await authenticate_user("mfa-remove@example.com", "pw")
assert authed is not None
# No MFA methods — login flow would skip MFA challenge
mfa_check = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == authed.id)
)).scalars().all()
assert len(mfa_check) == 0
async def test_disabled_user_with_mfa_cannot_login(session, monkeypatch):
"""Disabled user should be rejected at password stage, never reaching MFA."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
user = User(
email="mfa-disabled@example.com",
password_hash=hash_password("pw"),
disabled_at=utcnow(),
)
session.add(user)
await session.flush()
secret = generate_totp_secret()
session.add(MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id))
await session.flush()
# Password auth rejects disabled user before MFA is ever checked
result = await authenticate_user("mfa-disabled@example.com", "pw")
assert result is None

View file

@ -0,0 +1,309 @@
"""Integration tests for OIDC — mock provider endpoints, test full auth code flow."""
import json
import time
from unittest.mock import patch
from uuid import uuid4
import respx
from httpx import Response
from jose import jwt
from sqlmodel import select
from wiregui.auth.oidc import get_provider_config, load_providers, oauth, register_providers
from wiregui.config import get_settings
from wiregui.models.configuration import Configuration
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.user import User
# --- Helper to create a fake OIDC provider config in the DB ---
async def _setup_oidc_config(session) -> Configuration:
"""Insert a Configuration with a test OIDC provider."""
config = Configuration(
openid_connect_providers=[
{
"id": "test-idp",
"label": "Test IdP",
"scope": "openid email profile",
"response_type": "code",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"discovery_document_uri": "https://idp.example.com/.well-known/openid-configuration",
"auto_create_users": True,
}
],
)
session.add(config)
await session.commit()
return config
def _mock_discovery():
"""Mock OIDC discovery document response."""
return {
"issuer": "https://idp.example.com",
"authorization_endpoint": "https://idp.example.com/authorize",
"token_endpoint": "https://idp.example.com/token",
"userinfo_endpoint": "https://idp.example.com/userinfo",
"jwks_uri": "https://idp.example.com/.well-known/jwks.json",
}
def _mock_token_response(email: str = "oidc-user@example.com"):
"""Mock OIDC token endpoint response with ID token."""
now = int(time.time())
id_token_payload = {
"iss": "https://idp.example.com",
"sub": "oidc-subject-123",
"aud": "test-client-id",
"email": email,
"name": "OIDC User",
"iat": now,
"exp": now + 3600,
"nonce": "test-nonce",
}
# Sign with a simple secret (in real life this would be RSA)
id_token = jwt.encode(id_token_payload, "fake-secret", algorithm="HS256")
return {
"access_token": "mock-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "mock-refresh-token",
"id_token": id_token,
}
# --- Provider config loading ---
async def test_load_providers_from_config(session, monkeypatch):
"""Providers should be loaded from the Configuration table."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
await _setup_oidc_config(session)
providers = await load_providers()
assert len(providers) == 1
assert providers[0]["id"] == "test-idp"
assert providers[0]["client_id"] == "test-client-id"
async def test_load_providers_empty_when_no_config(session, monkeypatch):
"""Should return empty list when no Configuration exists."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
providers = await load_providers()
assert providers == []
async def test_get_provider_config_by_id(session, monkeypatch):
"""Should find a specific provider by ID."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
await _setup_oidc_config(session)
config = await get_provider_config("test-idp")
assert config is not None
assert config["label"] == "Test IdP"
config_missing = await get_provider_config("nonexistent")
assert config_missing is None
# --- OIDC connection storage ---
async def test_oidc_connection_created_on_login(session):
"""Simulates what the callback route does: create user + OIDC connection."""
user = User(email="oidc-new@example.com", role="unprivileged")
session.add(user)
await session.flush()
token_data = _mock_token_response("oidc-new@example.com")
conn = OIDCConnection(
provider="test-idp",
refresh_token=token_data["refresh_token"],
refresh_response=token_data,
user_id=user.id,
)
session.add(conn)
await session.flush()
# Verify it was stored
fetched = (await session.execute(
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar_one()
assert fetched.provider == "test-idp"
assert fetched.refresh_token == "mock-refresh-token"
assert fetched.refresh_response["access_token"] == "mock-access-token"
async def test_oidc_connection_updated_on_re_login(session):
"""Re-login should update the existing OIDC connection, not create a duplicate."""
user = User(email="oidc-relogin@example.com")
session.add(user)
await session.flush()
# First login
conn = OIDCConnection(
provider="test-idp",
refresh_token="old-refresh-token",
user_id=user.id,
)
session.add(conn)
await session.flush()
# Re-login — update existing connection (as the callback route does)
existing = (await session.execute(
select(OIDCConnection).where(
OIDCConnection.user_id == user.id,
OIDCConnection.provider == "test-idp",
)
)).scalar_one()
existing.refresh_token = "new-refresh-token"
from wiregui.utils.time import utcnow
existing.refreshed_at = utcnow()
session.add(existing)
await session.flush()
# Should still be one connection
from sqlmodel import func
count = (await session.execute(
select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar()
assert count == 1
fetched = (await session.execute(
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar_one()
assert fetched.refresh_token == "new-refresh-token"
async def test_oidc_auto_create_user(session):
"""When auto_create_users is True, a new user should be created from OIDC email."""
email = "auto-created@example.com"
# Verify user doesn't exist
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
assert existing is None
# Simulate what callback does with auto_create
user = User(email=email, role="unprivileged")
session.add(user)
await session.flush()
from wiregui.utils.time import utcnow
user.last_signed_in_at = utcnow()
user.last_signed_in_method = "oidc:test-idp"
session.add(user)
await session.flush()
created = (await session.execute(select(User).where(User.email == email))).scalar_one()
assert created.role == "unprivileged"
assert created.last_signed_in_method == "oidc:test-idp"
async def test_oidc_disabled_user_rejected(session):
"""Disabled users should not be logged in via OIDC."""
from wiregui.utils.time import utcnow
user = User(email="oidc-disabled@example.com", disabled_at=utcnow())
session.add(user)
await session.flush()
# The callback route checks disabled_at before creating session
assert user.disabled_at is not None # Would redirect to /login
async def test_oidc_user_without_auto_create_rejected(session):
"""When auto_create is False and user doesn't exist, login should fail."""
email = "no-auto-create@example.com"
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
assert existing is None
# The callback route checks auto_create_users from provider config
# With auto_create=False and no existing user, it would redirect to /login
# This verifies the precondition
# --- OIDC refresh token flow ---
async def test_oidc_refresh_stores_new_token(session):
"""Simulates a successful token refresh updating the connection."""
user = User(email="oidc-refresh-test@example.com")
session.add(user)
await session.flush()
conn = OIDCConnection(
provider="test-idp",
refresh_token="old-refresh",
user_id=user.id,
)
session.add(conn)
await session.flush()
# Simulate refresh result
new_token = {
"access_token": "new-access",
"refresh_token": "new-refresh",
"expires_in": 3600,
}
conn.refresh_token = new_token.get("refresh_token", conn.refresh_token)
conn.refresh_response = new_token
from wiregui.utils.time import utcnow
conn.refreshed_at = utcnow()
session.add(conn)
await session.flush()
fetched = await session.get(OIDCConnection, conn.id)
assert fetched.refresh_token == "new-refresh"
assert fetched.refresh_response["access_token"] == "new-access"
assert fetched.refreshed_at is not None
async def test_oidc_multiple_providers_per_user(session):
"""User can have connections to multiple OIDC providers."""
user = User(email="multi-provider@example.com")
session.add(user)
await session.flush()
for provider in ["google", "okta", "azure-ad"]:
session.add(OIDCConnection(
provider=provider,
refresh_token=f"token-{provider}",
user_id=user.id,
))
await session.flush()
conns = (await session.execute(
select(OIDCConnection).where(OIDCConnection.user_id == user.id).order_by(OIDCConnection.provider)
)).scalars().all()
assert len(conns) == 3
assert [c.provider for c in conns] == ["azure-ad", "google", "okta"]

58
tests/test_magic_link.py Normal file
View file

@ -0,0 +1,58 @@
"""Tests for magic link authentication flow."""
from datetime import timedelta
from wiregui.auth.jwt import create_access_token, decode_access_token
from wiregui.auth.passwords import hash_password
from wiregui.models.user import User
def test_magic_link_token_creation():
"""Magic link token should be a valid JWT with short expiry."""
token = create_access_token(
user_id="user-123",
role="unprivileged",
expires_delta=timedelta(minutes=15),
)
payload = decode_access_token(token)
assert payload is not None
assert payload["sub"] == "user-123"
assert payload["role"] == "unprivileged"
def test_magic_link_token_expired():
"""Expired magic link token should be rejected."""
token = create_access_token(
user_id="user-123",
role="admin",
expires_delta=timedelta(minutes=-1), # Already expired
)
payload = decode_access_token(token)
assert payload is None
def test_magic_link_token_wrong_user():
"""Token should only be valid for the intended user."""
token = create_access_token(user_id="user-A", role="admin")
payload = decode_access_token(token)
assert payload["sub"] == "user-A"
# Caller is responsible for checking sub matches the URL user_id
async def test_magic_link_disabled_user_rejected(session):
"""Disabled users should not be able to use magic links."""
from wiregui.utils.time import utcnow
user = User(
email="disabled-magic@example.com",
password_hash=hash_password("pw"),
disabled_at=utcnow(),
)
session.add(user)
await session.flush()
# The token would be valid but the page handler checks disabled_at
token = create_access_token(user_id=str(user.id), role="unprivileged")
payload = decode_access_token(token)
assert payload is not None # Token itself is valid
assert user.disabled_at is not None # But user is disabled — handler would reject

127
tests/test_mfa.py Normal file
View file

@ -0,0 +1,127 @@
"""Tests for TOTP MFA functionality."""
import pyotp
from wiregui.auth.mfa import (
generate_totp_qr_svg,
generate_totp_secret,
get_totp_uri,
verify_totp_code,
)
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.user import User
# --- TOTP secret generation ---
def test_generate_secret():
secret = generate_totp_secret()
assert len(secret) == 32 # base32 encoded
assert secret.isalpha() or any(c.isdigit() for c in secret)
def test_generate_secret_unique():
s1 = generate_totp_secret()
s2 = generate_totp_secret()
assert s1 != s2
# --- TOTP URI ---
def test_get_totp_uri():
uri = get_totp_uri("JBSWY3DPEHPK3PXP", "user@example.com")
assert uri.startswith("otpauth://totp/")
assert "user%40example.com" in uri or "user@example.com" in uri
assert "secret=JBSWY3DPEHPK3PXP" in uri
assert "issuer=WireGUI" in uri
def test_get_totp_uri_custom_issuer():
uri = get_totp_uri("SECRET", "test@test.com", issuer="MyVPN")
assert "issuer=MyVPN" in uri
# --- TOTP verification ---
def test_verify_valid_code():
secret = generate_totp_secret()
totp = pyotp.TOTP(secret)
code = totp.now()
assert verify_totp_code(secret, code) is True
def test_verify_invalid_code():
secret = generate_totp_secret()
assert verify_totp_code(secret, "000000") is False
def test_verify_wrong_secret():
secret1 = generate_totp_secret()
secret2 = generate_totp_secret()
code = pyotp.TOTP(secret1).now()
assert verify_totp_code(secret2, code) is False
def test_verify_empty_code():
secret = generate_totp_secret()
assert verify_totp_code(secret, "") is False
# --- QR code generation ---
def test_generate_qr_svg():
uri = get_totp_uri("SECRET", "test@test.com")
svg = generate_totp_qr_svg(uri)
assert "<svg" in svg
assert "</svg>" in svg
# --- MFA method model integration ---
async def test_create_totp_method(session):
user = User(email="mfa-test@example.com")
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(
name="My Phone",
type="totp",
payload={"secret": secret},
user_id=user.id,
)
session.add(method)
await session.flush()
from sqlmodel import select
fetched = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar_one()
assert fetched.name == "My Phone"
assert fetched.type == "totp"
stored_secret = fetched.payload["secret"]
code = pyotp.TOTP(stored_secret).now()
assert verify_totp_code(stored_secret, code) is True
async def test_user_multiple_mfa_methods(session):
user = User(email="multi-mfa@example.com")
session.add(user)
await session.flush()
m1 = MFAMethod(name="Phone", type="totp", payload={"secret": generate_totp_secret()}, user_id=user.id)
m2 = MFAMethod(name="Backup", type="totp", payload={"secret": generate_totp_secret()}, user_id=user.id)
session.add_all([m1, m2])
await session.flush()
from sqlmodel import select, func
count = (await session.execute(
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar()
assert count == 2

168
tests/test_models.py Normal file
View file

@ -0,0 +1,168 @@
"""Tests for SQLModel table definitions."""
import pytest # noqa: F401 — needed for pytest.raises
from sqlmodel import select
from wiregui.models.api_token import ApiToken
from wiregui.models.configuration import Configuration
from wiregui.models.connectivity_check import ConnectivityCheck
from wiregui.models.device import Device
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.rule import Rule
from wiregui.models.user import User
async def test_create_user(session):
user = User(email="alice@example.com", role="admin")
session.add(user)
await session.flush()
result = await session.execute(select(User).where(User.email == "alice@example.com"))
fetched = result.scalar_one()
assert fetched.id == user.id
assert fetched.role == "admin"
assert fetched.disabled_at is None
async def test_create_device_with_user(session):
user = User(email="bob@example.com")
session.add(user)
await session.flush()
device = Device(
name="laptop",
public_key="pk-test-device-001",
user_id=user.id,
)
session.add(device)
await session.flush()
result = await session.execute(select(Device).where(Device.public_key == "pk-test-device-001"))
fetched = result.scalar_one()
assert fetched.name == "laptop"
assert fetched.user_id == user.id
assert fetched.use_default_dns is True
assert fetched.use_default_allowed_ips is True
assert fetched.rx_bytes is None
async def test_device_unique_public_key(session):
user = User(email="carol@example.com")
session.add(user)
await session.flush()
d1 = Device(name="d1", public_key="duplicate-key", user_id=user.id)
session.add(d1)
await session.flush()
d2 = Device(name="d2", public_key="duplicate-key", user_id=user.id)
session.add(d2)
with pytest.raises(Exception): # IntegrityError
await session.flush()
async def test_create_rule(session):
user = User(email="dave@example.com")
session.add(user)
await session.flush()
rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id)
session.add(rule)
await session.flush()
result = await session.execute(select(Rule).where(Rule.user_id == user.id))
fetched = result.scalar_one()
assert fetched.action == "accept"
assert fetched.destination == "10.0.0.0/8"
assert fetched.port_type is None
assert fetched.port_range is None
async def test_create_rule_with_port(session):
rule = Rule(
action="drop",
destination="192.168.0.0/16",
port_type="tcp",
port_range="80-443",
)
session.add(rule)
await session.flush()
fetched = (await session.execute(select(Rule).where(Rule.id == rule.id))).scalar_one()
assert fetched.port_type == "tcp"
assert fetched.port_range == "80-443"
assert fetched.user_id is None # global rule
async def test_create_mfa_method(session):
user = User(email="eve@example.com")
session.add(user)
await session.flush()
mfa = MFAMethod(
name="My Authenticator",
type="totp",
payload={"secret": "JBSWY3DPEHPK3PXP"},
user_id=user.id,
)
session.add(mfa)
await session.flush()
fetched = (await session.execute(select(MFAMethod).where(MFAMethod.user_id == user.id))).scalar_one()
assert fetched.type == "totp"
assert fetched.payload["secret"] == "JBSWY3DPEHPK3PXP"
async def test_create_oidc_connection(session):
user = User(email="frank@example.com")
session.add(user)
await session.flush()
conn = OIDCConnection(provider="google", refresh_token="tok_abc", user_id=user.id)
session.add(conn)
await session.flush()
fetched = (await session.execute(select(OIDCConnection).where(OIDCConnection.user_id == user.id))).scalar_one()
assert fetched.provider == "google"
assert fetched.refresh_token == "tok_abc"
async def test_create_api_token(session):
user = User(email="grace@example.com")
session.add(user)
await session.flush()
token = ApiToken(token_hash="sha256_fake_hash", user_id=user.id)
session.add(token)
await session.flush()
fetched = (await session.execute(select(ApiToken).where(ApiToken.user_id == user.id))).scalar_one()
assert fetched.token_hash == "sha256_fake_hash"
assert fetched.expires_at is None
async def test_create_connectivity_check(session):
check = ConnectivityCheck(url="https://example.com", response_code=200)
session.add(check)
await session.flush()
fetched = (await session.execute(select(ConnectivityCheck).where(ConnectivityCheck.id == check.id))).scalar_one()
assert fetched.response_code == 200
async def test_configuration_defaults(session):
config = Configuration()
session.add(config)
await session.flush()
fetched = (await session.execute(select(Configuration).where(Configuration.id == config.id))).scalar_one()
assert fetched.allow_unprivileged_device_management is True
assert fetched.local_auth_enabled is True
assert fetched.default_client_mtu == 1280
assert fetched.default_client_persistent_keepalive == 25
assert fetched.default_client_dns == ["1.1.1.1", "1.0.0.1"]
assert fetched.default_client_allowed_ips == ["0.0.0.0/0", "::/0"]
assert fetched.vpn_session_duration == 0
assert fetched.openid_connect_providers == []
assert fetched.saml_identity_providers == []

View file

@ -0,0 +1,89 @@
"""Tests for the notification service."""
from wiregui.services import notifications
def setup_function():
"""Clear notifications before each test."""
notifications.clear_all()
def test_add_notification():
n = notifications.add("info", "Test message")
assert n.severity == "info"
assert n.message == "Test message"
assert n.user is None
assert n.id is not None
assert n.timestamp is not None
def test_add_notification_with_user():
n = notifications.add("error", "Something broke", user="admin@example.com")
assert n.user == "admin@example.com"
assert n.severity == "error"
def test_current_returns_newest_first():
notifications.add("info", "First")
notifications.add("warning", "Second")
notifications.add("error", "Third")
current = notifications.current()
assert len(current) == 3
assert current[0].message == "Third"
assert current[1].message == "Second"
assert current[2].message == "First"
def test_count():
assert notifications.count() == 0
notifications.add("info", "One")
notifications.add("info", "Two")
assert notifications.count() == 2
def test_clear_specific():
n1 = notifications.add("info", "Keep this")
n2 = notifications.add("error", "Remove this")
notifications.clear(n2.id)
current = notifications.current()
assert len(current) == 1
assert current[0].id == n1.id
def test_clear_nonexistent_id_is_noop():
notifications.add("info", "Test")
notifications.clear("nonexistent-id")
assert notifications.count() == 1
def test_clear_all():
notifications.add("info", "One")
notifications.add("info", "Two")
notifications.add("info", "Three")
assert notifications.count() == 3
notifications.clear_all()
assert notifications.count() == 0
assert notifications.current() == []
def test_to_dict():
n = notifications.add("warning", "Test dict", user="someone@example.com")
d = n.to_dict()
assert d["severity"] == "warning"
assert d["message"] == "Test dict"
assert d["user"] == "someone@example.com"
assert "id" in d
assert "timestamp" in d
def test_max_notifications():
"""Deque should cap at MAX_NOTIFICATIONS."""
for i in range(notifications.MAX_NOTIFICATIONS + 10):
notifications.add("info", f"Notification {i}")
assert notifications.count() == notifications.MAX_NOTIFICATIONS
# Newest should be the last one added
assert notifications.current()[0].message == f"Notification {notifications.MAX_NOTIFICATIONS + 9}"

124
tests/test_services.py Normal file
View file

@ -0,0 +1,124 @@
"""Tests for services — WireGuard and events."""
from unittest.mock import AsyncMock, patch
from wiregui.models.device import Device
from wiregui.models.rule import Rule
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated, on_rule_created
def _make_device(**kwargs) -> Device:
defaults = dict(
name="test",
public_key="pk-test",
preshared_key="psk-test",
ipv4="10.3.2.5",
ipv6="fd00::3:2:5",
user_id="00000000-0000-0000-0000-000000000000",
)
defaults.update(kwargs)
return Device(**defaults)
# --- Events (with WG enabled) ---
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.firewall")
@patch("wiregui.services.events.wireguard")
async def test_on_device_created_calls_add_peer(mock_wg, mock_fw, mock_settings):
mock_settings.return_value.wg_enabled = True
mock_wg.add_peer = AsyncMock()
mock_fw.add_device_jump_rule = AsyncMock()
device = _make_device()
await on_device_created(device)
mock_wg.add_peer.assert_awaited_once_with(
public_key="pk-test",
allowed_ips=["10.3.2.5/32", "fd00::3:2:5/128"],
preshared_key="psk-test",
)
mock_fw.add_device_jump_rule.assert_awaited_once()
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.wireguard")
async def test_on_device_deleted_calls_remove_peer(mock_wg, mock_settings):
mock_settings.return_value.wg_enabled = True
mock_wg.remove_peer = AsyncMock()
device = _make_device()
await on_device_deleted(device)
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-test")
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.wireguard")
async def test_on_device_updated_calls_add_peer(mock_wg, mock_settings):
mock_settings.return_value.wg_enabled = True
mock_wg.add_peer = AsyncMock()
device = _make_device()
await on_device_updated(device)
mock_wg.add_peer.assert_awaited_once()
# --- Events (WG disabled) ---
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.wireguard")
async def test_events_skip_when_wg_disabled(mock_wg, mock_settings):
mock_settings.return_value.wg_enabled = False
mock_wg.add_peer = AsyncMock()
mock_wg.remove_peer = AsyncMock()
device = _make_device()
await on_device_created(device)
await on_device_deleted(device)
await on_device_updated(device)
mock_wg.add_peer.assert_not_awaited()
mock_wg.remove_peer.assert_not_awaited()
# --- Events (WG error handling) ---
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.firewall")
@patch("wiregui.services.events.wireguard")
async def test_on_device_created_handles_wg_error(mock_wg, mock_fw, mock_settings):
mock_settings.return_value.wg_enabled = True
mock_wg.add_peer = AsyncMock(side_effect=RuntimeError("wg failed"))
mock_fw.add_device_jump_rule = AsyncMock()
device = _make_device()
# Should not raise — error is logged
await on_device_created(device)
# --- Rule events ---
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.firewall")
async def test_on_rule_created_calls_apply_rule(mock_fw, mock_settings):
mock_settings.return_value.wg_enabled = True
mock_fw.apply_rule = AsyncMock()
rule = Rule(
action="accept",
destination="10.0.0.0/8",
port_type="tcp",
port_range="80",
user_id="00000000-0000-0000-0000-000000000000",
)
await on_rule_created(rule)
mock_fw.apply_rule.assert_awaited_once_with(
"00000000-0000-0000-0000-000000000000", "10.0.0.0/8", "accept", "tcp", "80",
)

View file

@ -0,0 +1,203 @@
"""Extended service tests — wireguard subprocess mocking, firewall nft mocking, email."""
from unittest.mock import AsyncMock, MagicMock, patch
from wiregui.services.wireguard import PeerInfo, add_peer, get_peers, remove_peer
# ========== WireGuard service (mocked subprocess) ==========
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_add_peer_without_psk(mock_run):
mock_run.return_value = ""
await add_peer("pubkey123", ["10.0.0.1/32", "fd00::1/128"], iface="wg-test")
mock_run.assert_awaited_once()
args = mock_run.call_args[0][0]
assert "wg" in args
assert "set" in args
assert "pubkey123" in args
assert "10.0.0.1/32,fd00::1/128" in args
@patch("asyncio.create_subprocess_exec")
async def test_add_peer_with_psk(mock_exec):
"""PSK path uses subprocess directly with stdin."""
mock_proc = AsyncMock()
mock_proc.communicate.return_value = (b"", b"")
mock_proc.returncode = 0
mock_exec.return_value = mock_proc
await add_peer("pubkey456", ["10.0.0.2/32"], preshared_key="psk-data", iface="wg-test")
mock_exec.assert_awaited_once()
call_args = mock_exec.call_args[0]
assert "preshared-key" in call_args
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_remove_peer(mock_run):
mock_run.return_value = ""
await remove_peer("pubkey789", iface="wg-test")
mock_run.assert_awaited_once()
args = mock_run.call_args[0][0]
assert "remove" in args
assert "pubkey789" in args
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_get_peers_parses_dump(mock_run):
dump_output = (
"privkey\tpubkey\t51820\toff\n"
"peerkey1\t(none)\t1.2.3.4:51820\t10.0.0.1/32\t1700000000\t12345\t67890\t25\n"
"peerkey2\t(none)\t(none)\t10.0.0.2/32,fd00::2/128\t0\t0\t0\t0\n"
)
mock_run.return_value = dump_output
peers = await get_peers(iface="wg-test")
assert len(peers) == 2
assert peers[0].public_key == "peerkey1"
assert peers[0].endpoint == "1.2.3.4:51820"
assert peers[0].rx_bytes == 12345
assert peers[0].tx_bytes == 67890
assert peers[0].latest_handshake is not None
assert peers[1].public_key == "peerkey2"
assert peers[1].endpoint is None
assert peers[1].rx_bytes == 0
assert peers[1].latest_handshake is None
assert len(peers[1].allowed_ips) == 2
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_get_peers_returns_empty_on_error(mock_run):
mock_run.side_effect = RuntimeError("interface not found")
peers = await get_peers(iface="wg-test")
assert peers == []
# ========== Firewall (mocked nft) ==========
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_setup_base_tables(mock_batch):
from wiregui.services.firewall import setup_base_tables
await setup_base_tables()
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("add table" in c for c in cmds)
assert any("forward" in c for c in cmds)
assert any("postrouting" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_add_user_chain(mock_batch):
from wiregui.services.firewall import add_user_chain
await add_user_chain("a1b2c3d4-0000-0000-0000-000000000000")
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("user_a1b2c3d40000" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_remove_user_chain(mock_batch):
from wiregui.services.firewall import remove_user_chain
await remove_user_chain("a1b2c3d4-0000-0000-0000-000000000000")
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("flush" in c for c in cmds)
assert any("delete" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_add_device_jump_rule(mock_batch):
from wiregui.services.firewall import add_device_jump_rule
await add_device_jump_rule("user-id-123", "10.0.0.5", "fd00::5")
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("10.0.0.5" in c and "jump" in c for c in cmds)
assert any("fd00::5" in c and "jump" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_apply_rule(mock_batch):
from wiregui.services.firewall import apply_rule
await apply_rule("user-123", "10.0.0.0/8", "accept", "tcp", "80-443")
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("10.0.0.0/8" in c and "accept" in c and "tcp dport 80-443" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_rebuild_all_rules(mock_batch):
from wiregui.services.firewall import rebuild_all_rules
await rebuild_all_rules([
{
"user_id": "user-1",
"devices": [{"ipv4": "10.0.0.1", "ipv6": "fd00::1"}],
"rules": [
{"destination": "0.0.0.0/0", "action": "accept", "port_type": None, "port_range": None},
{"destination": "192.168.0.0/16", "action": "drop", "port_type": "tcp", "port_range": "22"},
],
}
])
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("flush chain" in c and "forward" in c for c in cmds)
assert any("0.0.0.0/0" in c and "accept" in c for c in cmds)
assert any("192.168.0.0/16" in c and "drop" in c for c in cmds)
assert any("10.0.0.1" in c and "jump" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_setup_masquerade(mock_batch):
from wiregui.services.firewall import setup_masquerade
await setup_masquerade(iface="wg0")
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("masquerade" in c for c in cmds)
# ========== Email service (mocked smtp) ==========
@patch("wiregui.services.email.aiosmtplib.send", new_callable=AsyncMock)
async def test_send_email_success(mock_send, monkeypatch):
monkeypatch.setattr("wiregui.services.email.get_settings", lambda: type("S", (), {
"smtp_host": "smtp.test.com",
"smtp_port": 587,
"smtp_user": "user",
"smtp_password": "pass",
"smtp_from": "test@test.com",
})())
from wiregui.services.email import send_email
result = await send_email("to@test.com", "Subject", "Body")
assert result is True
mock_send.assert_awaited_once()
async def test_send_email_no_smtp_configured(monkeypatch):
monkeypatch.setattr("wiregui.services.email.get_settings", lambda: type("S", (), {
"smtp_host": None,
})())
from wiregui.services.email import send_email
result = await send_email("to@test.com", "Subject", "Body")
assert result is False
@patch("wiregui.services.email.aiosmtplib.send", new_callable=AsyncMock)
async def test_send_magic_link(mock_send, monkeypatch):
monkeypatch.setattr("wiregui.services.email.get_settings", lambda: type("S", (), {
"smtp_host": "smtp.test.com",
"smtp_port": 587,
"smtp_user": "u",
"smtp_password": "p",
"smtp_from": "noreply@test.com",
})())
from wiregui.services.email import send_magic_link
result = await send_magic_link("user@test.com", "https://app.test/magic/123/token")
assert result is True
mock_send.assert_awaited_once()

231
tests/test_tasks.py Normal file
View file

@ -0,0 +1,231 @@
"""Tests for background tasks — VPN session expiry and connectivity checks."""
from datetime import timedelta
from unittest.mock import AsyncMock, patch
from sqlmodel import select
from wiregui.auth.passwords import hash_password
from wiregui.models.configuration import Configuration
from wiregui.models.connectivity_check import ConnectivityCheck
from wiregui.models.device import Device
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# --- VPN session expiry ---
async def test_vpn_session_expiry_removes_expired_peers(session, monkeypatch):
"""Users whose session expired should have their WG peers removed."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
# Create config with 1-hour session duration
config = Configuration(vpn_session_duration=3600)
session.add(config)
await session.flush()
# Create a user who signed in 2 hours ago (expired)
expired_user = User(
email="expired@example.com",
password_hash=hash_password("pw"),
last_signed_in_at=utcnow() - timedelta(hours=2),
)
session.add(expired_user)
await session.flush()
device = Device(name="laptop", public_key="pk-expired", user_id=expired_user.id)
session.add(device)
await session.flush()
# Create a user who signed in 30 min ago (still valid)
active_user = User(
email="active@example.com",
password_hash=hash_password("pw"),
last_signed_in_at=utcnow() - timedelta(minutes=30),
)
session.add(active_user)
await session.flush()
active_device = Device(name="phone", public_key="pk-active", user_id=active_user.id)
session.add(active_device)
await session.flush()
# Mock WireGuard
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.vpn_session import _expire_sessions
await _expire_sessions()
# Only expired user's peer should be removed
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-expired")
async def test_vpn_session_no_expiry_when_duration_zero(session, monkeypatch):
"""When vpn_session_duration is 0 (unlimited), no peers should be removed."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
config = Configuration(vpn_session_duration=0)
session.add(config)
await session.flush()
user = User(
email="unlimited@example.com",
last_signed_in_at=utcnow() - timedelta(days=365),
)
session.add(user)
await session.flush()
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.vpn_session import _expire_sessions
await _expire_sessions()
mock_wg.remove_peer.assert_not_awaited()
async def test_vpn_session_no_expiry_when_no_config(session, monkeypatch):
"""When no Configuration exists, no peers should be removed."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
# No Configuration row at all
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.vpn_session import _expire_sessions
await _expire_sessions()
mock_wg.remove_peer.assert_not_awaited()
async def test_vpn_session_skips_disabled_users(session, monkeypatch):
"""Disabled users should be skipped even if their session is expired."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
config = Configuration(vpn_session_duration=3600)
session.add(config)
await session.flush()
user = User(
email="disabled-session@example.com",
last_signed_in_at=utcnow() - timedelta(hours=2),
disabled_at=utcnow(),
)
session.add(user)
await session.flush()
device = Device(name="d", public_key="pk-disabled-session", user_id=user.id)
session.add(device)
await session.flush()
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.vpn_session import _expire_sessions
await _expire_sessions()
mock_wg.remove_peer.assert_not_awaited()
# --- Connectivity checks ---
async def test_connectivity_check_success(session, monkeypatch):
"""Successful connectivity check should store result in DB."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.connectivity.async_session", mock_session)
# Mock httpx to return a successful response
import httpx
class MockResponse:
status_code = 200
headers = {"content-type": "text/plain"}
text = "203.0.113.1"
class MockAsyncClient:
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def get(self, url):
return MockResponse()
monkeypatch.setattr("wiregui.tasks.connectivity.httpx.AsyncClient", lambda **kw: MockAsyncClient())
from wiregui.tasks.connectivity import _check_connectivity
await _check_connectivity()
result = (await session.execute(select(ConnectivityCheck).limit(1))).scalar_one()
assert result.response_code == 200
assert result.response_body == "203.0.113.1"
async def test_connectivity_check_failure(session, monkeypatch):
"""Failed connectivity check should store error and create notification."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.connectivity.async_session", mock_session)
class MockAsyncClient:
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def get(self, url):
raise ConnectionError("Network unreachable")
monkeypatch.setattr("wiregui.tasks.connectivity.httpx.AsyncClient", lambda **kw: MockAsyncClient())
from wiregui.services import notifications
notifications.clear_all()
from wiregui.tasks.connectivity import _check_connectivity
await _check_connectivity()
result = (await session.execute(select(ConnectivityCheck).limit(1))).scalar_one()
assert result.response_code is None
assert "Network unreachable" in result.response_body
assert notifications.count() > 0
assert "connectivity" in notifications.current()[0].message.lower()

View file

@ -0,0 +1,229 @@
"""Extended task tests — stats polling, reconciliation, OIDC refresh."""
from datetime import timedelta
from unittest.mock import AsyncMock, patch
from sqlmodel import select
from wiregui.auth.passwords import hash_password
from wiregui.models.configuration import Configuration
from wiregui.models.device import Device
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.user import User
from wiregui.services.wireguard import PeerInfo
from wiregui.utils.time import utcnow
# ========== Stats task ==========
async def test_stats_update_from_wg_peers(session, monkeypatch):
"""Stats task should update device records from WireGuard peer data."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.stats.async_session", mock_session)
user = User(email="stats-user@test.com")
session.add(user)
await session.flush()
device = Device(name="stats-dev", public_key="pk-stats-test", user_id=user.id)
session.add(device)
await session.flush()
mock_peers = [
PeerInfo(
public_key="pk-stats-test",
endpoint="1.2.3.4:51820",
rx_bytes=123456,
tx_bytes=789012,
latest_handshake=utcnow(),
)
]
with patch("wiregui.tasks.stats.wireguard") as mock_wg:
mock_wg.get_peers = AsyncMock(return_value=mock_peers)
from wiregui.tasks.stats import _update_stats
await _update_stats()
refreshed = await session.get(Device, device.id)
assert refreshed.rx_bytes == 123456
assert refreshed.tx_bytes == 789012
assert refreshed.remote_ip == "1.2.3.4"
assert refreshed.latest_handshake is not None
async def test_stats_no_peers_is_noop(session, monkeypatch):
"""No WG peers should result in no DB changes."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.stats.async_session", mock_session)
with patch("wiregui.tasks.stats.wireguard") as mock_wg:
mock_wg.get_peers = AsyncMock(return_value=[])
from wiregui.tasks.stats import _update_stats
await _update_stats() # Should not raise
async def test_stats_unmatched_peer_ignored(session, monkeypatch):
"""Peers not matching any device should be ignored."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.stats.async_session", mock_session)
mock_peers = [
PeerInfo(public_key="unknown-peer-key", rx_bytes=100, tx_bytes=200)
]
with patch("wiregui.tasks.stats.wireguard") as mock_wg:
mock_wg.get_peers = AsyncMock(return_value=mock_peers)
from wiregui.tasks.stats import _update_stats
await _update_stats() # Should not raise
# ========== Reconciliation task ==========
async def test_reconcile_adds_missing_peers(session, monkeypatch):
"""Devices in DB but not in WG should be added."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.reconcile.async_session", mock_session)
user = User(email="reconcile@test.com")
session.add(user)
await session.flush()
device = Device(name="missing", public_key="pk-missing", ipv4="10.0.0.5", user_id=user.id)
session.add(device)
await session.flush()
with patch("wiregui.tasks.reconcile.wireguard") as mock_wg:
mock_wg.get_peers = AsyncMock(return_value=[]) # WG has no peers
mock_wg.add_peer = AsyncMock()
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.reconcile import reconcile
await reconcile()
mock_wg.add_peer.assert_awaited_once()
call_kwargs = mock_wg.add_peer.call_args[1]
assert call_kwargs["public_key"] == "pk-missing"
assert "10.0.0.5/32" in call_kwargs["allowed_ips"]
mock_wg.remove_peer.assert_not_awaited()
async def test_reconcile_removes_orphaned_peers(session, monkeypatch):
"""Peers in WG but not in DB should be removed."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.reconcile.async_session", mock_session)
# No devices in DB, but WG has a peer
orphan = PeerInfo(public_key="pk-orphan", rx_bytes=0, tx_bytes=0)
with patch("wiregui.tasks.reconcile.wireguard") as mock_wg:
mock_wg.get_peers = AsyncMock(return_value=[orphan])
mock_wg.add_peer = AsyncMock()
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.reconcile import reconcile
await reconcile()
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-orphan")
mock_wg.add_peer.assert_not_awaited()
async def test_reconcile_in_sync(session, monkeypatch):
"""When DB and WG match, nothing should happen."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.reconcile.async_session", mock_session)
user = User(email="in-sync@test.com")
session.add(user)
await session.flush()
device = Device(name="synced", public_key="pk-synced", user_id=user.id)
session.add(device)
await session.flush()
peer = PeerInfo(public_key="pk-synced", rx_bytes=0, tx_bytes=0)
with patch("wiregui.tasks.reconcile.wireguard") as mock_wg:
mock_wg.get_peers = AsyncMock(return_value=[peer])
mock_wg.add_peer = AsyncMock()
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.reconcile import reconcile
await reconcile()
mock_wg.add_peer.assert_not_awaited()
mock_wg.remove_peer.assert_not_awaited()
# ========== OIDC refresh task ==========
async def test_oidc_refresh_no_connections_is_noop(session, monkeypatch):
"""No OIDC connections should result in no refresh attempts."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.oidc_refresh.async_session", mock_session)
monkeypatch.setattr("wiregui.auth.oidc.load_providers", AsyncMock(return_value=[]))
from wiregui.tasks.oidc_refresh import _refresh_all
await _refresh_all() # Should not raise
async def test_oidc_refresh_skips_unknown_provider(session, monkeypatch):
"""Connections for unknown providers should be skipped."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.oidc_refresh.async_session", mock_session)
monkeypatch.setattr("wiregui.auth.oidc.load_providers", AsyncMock(return_value=[
{"id": "known-provider", "client_id": "cid", "client_secret": "cs", "discovery_document_uri": "https://x"}
]))
user = User(email="oidc-skip@test.com")
session.add(user)
await session.flush()
conn = OIDCConnection(provider="unknown-provider", refresh_token="tok", user_id=user.id)
session.add(conn)
await session.flush()
from wiregui.tasks.oidc_refresh import _refresh_all
await _refresh_all() # Should skip gracefully

120
tests/test_utils.py Normal file
View file

@ -0,0 +1,120 @@
"""Tests for utility modules."""
import subprocess
import pytest
from sqlmodel import select
from wiregui.models.device import Device
from wiregui.models.user import User
from wiregui.utils.network import allocate_ipv4, allocate_ipv6
from wiregui.utils.wg_conf import build_client_config
# --- IP allocation ---
async def test_allocate_ipv4_first_device(session):
user = User(email="net-test@example.com")
session.add(user)
await session.flush()
ip = await allocate_ipv4(session, "10.3.2.0/24")
assert ip.startswith("10.3.2.")
# Should not be the network (.0) or gateway (.1)
last_octet = int(ip.split(".")[-1])
assert last_octet >= 2
async def test_allocate_ipv4_skips_used(session):
user = User(email="net-skip@example.com")
session.add(user)
await session.flush()
# Exhaust a tiny /30 network (4 addresses: .0 network, .1 gateway, .2 usable, .3 broadcast)
d1 = Device(name="d1", public_key="pk-net-1", ipv4="10.99.0.2", user_id=user.id)
session.add(d1)
await session.flush()
# Only .2 was usable in a /30 — allocation should fail
with pytest.raises(ValueError, match="No available"):
await allocate_ipv4(session, "10.99.0.0/30")
async def test_allocate_ipv6(session):
user = User(email="net6-test@example.com")
session.add(user)
await session.flush()
ip = await allocate_ipv6(session, "fd00::3:2:0/120")
assert ip.startswith("fd00::3:2:")
# --- WireGuard config builder ---
def test_build_client_config():
device = Device(
name="test-device",
public_key="device-pub-key",
preshared_key="device-psk",
ipv4="10.3.2.5",
ipv6="fd00::3:2:5",
use_default_allowed_ips=True,
use_default_dns=True,
use_default_endpoint=True,
use_default_mtu=True,
use_default_persistent_keepalive=True,
user_id="00000000-0000-0000-0000-000000000000",
)
config = build_client_config(device, "PRIVATE_KEY_HERE", "SERVER_PUB_KEY")
assert "[Interface]" in config
assert "PrivateKey = PRIVATE_KEY_HERE" in config
assert "10.3.2.5/32" in config
assert "fd00::3:2:5/128" in config
assert "[Peer]" in config
assert "PublicKey = SERVER_PUB_KEY" in config
assert "PresharedKey = device-psk" in config
assert "Endpoint = " in config
def test_build_client_config_no_psk():
device = Device(
name="no-psk",
public_key="pub",
preshared_key=None,
ipv4="10.3.2.6",
ipv6=None,
use_default_allowed_ips=True,
use_default_dns=True,
use_default_endpoint=True,
use_default_mtu=True,
use_default_persistent_keepalive=True,
user_id="00000000-0000-0000-0000-000000000000",
)
config = build_client_config(device, "PRIV", "SERVPUB")
assert "PresharedKey" not in config
assert "fd00::" not in config # no ipv6
# --- Crypto (only if wg is installed) ---
def test_generate_keypair():
"""Test keypair generation — requires `wg` CLI to be installed."""
try:
subprocess.run(["wg", "--version"], capture_output=True, check=True)
except FileNotFoundError:
pytest.skip("wg CLI not installed")
from wiregui.utils.crypto import generate_keypair, generate_preshared_key
priv, pub = generate_keypair()
assert len(priv) == 44 # base64-encoded 32 bytes
assert len(pub) == 44
psk = generate_preshared_key()
assert len(psk) == 44