feat: initial WireGUI implementation — full VPN management platform
Complete Python/NiceGUI rewrite of the Wirezone (Elixir/Phoenix) VPN management platform. All 10 implementation phases delivered. Core stack: - NiceGUI reactive UI with SQLModel ORM on PostgreSQL (asyncpg) - Alembic migrations, Valkey/Redis cache, pydantic-settings config - WireGuard management via subprocess (wg/ip/nft CLIs) - 164 tests passing, 35% code coverage Features: - User/device/rule CRUD with admin and unprivileged roles - Full device config form with per-device WG overrides - WireGuard client config generation with QR codes - REST API (v0) with Bearer token auth for all resources - TOTP MFA with QR registration and challenge flow - OIDC SSO with authlib (provider registry, auto-create users) - Magic link passwordless sign-in via email - SAML SP-initiated SSO with IdP metadata parsing - WebAuthn/FIDO2 security key registration - nftables firewall with per-user chains and masquerade - Background tasks: WG stats polling, VPN session expiry, OIDC token refresh, WAN connectivity checks - Startup reconciliation (DB ↔ WireGuard state sync) - In-memory notification system with header badge - Admin UI: users, devices, rules, settings (3 tabs), diagnostics - Loguru logging with optional timestamped file output Deployment: - Multi-stage Dockerfile (python:3.13-slim) - Docker Compose prod stack (bridge networking, NET_ADMIN, nftables) - Forgejo CI: tests → semantic versioning → Docker registry push - Health endpoint at /api/health
This commit is contained in:
commit
0546b44507
109 changed files with 11793 additions and 0 deletions
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
65
tests/conftest.py
Normal file
65
tests/conftest.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
"""Shared test fixtures — async DB session using a test database."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from wiregui.config import get_settings
|
||||
|
||||
# All models must be imported so SQLModel.metadata knows about them
|
||||
from wiregui.models import * # noqa: F401, F403
|
||||
|
||||
|
||||
def _test_database_url() -> str:
|
||||
url = get_settings().database_url
|
||||
base, _dbname = url.rsplit("/", 1)
|
||||
return f"{base}/wiregui_test"
|
||||
|
||||
|
||||
TEST_DATABASE_URL = _test_database_url()
|
||||
|
||||
# Module-level engine creation (runs once via autouse session fixture)
|
||||
_engine = None
|
||||
|
||||
|
||||
def _ensure_test_db_sync():
|
||||
"""Ensure wiregui_test database exists (called once)."""
|
||||
import asyncio
|
||||
|
||||
async def _create():
|
||||
base_url = get_settings().database_url.rsplit("/", 1)[0] + "/postgres"
|
||||
admin_engine = create_async_engine(base_url, isolation_level="AUTOCOMMIT")
|
||||
async with admin_engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = 'wiregui_test'")
|
||||
)
|
||||
if result.scalar() is None:
|
||||
await conn.execute(text("CREATE DATABASE wiregui_test"))
|
||||
await admin_engine.dispose()
|
||||
|
||||
asyncio.run(_create())
|
||||
|
||||
|
||||
# Create test DB once at import time
|
||||
_ensure_test_db_sync()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def session() -> AsyncGenerator[AsyncSession]:
|
||||
"""Fresh engine + session per test, with table setup/teardown."""
|
||||
engine = create_async_engine(TEST_DATABASE_URL)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
async with factory() as sess:
|
||||
yield sess
|
||||
await sess.rollback()
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
await engine.dispose()
|
||||
161
tests/test_account.py
Normal file
161
tests/test_account.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
"""Tests for account functionality — password changes, API tokens, OIDC connections."""
|
||||
|
||||
import hashlib
|
||||
from datetime import timedelta
|
||||
|
||||
from sqlmodel import func, select
|
||||
|
||||
from wiregui.auth.api_token import generate_api_token
|
||||
from wiregui.auth.passwords import hash_password, verify_password
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# --- Password change ---
|
||||
|
||||
|
||||
async def test_password_change_flow(session):
|
||||
"""Simulate the password change flow: verify old, set new."""
|
||||
user = User(email="pw-change@example.com", password_hash=hash_password("old-password"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# Verify old password
|
||||
assert verify_password("old-password", user.password_hash) is True
|
||||
|
||||
# Change password
|
||||
user.password_hash = hash_password("new-password")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert verify_password("new-password", fetched.password_hash) is True
|
||||
assert verify_password("old-password", fetched.password_hash) is False
|
||||
|
||||
|
||||
async def test_password_change_wrong_current(session):
|
||||
"""Wrong current password should not allow change."""
|
||||
user = User(email="pw-wrong@example.com", password_hash=hash_password("correct"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# Simulate check
|
||||
assert verify_password("wrong", user.password_hash) is False
|
||||
|
||||
|
||||
# --- API token management ---
|
||||
|
||||
|
||||
async def test_create_multiple_tokens(session):
|
||||
user = User(email="multi-token@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
for _ in range(3):
|
||||
_, token_hash = generate_api_token()
|
||||
session.add(ApiToken(token_hash=token_hash, user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(ApiToken).where(ApiToken.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 3
|
||||
|
||||
|
||||
async def test_token_with_expiry(session):
|
||||
user = User(email="expiry-token@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
_, token_hash = generate_api_token()
|
||||
expires = utcnow() + timedelta(days=30)
|
||||
token = ApiToken(token_hash=token_hash, expires_at=expires, user_id=user.id)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(ApiToken, token.id)
|
||||
assert fetched.expires_at is not None
|
||||
assert fetched.expires_at > utcnow()
|
||||
|
||||
|
||||
async def test_delete_token(session):
|
||||
user = User(email="del-token@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
_, token_hash = generate_api_token()
|
||||
token = ApiToken(token_hash=token_hash, user_id=user.id)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
await session.delete(token)
|
||||
await session.flush()
|
||||
|
||||
assert await session.get(ApiToken, token.id) is None
|
||||
|
||||
|
||||
# --- OIDC connections ---
|
||||
|
||||
|
||||
async def test_oidc_connection_create(session):
|
||||
user = User(email="oidc-conn@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(
|
||||
provider="google",
|
||||
refresh_token="refresh-tok-123",
|
||||
refresh_response={"access_token": "at", "token_type": "Bearer"},
|
||||
refreshed_at=utcnow(),
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar_one()
|
||||
assert fetched.provider == "google"
|
||||
assert fetched.refresh_token == "refresh-tok-123"
|
||||
assert fetched.refresh_response["access_token"] == "at"
|
||||
|
||||
|
||||
async def test_multiple_oidc_providers(session):
|
||||
user = User(email="multi-oidc@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
for provider in ["google", "okta", "azure"]:
|
||||
conn = OIDCConnection(provider=provider, user_id=user.id)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 3
|
||||
|
||||
|
||||
async def test_oidc_connection_update_refresh_token(session):
|
||||
user = User(email="oidc-refresh@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(
|
||||
provider="google",
|
||||
refresh_token="old-token",
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
conn.refresh_token = "new-token"
|
||||
conn.refreshed_at = utcnow()
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(OIDCConnection, conn.id)
|
||||
assert fetched.refresh_token == "new-token"
|
||||
assert fetched.refreshed_at is not None
|
||||
283
tests/test_admin.py
Normal file
283
tests/test_admin.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
"""Tests for admin functionality — user management, configuration, cascading deletes."""
|
||||
|
||||
import pytest
|
||||
from sqlmodel import func, select
|
||||
|
||||
from wiregui.auth.passwords import hash_password, verify_password
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# --- User CRUD ---
|
||||
|
||||
|
||||
async def test_create_user_with_role(session):
|
||||
user = User(email="new-admin@test.com", password_hash=hash_password("secret"), role="admin")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert fetched.role == "admin"
|
||||
assert verify_password("secret", fetched.password_hash)
|
||||
|
||||
|
||||
async def test_update_user_email(session):
|
||||
user = User(email="old@test.com", password_hash=hash_password("pw"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
user.email = "new@test.com"
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert fetched.email == "new@test.com"
|
||||
|
||||
|
||||
async def test_disable_user(session):
|
||||
user = User(email="active@test.com", password_hash=hash_password("pw"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
assert user.disabled_at is None
|
||||
|
||||
user.disabled_at = utcnow()
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert fetched.disabled_at is not None
|
||||
|
||||
|
||||
async def test_promote_demote_user(session):
|
||||
user = User(email="user@test.com", role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
assert user.role == "unprivileged"
|
||||
|
||||
user.role = "admin"
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert fetched.role == "admin"
|
||||
|
||||
user.role = "unprivileged"
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
assert (await session.get(User, user.id)).role == "unprivileged"
|
||||
|
||||
|
||||
# --- Cascading delete (manual, as we do it in the admin page) ---
|
||||
|
||||
|
||||
async def test_delete_user_cascades_devices(session):
|
||||
user = User(email="cascade@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
d1 = Device(name="d1", public_key="pk-cascade-1", ipv4="10.0.0.1", user_id=user.id)
|
||||
d2 = Device(name="d2", public_key="pk-cascade-2", ipv4="10.0.0.2", user_id=user.id)
|
||||
session.add_all([d1, d2])
|
||||
await session.flush()
|
||||
|
||||
# Manually delete devices then user (matching admin page behavior)
|
||||
devices = (await session.execute(select(Device).where(Device.user_id == user.id))).scalars().all()
|
||||
for d in devices:
|
||||
await session.delete(d)
|
||||
await session.delete(user)
|
||||
await session.flush()
|
||||
|
||||
assert (await session.execute(select(func.count()).select_from(Device).where(Device.user_id == user.id))).scalar() == 0
|
||||
assert await session.get(User, user.id) is None
|
||||
|
||||
|
||||
async def test_delete_user_cascades_rules(session):
|
||||
user = User(email="rule-cascade@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id)
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
|
||||
# Delete rules then user
|
||||
rules = (await session.execute(select(Rule).where(Rule.user_id == user.id))).scalars().all()
|
||||
for r in rules:
|
||||
await session.delete(r)
|
||||
await session.delete(user)
|
||||
await session.flush()
|
||||
|
||||
assert (await session.execute(select(func.count()).select_from(Rule).where(Rule.user_id == user.id))).scalar() == 0
|
||||
|
||||
|
||||
# --- Configuration singleton ---
|
||||
|
||||
|
||||
async def test_configuration_create_and_update(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
assert config.default_client_mtu == 1280
|
||||
assert config.local_auth_enabled is True
|
||||
|
||||
config.default_client_mtu = 1400
|
||||
config.local_auth_enabled = False
|
||||
config.vpn_session_duration = 3600
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Configuration, config.id)
|
||||
assert fetched.default_client_mtu == 1400
|
||||
assert fetched.local_auth_enabled is False
|
||||
assert fetched.vpn_session_duration == 3600
|
||||
|
||||
|
||||
async def test_configuration_oidc_providers(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
assert config.openid_connect_providers == []
|
||||
|
||||
providers = [
|
||||
{
|
||||
"id": "google",
|
||||
"label": "Sign in with Google",
|
||||
"scope": "openid email profile",
|
||||
"response_type": "code",
|
||||
"client_id": "google-client-id",
|
||||
"client_secret": "google-secret",
|
||||
"discovery_document_uri": "https://accounts.google.com/.well-known/openid-configuration",
|
||||
"auto_create_users": True,
|
||||
},
|
||||
{
|
||||
"id": "okta",
|
||||
"label": "Okta SSO",
|
||||
"scope": "openid email profile",
|
||||
"response_type": "code",
|
||||
"client_id": "okta-client-id",
|
||||
"client_secret": "okta-secret",
|
||||
"discovery_document_uri": "https://dev-123.okta.com/.well-known/openid-configuration",
|
||||
"auto_create_users": False,
|
||||
},
|
||||
]
|
||||
config.openid_connect_providers = providers
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Configuration, config.id)
|
||||
assert len(fetched.openid_connect_providers) == 2
|
||||
assert fetched.openid_connect_providers[0]["id"] == "google"
|
||||
assert fetched.openid_connect_providers[1]["auto_create_users"] is False
|
||||
|
||||
|
||||
async def test_configuration_update_client_defaults(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
config.default_client_endpoint = "vpn.example.com"
|
||||
config.default_client_dns = ["8.8.8.8", "8.8.4.4"]
|
||||
config.default_client_allowed_ips = ["10.0.0.0/8"]
|
||||
config.default_client_persistent_keepalive = 30
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Configuration, config.id)
|
||||
assert fetched.default_client_endpoint == "vpn.example.com"
|
||||
assert fetched.default_client_dns == ["8.8.8.8", "8.8.4.4"]
|
||||
assert fetched.default_client_allowed_ips == ["10.0.0.0/8"]
|
||||
assert fetched.default_client_persistent_keepalive == 30
|
||||
|
||||
|
||||
async def test_configuration_security_toggles(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
config.allow_unprivileged_device_management = False
|
||||
config.allow_unprivileged_device_configuration = False
|
||||
config.disable_vpn_on_oidc_error = True
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Configuration, config.id)
|
||||
assert fetched.allow_unprivileged_device_management is False
|
||||
assert fetched.allow_unprivileged_device_configuration is False
|
||||
assert fetched.disable_vpn_on_oidc_error is True
|
||||
|
||||
|
||||
# --- Device config overrides ---
|
||||
|
||||
|
||||
async def test_device_with_custom_config(session):
|
||||
user = User(email="config-user@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(
|
||||
name="custom-config",
|
||||
public_key="pk-custom-config",
|
||||
user_id=user.id,
|
||||
use_default_dns=False,
|
||||
use_default_endpoint=False,
|
||||
use_default_mtu=False,
|
||||
use_default_persistent_keepalive=False,
|
||||
use_default_allowed_ips=False,
|
||||
dns=["8.8.8.8"],
|
||||
endpoint="custom-vpn.example.com",
|
||||
mtu=1400,
|
||||
persistent_keepalive=15,
|
||||
allowed_ips=["10.0.0.0/8", "172.16.0.0/12"],
|
||||
)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Device, device.id)
|
||||
assert fetched.use_default_dns is False
|
||||
assert fetched.dns == ["8.8.8.8"]
|
||||
assert fetched.endpoint == "custom-vpn.example.com"
|
||||
assert fetched.mtu == 1400
|
||||
assert fetched.persistent_keepalive == 15
|
||||
assert fetched.allowed_ips == ["10.0.0.0/8", "172.16.0.0/12"]
|
||||
|
||||
|
||||
async def test_device_default_flags_are_true(session):
|
||||
user = User(email="defaults@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(name="defaults", public_key="pk-defaults", user_id=user.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Device, device.id)
|
||||
assert fetched.use_default_allowed_ips is True
|
||||
assert fetched.use_default_dns is True
|
||||
assert fetched.use_default_endpoint is True
|
||||
assert fetched.use_default_mtu is True
|
||||
assert fetched.use_default_persistent_keepalive is True
|
||||
|
||||
|
||||
# --- User device count ---
|
||||
|
||||
|
||||
async def test_user_device_count_query(session):
|
||||
user = User(email="count-user@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
for i in range(3):
|
||||
session.add(Device(name=f"d{i}", public_key=f"pk-count-{i}", user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(Device).where(Device.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 3
|
||||
86
tests/test_api.py
Normal file
86
tests/test_api.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""Tests for REST API endpoints and token auth."""
|
||||
|
||||
import hashlib
|
||||
|
||||
from wiregui.auth.api_token import generate_api_token, resolve_bearer_token
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# --- Token generation ---
|
||||
|
||||
|
||||
def test_generate_api_token():
|
||||
plaintext, token_hash = generate_api_token()
|
||||
assert len(plaintext) > 20
|
||||
assert token_hash == hashlib.sha256(plaintext.encode()).hexdigest()
|
||||
|
||||
|
||||
def test_generate_api_token_unique():
|
||||
t1, h1 = generate_api_token()
|
||||
t2, h2 = generate_api_token()
|
||||
assert t1 != t2
|
||||
assert h1 != h2
|
||||
|
||||
|
||||
# --- Token resolution ---
|
||||
|
||||
|
||||
async def test_resolve_valid_token(session):
|
||||
user = User(email="api-user@example.com", password_hash=hash_password("x"), role="admin")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
token = ApiToken(token_hash=token_hash, user_id=user.id)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is not None
|
||||
assert resolved.id == user.id
|
||||
|
||||
|
||||
async def test_resolve_invalid_token(session):
|
||||
resolved = await resolve_bearer_token(session, "bogus-token")
|
||||
assert resolved is None
|
||||
|
||||
|
||||
async def test_resolve_expired_token(session):
|
||||
from datetime import timedelta
|
||||
|
||||
user = User(email="expired-api@example.com", password_hash=hash_password("x"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
token = ApiToken(
|
||||
token_hash=token_hash,
|
||||
user_id=user.id,
|
||||
expires_at=utcnow() - timedelta(hours=1),
|
||||
)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is None
|
||||
|
||||
|
||||
async def test_resolve_token_disabled_user(session):
|
||||
user = User(
|
||||
email="disabled-api@example.com",
|
||||
password_hash=hash_password("x"),
|
||||
disabled_at=utcnow(),
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
token = ApiToken(token_hash=token_hash, user_id=user.id)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is None
|
||||
325
tests/test_api_routes.py
Normal file
325
tests/test_api_routes.py
Normal file
|
|
@ -0,0 +1,325 @@
|
|||
"""Tests for REST API routes via httpx AsyncClient against the FastAPI app."""
|
||||
|
||||
import hashlib
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.api.deps import get_current_api_user, get_db, require_admin
|
||||
from wiregui.api.v0 import router as api_router
|
||||
from wiregui.auth.api_token import generate_api_token
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
def _build_app(session, admin_user=None, regular_user=None):
|
||||
"""Build a test FastAPI app with overridden dependencies."""
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(api_router, prefix="/api")
|
||||
|
||||
async def override_get_db():
|
||||
yield session
|
||||
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
if admin_user:
|
||||
test_app.dependency_overrides[get_current_api_user] = lambda: admin_user
|
||||
test_app.dependency_overrides[require_admin] = lambda: admin_user
|
||||
|
||||
return test_app
|
||||
|
||||
|
||||
async def _make_admin(session) -> User:
|
||||
user = User(email="api-admin@test.com", password_hash=hash_password("pw"), role="admin")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
return user
|
||||
|
||||
|
||||
async def _make_user(session, email="api-user@test.com") -> User:
|
||||
user = User(email=email, password_hash=hash_password("pw"), role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
return user
|
||||
|
||||
|
||||
# ========== Users API ==========
|
||||
|
||||
|
||||
async def test_list_users(session):
|
||||
admin = await _make_admin(session)
|
||||
await _make_user(session, "user1@test.com")
|
||||
await _make_user(session, "user2@test.com")
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/users/")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 3 # admin + 2 users
|
||||
|
||||
|
||||
async def test_get_user(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get(f"/api/v0/users/{admin.id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["email"] == "api-admin@test.com"
|
||||
|
||||
|
||||
async def test_get_user_not_found(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get(f"/api/v0/users/{uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_create_user(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.post("/api/v0/users/", json={
|
||||
"email": "new-api-user@test.com",
|
||||
"password": "secret123",
|
||||
"role": "unprivileged",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["email"] == "new-api-user@test.com"
|
||||
assert data["role"] == "unprivileged"
|
||||
assert "id" in data
|
||||
|
||||
|
||||
async def test_update_user(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put(f"/api/v0/users/{user.id}", json={
|
||||
"role": "admin",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["role"] == "admin"
|
||||
|
||||
|
||||
async def test_update_user_password(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put(f"/api/v0/users/{user.id}", json={
|
||||
"password": "new-password-123",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
from wiregui.auth.passwords import verify_password
|
||||
refreshed = await session.get(User, user.id)
|
||||
assert verify_password("new-password-123", refreshed.password_hash)
|
||||
|
||||
|
||||
async def test_delete_user(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.delete(f"/api/v0/users/{user.id}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
assert await session.get(User, user.id) is None
|
||||
|
||||
|
||||
# ========== Devices API ==========
|
||||
|
||||
|
||||
async def test_list_devices_admin_sees_all(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session)
|
||||
session.add(Device(name="d1", public_key="pk-api-d1", user_id=admin.id))
|
||||
session.add(Device(name="d2", public_key="pk-api-d2", user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/devices/")
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) >= 2
|
||||
|
||||
|
||||
async def test_list_devices_user_sees_own(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session, "own-devices@test.com")
|
||||
session.add(Device(name="mine", public_key="pk-api-mine", user_id=user.id))
|
||||
session.add(Device(name="not-mine", public_key="pk-api-notmine", user_id=admin.id))
|
||||
await session.flush()
|
||||
|
||||
# Override to be the regular user
|
||||
test_app = _build_app(session)
|
||||
test_app.dependency_overrides[get_current_api_user] = lambda: user
|
||||
async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/devices/")
|
||||
assert resp.status_code == 200
|
||||
names = [d["name"] for d in resp.json()]
|
||||
assert "mine" in names
|
||||
assert "not-mine" not in names
|
||||
|
||||
|
||||
async def test_get_device(session):
|
||||
admin = await _make_admin(session)
|
||||
device = Device(name="detail", public_key="pk-api-detail", user_id=admin.id, ipv4="10.0.0.5")
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get(f"/api/v0/devices/{device.id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "detail"
|
||||
assert resp.json()["ipv4"] == "10.0.0.5"
|
||||
|
||||
|
||||
async def test_get_device_forbidden_for_other_user(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session, "other-dev@test.com")
|
||||
device = Device(name="admin-dev", public_key="pk-api-forbid", user_id=admin.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
test_app = _build_app(session)
|
||||
test_app.dependency_overrides[get_current_api_user] = lambda: user
|
||||
async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
|
||||
resp = await client.get(f"/api/v0/devices/{device.id}")
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
async def test_update_device(session):
|
||||
admin = await _make_admin(session)
|
||||
device = Device(name="old-name", public_key="pk-api-update", user_id=admin.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put(f"/api/v0/devices/{device.id}", json={"name": "new-name"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "new-name"
|
||||
|
||||
|
||||
async def test_delete_device(session):
|
||||
admin = await _make_admin(session)
|
||||
device = Device(name="to-delete", public_key="pk-api-del", user_id=admin.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
did = device.id
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.delete(f"/api/v0/devices/{did}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
assert await session.get(Device, did) is None
|
||||
|
||||
|
||||
# ========== Rules API ==========
|
||||
|
||||
|
||||
async def test_list_rules(session):
|
||||
admin = await _make_admin(session)
|
||||
session.add(Rule(action="accept", destination="10.0.0.0/8"))
|
||||
session.add(Rule(action="drop", destination="192.168.0.0/16", user_id=admin.id))
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/rules/")
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) >= 2
|
||||
|
||||
|
||||
async def test_create_rule(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.post("/api/v0/rules/", json={
|
||||
"action": "accept",
|
||||
"destination": "172.16.0.0/12",
|
||||
"port_type": "tcp",
|
||||
"port_range": "443",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["action"] == "accept"
|
||||
assert data["destination"] == "172.16.0.0/12"
|
||||
assert data["port_type"] == "tcp"
|
||||
assert data["port_range"] == "443"
|
||||
|
||||
|
||||
async def test_update_rule(session):
|
||||
admin = await _make_admin(session)
|
||||
rule = Rule(action="accept", destination="10.0.0.0/8")
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put(f"/api/v0/rules/{rule.id}", json={"action": "drop"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["action"] == "drop"
|
||||
|
||||
|
||||
async def test_delete_rule(session):
|
||||
admin = await _make_admin(session)
|
||||
rule = Rule(action="drop", destination="0.0.0.0/0")
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
rid = rule.id
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.delete(f"/api/v0/rules/{rid}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
assert await session.get(Rule, rid) is None
|
||||
|
||||
|
||||
# ========== Configuration API ==========
|
||||
|
||||
|
||||
async def test_get_configuration_auto_creates(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/configuration/")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["default_client_mtu"] == 1280
|
||||
assert data["local_auth_enabled"] is True
|
||||
|
||||
|
||||
async def test_update_configuration(session):
|
||||
admin = await _make_admin(session)
|
||||
# Pre-create config
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put("/api/v0/configuration/", json={
|
||||
"default_client_mtu": 1400,
|
||||
"vpn_session_duration": 3600,
|
||||
"default_client_dns": ["8.8.8.8"],
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["default_client_mtu"] == 1400
|
||||
assert data["vpn_session_duration"] == 3600
|
||||
assert data["default_client_dns"] == ["8.8.8.8"]
|
||||
98
tests/test_auth.py
Normal file
98
tests/test_auth.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""Tests for authentication modules."""
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.jwt import create_access_token, decode_access_token
|
||||
from wiregui.auth.passwords import hash_password, verify_password
|
||||
from wiregui.auth.seed import seed_admin
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
# --- Password hashing ---
|
||||
|
||||
|
||||
def test_hash_and_verify():
|
||||
hashed = hash_password("my-secret")
|
||||
assert verify_password("my-secret", hashed) is True
|
||||
|
||||
|
||||
def test_verify_wrong_password():
|
||||
hashed = hash_password("correct")
|
||||
assert verify_password("wrong", hashed) is False
|
||||
|
||||
|
||||
def test_hash_is_not_plaintext():
|
||||
hashed = hash_password("plaintext")
|
||||
assert hashed != "plaintext"
|
||||
assert hashed.startswith("$2b$")
|
||||
|
||||
|
||||
# --- JWT ---
|
||||
|
||||
|
||||
def test_create_and_decode_token():
|
||||
token = create_access_token(user_id="user-123", role="admin")
|
||||
payload = decode_access_token(token)
|
||||
assert payload is not None
|
||||
assert payload["sub"] == "user-123"
|
||||
assert payload["role"] == "admin"
|
||||
assert "exp" in payload
|
||||
|
||||
|
||||
def test_decode_invalid_token():
|
||||
assert decode_access_token("garbage.token.value") is None
|
||||
|
||||
|
||||
def test_decode_tampered_token():
|
||||
token = create_access_token(user_id="user-123", role="admin")
|
||||
tampered = token[:-4] + "XXXX"
|
||||
assert decode_access_token(tampered) is None
|
||||
|
||||
|
||||
# --- Admin seed ---
|
||||
|
||||
|
||||
async def test_seed_admin_creates_user(session, monkeypatch):
|
||||
"""seed_admin should create an admin when no users exist."""
|
||||
# Patch async_session to use our test session
|
||||
from unittest.mock import AsyncMock
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.seed.async_session", mock_session)
|
||||
monkeypatch.setattr("wiregui.auth.seed.get_settings", lambda: type("S", (), {
|
||||
"admin_email": "seed-test@example.com",
|
||||
"admin_password": "seed-pass-123",
|
||||
})())
|
||||
|
||||
await seed_admin()
|
||||
|
||||
result = await session.execute(select(User).where(User.email == "seed-test@example.com"))
|
||||
admin = result.scalar_one()
|
||||
assert admin.role == "admin"
|
||||
assert verify_password("seed-pass-123", admin.password_hash)
|
||||
|
||||
|
||||
async def test_seed_admin_skips_when_users_exist(session, monkeypatch):
|
||||
"""seed_admin should not create a second admin if users already exist."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
existing = User(email="existing@example.com", role="unprivileged")
|
||||
session.add(existing)
|
||||
await session.flush()
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.seed.async_session", mock_session)
|
||||
|
||||
await seed_admin()
|
||||
|
||||
result = await session.execute(select(User))
|
||||
users = result.scalars().all()
|
||||
assert len(users) == 1
|
||||
assert users[0].email == "existing@example.com"
|
||||
226
tests/test_auth_extended.py
Normal file
226
tests/test_auth_extended.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""Extended auth tests — OIDC registration, WebAuthn options, session edge cases."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.auth.session import authenticate_user
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# ========== Session / authenticate_user edge cases ==========
|
||||
|
||||
|
||||
async def test_authenticate_user_no_password_hash(session, monkeypatch):
|
||||
"""Users without a password (OIDC-only) should not authenticate via password."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
user = User(email="no-pw@test.com", password_hash=None)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
result = await authenticate_user("no-pw@test.com", "anything")
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_authenticate_user_disabled(session, monkeypatch):
|
||||
"""Disabled users should not authenticate."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
user = User(email="disabled-auth@test.com", password_hash=hash_password("pw"), disabled_at=utcnow())
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
result = await authenticate_user("disabled-auth@test.com", "pw")
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_authenticate_user_nonexistent(session, monkeypatch):
|
||||
"""Nonexistent email should return None."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
result = await authenticate_user("ghost@nowhere.com", "pw")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ========== OIDC provider registration ==========
|
||||
|
||||
|
||||
async def test_register_providers_from_config(session, monkeypatch):
|
||||
"""register_providers should register configured OIDC providers with authlib."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||||
|
||||
from wiregui.models.configuration import Configuration
|
||||
config = Configuration(openid_connect_providers=[
|
||||
{
|
||||
"id": "test-reg",
|
||||
"label": "Test",
|
||||
"scope": "openid email",
|
||||
"client_id": "cid",
|
||||
"client_secret": "cs",
|
||||
"discovery_document_uri": "https://idp.test/.well-known/openid-configuration",
|
||||
}
|
||||
])
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
with patch("wiregui.auth.oidc.oauth") as mock_oauth:
|
||||
from wiregui.auth.oidc import register_providers
|
||||
await register_providers()
|
||||
mock_oauth.register.assert_called_once()
|
||||
call_kwargs = mock_oauth.register.call_args[1]
|
||||
assert call_kwargs["name"] == "test-reg"
|
||||
assert call_kwargs["client_id"] == "cid"
|
||||
|
||||
|
||||
async def test_get_client_unknown_provider():
|
||||
"""get_client should raise for unregistered providers."""
|
||||
import pytest
|
||||
from wiregui.auth.oidc import get_client
|
||||
with pytest.raises(ValueError, match="not registered"):
|
||||
get_client("nonexistent-provider-xyz")
|
||||
|
||||
|
||||
# ========== WebAuthn options ==========
|
||||
|
||||
|
||||
def test_webauthn_registration_options(monkeypatch):
|
||||
"""create_registration_options should return valid options and challenge."""
|
||||
monkeypatch.setattr("wiregui.auth.webauthn.get_settings", lambda: type("S", (), {
|
||||
"external_url": "https://vpn.example.com",
|
||||
})())
|
||||
|
||||
from wiregui.auth.webauthn import create_registration_options
|
||||
user_id = uuid4()
|
||||
result = create_registration_options(user_id, "user@example.com")
|
||||
|
||||
assert "options_json" in result
|
||||
assert "challenge" in result
|
||||
assert len(result["challenge"]) > 10
|
||||
assert "user@example.com" in result["options_json"]
|
||||
|
||||
|
||||
def test_webauthn_registration_options_with_excludes(monkeypatch):
|
||||
"""Existing credentials should be excluded from registration options."""
|
||||
monkeypatch.setattr("wiregui.auth.webauthn.get_settings", lambda: type("S", (), {
|
||||
"external_url": "https://vpn.example.com",
|
||||
})())
|
||||
|
||||
from wiregui.auth.webauthn import create_registration_options
|
||||
existing = [{"credential_id": "AQIDBA"}] # base64url of bytes [1,2,3,4]
|
||||
result = create_registration_options(uuid4(), "user@example.com", existing)
|
||||
assert "options_json" in result
|
||||
|
||||
|
||||
def test_webauthn_authentication_options(monkeypatch):
|
||||
"""create_authentication_options should accept credential descriptors."""
|
||||
monkeypatch.setattr("wiregui.auth.webauthn.get_settings", lambda: type("S", (), {
|
||||
"external_url": "https://vpn.example.com",
|
||||
})())
|
||||
|
||||
from wiregui.auth.webauthn import create_authentication_options
|
||||
credentials = [{"credential_id": "AQIDBA"}]
|
||||
result = create_authentication_options(credentials)
|
||||
assert "options_json" in result
|
||||
assert "challenge" in result
|
||||
|
||||
|
||||
# ========== Events — rule update/delete with rebuild ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.firewall")
|
||||
async def test_on_rule_updated_triggers_rebuild(mock_fw, mock_settings):
|
||||
"""on_rule_updated should rebuild the user's firewall chain."""
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_fw.rebuild_all_rules = AsyncMock()
|
||||
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.services.events import on_rule_updated
|
||||
|
||||
# Need to mock the DB call inside _rebuild_user_chain
|
||||
with patch("wiregui.services.events.async_session") as mock_session_factory:
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Mock the select results
|
||||
mock_rules_result = MagicMock()
|
||||
mock_rules_result.scalars.return_value.all.return_value = []
|
||||
mock_devices_result = MagicMock()
|
||||
mock_devices_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute = AsyncMock(side_effect=[mock_rules_result, mock_devices_result])
|
||||
|
||||
mock_session_factory.return_value = mock_session
|
||||
|
||||
rule = Rule(action="accept", destination="10.0.0.0/8", user_id="a1b2c3d4-0000-0000-0000-000000000000")
|
||||
await on_rule_updated(rule)
|
||||
|
||||
mock_fw.rebuild_all_rules.assert_awaited_once()
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.firewall")
|
||||
async def test_on_rule_deleted_triggers_rebuild(mock_fw, mock_settings):
|
||||
"""on_rule_deleted should rebuild the user's firewall chain."""
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_fw.rebuild_all_rules = AsyncMock()
|
||||
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.services.events import on_rule_deleted
|
||||
|
||||
with patch("wiregui.services.events.async_session") as mock_session_factory:
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_rules_result = MagicMock()
|
||||
mock_rules_result.scalars.return_value.all.return_value = []
|
||||
mock_devices_result = MagicMock()
|
||||
mock_devices_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute = AsyncMock(side_effect=[mock_rules_result, mock_devices_result])
|
||||
|
||||
mock_session_factory.return_value = mock_session
|
||||
|
||||
rule = Rule(action="drop", destination="0.0.0.0/0", user_id="a1b2c3d4-0000-0000-0000-000000000000")
|
||||
await on_rule_deleted(rule)
|
||||
|
||||
mock_fw.rebuild_all_rules.assert_awaited_once()
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
async def test_on_rule_deleted_skips_when_disabled(mock_settings):
|
||||
"""Rule events should be no-ops when WG is disabled."""
|
||||
mock_settings.return_value.wg_enabled = False
|
||||
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.services.events import on_rule_deleted, on_rule_updated
|
||||
|
||||
rule = Rule(action="drop", destination="0.0.0.0/0", user_id="a1b2c3d4-0000-0000-0000-000000000000")
|
||||
await on_rule_updated(rule) # Should not raise
|
||||
await on_rule_deleted(rule) # Should not raise
|
||||
40
tests/test_firewall.py
Normal file
40
tests/test_firewall.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""Tests for firewall service — rule expression building and chain naming."""
|
||||
|
||||
from wiregui.services.firewall import _build_rule_expr, _user_chain_name
|
||||
|
||||
|
||||
def test_user_chain_name():
|
||||
uid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
name = _user_chain_name(uid)
|
||||
assert name == "user_a1b2c3d4e5f6"
|
||||
assert len(name) <= 30
|
||||
|
||||
|
||||
def test_user_chain_name_deterministic():
|
||||
uid = "12345678-1234-1234-1234-123456789abc"
|
||||
assert _user_chain_name(uid) == _user_chain_name(uid)
|
||||
|
||||
|
||||
def test_build_rule_expr_ipv4_accept():
|
||||
expr = _build_rule_expr("10.0.0.0/8", "accept")
|
||||
assert expr == "ip daddr 10.0.0.0/8 accept"
|
||||
|
||||
|
||||
def test_build_rule_expr_ipv6_drop():
|
||||
expr = _build_rule_expr("fd00::/64", "drop")
|
||||
assert expr == "ip6 daddr fd00::/64 drop"
|
||||
|
||||
|
||||
def test_build_rule_expr_with_port():
|
||||
expr = _build_rule_expr("192.168.0.0/16", "accept", port_type="tcp", port_range="80-443")
|
||||
assert expr == "ip daddr 192.168.0.0/16 tcp dport 80-443 accept"
|
||||
|
||||
|
||||
def test_build_rule_expr_single_port():
|
||||
expr = _build_rule_expr("10.0.0.1/32", "drop", port_type="udp", port_range="53")
|
||||
assert expr == "ip daddr 10.0.0.1/32 udp dport 53 drop"
|
||||
|
||||
|
||||
def test_build_rule_expr_no_port():
|
||||
expr = _build_rule_expr("0.0.0.0/0", "accept", port_type=None, port_range=None)
|
||||
assert expr == "ip daddr 0.0.0.0/0 accept"
|
||||
239
tests/test_integration_mfa.py
Normal file
239
tests/test_integration_mfa.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""Integration tests for MFA — full registration and authentication flows through the database."""
|
||||
|
||||
import pyotp
|
||||
from sqlmodel import func, select
|
||||
|
||||
from wiregui.auth.mfa import generate_totp_secret, verify_totp_code
|
||||
from wiregui.auth.passwords import hash_password, verify_password
|
||||
from wiregui.auth.session import authenticate_user
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
async def test_full_totp_registration_flow(session, monkeypatch):
|
||||
"""End-to-end: create user → generate secret → verify code → store method → re-verify from DB."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
# Create user with password
|
||||
user = User(email="mfa-flow@example.com", password_hash=hash_password("secure123"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# Step 1: Generate TOTP secret (happens in account page)
|
||||
secret = generate_totp_secret()
|
||||
|
||||
# Step 2: User scans QR, enters code from their authenticator
|
||||
totp = pyotp.TOTP(secret)
|
||||
code = totp.now()
|
||||
|
||||
# Step 3: Verify the code is correct before saving
|
||||
assert verify_totp_code(secret, code) is True
|
||||
|
||||
# Step 4: Save the MFA method to DB
|
||||
method = MFAMethod(
|
||||
name="My Authenticator",
|
||||
type="totp",
|
||||
payload={"secret": secret},
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
# Step 5: Simulate future login — load method from DB and verify a fresh code
|
||||
fetched_methods = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalars().all()
|
||||
|
||||
assert len(fetched_methods) == 1
|
||||
stored_secret = fetched_methods[0].payload["secret"]
|
||||
fresh_code = pyotp.TOTP(stored_secret).now()
|
||||
assert verify_totp_code(stored_secret, fresh_code) is True
|
||||
|
||||
|
||||
async def test_mfa_blocks_login_without_code(session, monkeypatch):
|
||||
"""User with MFA should not be fully authenticated without completing MFA challenge."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
# Create user with MFA
|
||||
user = User(email="mfa-block@example.com", password_hash=hash_password("password1"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(name="Phone", type="totp", payload={"secret": secret}, user_id=user.id)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
# Password auth succeeds
|
||||
authed_user = await authenticate_user("mfa-block@example.com", "password1")
|
||||
assert authed_user is not None
|
||||
|
||||
# But MFA methods exist — login page would redirect to /mfa instead of completing login
|
||||
mfa_methods = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == authed_user.id)
|
||||
)).scalars().all()
|
||||
assert len(mfa_methods) > 0 # Login flow would check this and redirect to /mfa
|
||||
|
||||
|
||||
async def test_mfa_wrong_code_rejected(session):
|
||||
"""Wrong TOTP code should be rejected even if method is valid."""
|
||||
user = User(email="mfa-wrong@example.com", password_hash=hash_password("pw"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
# Load from DB and try wrong code
|
||||
fetched = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar_one()
|
||||
|
||||
assert verify_totp_code(fetched.payload["secret"], "000000") is False
|
||||
assert verify_totp_code(fetched.payload["secret"], "123456") is False
|
||||
|
||||
|
||||
async def test_mfa_multiple_methods_any_valid_code_works(session):
|
||||
"""If user has multiple TOTP methods, a valid code from any should work."""
|
||||
user = User(email="mfa-multi@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret1 = generate_totp_secret()
|
||||
secret2 = generate_totp_secret()
|
||||
|
||||
session.add(MFAMethod(name="Phone", type="totp", payload={"secret": secret1}, user_id=user.id))
|
||||
session.add(MFAMethod(name="Backup", type="totp", payload={"secret": secret2}, user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
methods = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalars().all()
|
||||
|
||||
# Code from method 1 should verify against method 1's secret
|
||||
code1 = pyotp.TOTP(secret1).now()
|
||||
verified = False
|
||||
for m in methods:
|
||||
if verify_totp_code(m.payload["secret"], code1):
|
||||
verified = True
|
||||
break
|
||||
assert verified is True
|
||||
|
||||
# Code from method 2 should also work
|
||||
code2 = pyotp.TOTP(secret2).now()
|
||||
verified2 = False
|
||||
for m in methods:
|
||||
if verify_totp_code(m.payload["secret"], code2):
|
||||
verified2 = True
|
||||
break
|
||||
assert verified2 is True
|
||||
|
||||
|
||||
async def test_mfa_method_last_used_tracking(session):
|
||||
"""Verifying MFA should update last_used_at timestamp."""
|
||||
user = User(email="mfa-tracking@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
assert method.last_used_at is None
|
||||
|
||||
# Simulate successful verification and update
|
||||
code = pyotp.TOTP(secret).now()
|
||||
assert verify_totp_code(secret, code) is True
|
||||
|
||||
method.last_used_at = utcnow()
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(MFAMethod, method.id)
|
||||
assert fetched.last_used_at is not None
|
||||
|
||||
|
||||
async def test_mfa_delete_method_allows_login_without_mfa(session, monkeypatch):
|
||||
"""After removing all MFA methods, user should not be redirected to MFA challenge."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
user = User(email="mfa-remove@example.com", password_hash=hash_password("pw"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(name="Temp", type="totp", payload={"secret": secret}, user_id=user.id)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
# MFA exists
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 1
|
||||
|
||||
# Delete it
|
||||
await session.delete(method)
|
||||
await session.flush()
|
||||
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 0
|
||||
|
||||
# Password auth still works
|
||||
authed = await authenticate_user("mfa-remove@example.com", "pw")
|
||||
assert authed is not None
|
||||
|
||||
# No MFA methods — login flow would skip MFA challenge
|
||||
mfa_check = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == authed.id)
|
||||
)).scalars().all()
|
||||
assert len(mfa_check) == 0
|
||||
|
||||
|
||||
async def test_disabled_user_with_mfa_cannot_login(session, monkeypatch):
|
||||
"""Disabled user should be rejected at password stage, never reaching MFA."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
user = User(
|
||||
email="mfa-disabled@example.com",
|
||||
password_hash=hash_password("pw"),
|
||||
disabled_at=utcnow(),
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
session.add(MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
# Password auth rejects disabled user before MFA is ever checked
|
||||
result = await authenticate_user("mfa-disabled@example.com", "pw")
|
||||
assert result is None
|
||||
309
tests/test_integration_oidc.py
Normal file
309
tests/test_integration_oidc.py
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
"""Integration tests for OIDC — mock provider endpoints, test full auth code flow."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import respx
|
||||
from httpx import Response
|
||||
from jose import jwt
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.oidc import get_provider_config, load_providers, oauth, register_providers
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
# --- Helper to create a fake OIDC provider config in the DB ---
|
||||
|
||||
|
||||
async def _setup_oidc_config(session) -> Configuration:
|
||||
"""Insert a Configuration with a test OIDC provider."""
|
||||
config = Configuration(
|
||||
openid_connect_providers=[
|
||||
{
|
||||
"id": "test-idp",
|
||||
"label": "Test IdP",
|
||||
"scope": "openid email profile",
|
||||
"response_type": "code",
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-client-secret",
|
||||
"discovery_document_uri": "https://idp.example.com/.well-known/openid-configuration",
|
||||
"auto_create_users": True,
|
||||
}
|
||||
],
|
||||
)
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
return config
|
||||
|
||||
|
||||
def _mock_discovery():
|
||||
"""Mock OIDC discovery document response."""
|
||||
return {
|
||||
"issuer": "https://idp.example.com",
|
||||
"authorization_endpoint": "https://idp.example.com/authorize",
|
||||
"token_endpoint": "https://idp.example.com/token",
|
||||
"userinfo_endpoint": "https://idp.example.com/userinfo",
|
||||
"jwks_uri": "https://idp.example.com/.well-known/jwks.json",
|
||||
}
|
||||
|
||||
|
||||
def _mock_token_response(email: str = "oidc-user@example.com"):
|
||||
"""Mock OIDC token endpoint response with ID token."""
|
||||
now = int(time.time())
|
||||
id_token_payload = {
|
||||
"iss": "https://idp.example.com",
|
||||
"sub": "oidc-subject-123",
|
||||
"aud": "test-client-id",
|
||||
"email": email,
|
||||
"name": "OIDC User",
|
||||
"iat": now,
|
||||
"exp": now + 3600,
|
||||
"nonce": "test-nonce",
|
||||
}
|
||||
# Sign with a simple secret (in real life this would be RSA)
|
||||
id_token = jwt.encode(id_token_payload, "fake-secret", algorithm="HS256")
|
||||
|
||||
return {
|
||||
"access_token": "mock-access-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"id_token": id_token,
|
||||
}
|
||||
|
||||
|
||||
# --- Provider config loading ---
|
||||
|
||||
|
||||
async def test_load_providers_from_config(session, monkeypatch):
|
||||
"""Providers should be loaded from the Configuration table."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||||
|
||||
await _setup_oidc_config(session)
|
||||
|
||||
providers = await load_providers()
|
||||
assert len(providers) == 1
|
||||
assert providers[0]["id"] == "test-idp"
|
||||
assert providers[0]["client_id"] == "test-client-id"
|
||||
|
||||
|
||||
async def test_load_providers_empty_when_no_config(session, monkeypatch):
|
||||
"""Should return empty list when no Configuration exists."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||||
|
||||
providers = await load_providers()
|
||||
assert providers == []
|
||||
|
||||
|
||||
async def test_get_provider_config_by_id(session, monkeypatch):
|
||||
"""Should find a specific provider by ID."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||||
|
||||
await _setup_oidc_config(session)
|
||||
|
||||
config = await get_provider_config("test-idp")
|
||||
assert config is not None
|
||||
assert config["label"] == "Test IdP"
|
||||
|
||||
config_missing = await get_provider_config("nonexistent")
|
||||
assert config_missing is None
|
||||
|
||||
|
||||
# --- OIDC connection storage ---
|
||||
|
||||
|
||||
async def test_oidc_connection_created_on_login(session):
|
||||
"""Simulates what the callback route does: create user + OIDC connection."""
|
||||
user = User(email="oidc-new@example.com", role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
token_data = _mock_token_response("oidc-new@example.com")
|
||||
conn = OIDCConnection(
|
||||
provider="test-idp",
|
||||
refresh_token=token_data["refresh_token"],
|
||||
refresh_response=token_data,
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
# Verify it was stored
|
||||
fetched = (await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar_one()
|
||||
assert fetched.provider == "test-idp"
|
||||
assert fetched.refresh_token == "mock-refresh-token"
|
||||
assert fetched.refresh_response["access_token"] == "mock-access-token"
|
||||
|
||||
|
||||
async def test_oidc_connection_updated_on_re_login(session):
|
||||
"""Re-login should update the existing OIDC connection, not create a duplicate."""
|
||||
user = User(email="oidc-relogin@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# First login
|
||||
conn = OIDCConnection(
|
||||
provider="test-idp",
|
||||
refresh_token="old-refresh-token",
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
# Re-login — update existing connection (as the callback route does)
|
||||
existing = (await session.execute(
|
||||
select(OIDCConnection).where(
|
||||
OIDCConnection.user_id == user.id,
|
||||
OIDCConnection.provider == "test-idp",
|
||||
)
|
||||
)).scalar_one()
|
||||
|
||||
existing.refresh_token = "new-refresh-token"
|
||||
from wiregui.utils.time import utcnow
|
||||
existing.refreshed_at = utcnow()
|
||||
session.add(existing)
|
||||
await session.flush()
|
||||
|
||||
# Should still be one connection
|
||||
from sqlmodel import func
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 1
|
||||
|
||||
fetched = (await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar_one()
|
||||
assert fetched.refresh_token == "new-refresh-token"
|
||||
|
||||
|
||||
async def test_oidc_auto_create_user(session):
|
||||
"""When auto_create_users is True, a new user should be created from OIDC email."""
|
||||
email = "auto-created@example.com"
|
||||
|
||||
# Verify user doesn't exist
|
||||
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
|
||||
assert existing is None
|
||||
|
||||
# Simulate what callback does with auto_create
|
||||
user = User(email=email, role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
from wiregui.utils.time import utcnow
|
||||
user.last_signed_in_at = utcnow()
|
||||
user.last_signed_in_method = "oidc:test-idp"
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
created = (await session.execute(select(User).where(User.email == email))).scalar_one()
|
||||
assert created.role == "unprivileged"
|
||||
assert created.last_signed_in_method == "oidc:test-idp"
|
||||
|
||||
|
||||
async def test_oidc_disabled_user_rejected(session):
|
||||
"""Disabled users should not be logged in via OIDC."""
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
user = User(email="oidc-disabled@example.com", disabled_at=utcnow())
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# The callback route checks disabled_at before creating session
|
||||
assert user.disabled_at is not None # Would redirect to /login
|
||||
|
||||
|
||||
async def test_oidc_user_without_auto_create_rejected(session):
|
||||
"""When auto_create is False and user doesn't exist, login should fail."""
|
||||
email = "no-auto-create@example.com"
|
||||
|
||||
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
|
||||
assert existing is None
|
||||
|
||||
# The callback route checks auto_create_users from provider config
|
||||
# With auto_create=False and no existing user, it would redirect to /login
|
||||
# This verifies the precondition
|
||||
|
||||
|
||||
# --- OIDC refresh token flow ---
|
||||
|
||||
|
||||
async def test_oidc_refresh_stores_new_token(session):
|
||||
"""Simulates a successful token refresh updating the connection."""
|
||||
user = User(email="oidc-refresh-test@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(
|
||||
provider="test-idp",
|
||||
refresh_token="old-refresh",
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
# Simulate refresh result
|
||||
new_token = {
|
||||
"access_token": "new-access",
|
||||
"refresh_token": "new-refresh",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
conn.refresh_token = new_token.get("refresh_token", conn.refresh_token)
|
||||
conn.refresh_response = new_token
|
||||
from wiregui.utils.time import utcnow
|
||||
conn.refreshed_at = utcnow()
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(OIDCConnection, conn.id)
|
||||
assert fetched.refresh_token == "new-refresh"
|
||||
assert fetched.refresh_response["access_token"] == "new-access"
|
||||
assert fetched.refreshed_at is not None
|
||||
|
||||
|
||||
async def test_oidc_multiple_providers_per_user(session):
|
||||
"""User can have connections to multiple OIDC providers."""
|
||||
user = User(email="multi-provider@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
for provider in ["google", "okta", "azure-ad"]:
|
||||
session.add(OIDCConnection(
|
||||
provider=provider,
|
||||
refresh_token=f"token-{provider}",
|
||||
user_id=user.id,
|
||||
))
|
||||
await session.flush()
|
||||
|
||||
conns = (await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.user_id == user.id).order_by(OIDCConnection.provider)
|
||||
)).scalars().all()
|
||||
|
||||
assert len(conns) == 3
|
||||
assert [c.provider for c in conns] == ["azure-ad", "google", "okta"]
|
||||
58
tests/test_magic_link.py
Normal file
58
tests/test_magic_link.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Tests for magic link authentication flow."""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from wiregui.auth.jwt import create_access_token, decode_access_token
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
def test_magic_link_token_creation():
|
||||
"""Magic link token should be a valid JWT with short expiry."""
|
||||
token = create_access_token(
|
||||
user_id="user-123",
|
||||
role="unprivileged",
|
||||
expires_delta=timedelta(minutes=15),
|
||||
)
|
||||
payload = decode_access_token(token)
|
||||
assert payload is not None
|
||||
assert payload["sub"] == "user-123"
|
||||
assert payload["role"] == "unprivileged"
|
||||
|
||||
|
||||
def test_magic_link_token_expired():
|
||||
"""Expired magic link token should be rejected."""
|
||||
token = create_access_token(
|
||||
user_id="user-123",
|
||||
role="admin",
|
||||
expires_delta=timedelta(minutes=-1), # Already expired
|
||||
)
|
||||
payload = decode_access_token(token)
|
||||
assert payload is None
|
||||
|
||||
|
||||
def test_magic_link_token_wrong_user():
|
||||
"""Token should only be valid for the intended user."""
|
||||
token = create_access_token(user_id="user-A", role="admin")
|
||||
payload = decode_access_token(token)
|
||||
assert payload["sub"] == "user-A"
|
||||
# Caller is responsible for checking sub matches the URL user_id
|
||||
|
||||
|
||||
async def test_magic_link_disabled_user_rejected(session):
|
||||
"""Disabled users should not be able to use magic links."""
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
user = User(
|
||||
email="disabled-magic@example.com",
|
||||
password_hash=hash_password("pw"),
|
||||
disabled_at=utcnow(),
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# The token would be valid but the page handler checks disabled_at
|
||||
token = create_access_token(user_id=str(user.id), role="unprivileged")
|
||||
payload = decode_access_token(token)
|
||||
assert payload is not None # Token itself is valid
|
||||
assert user.disabled_at is not None # But user is disabled — handler would reject
|
||||
127
tests/test_mfa.py
Normal file
127
tests/test_mfa.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Tests for TOTP MFA functionality."""
|
||||
|
||||
import pyotp
|
||||
|
||||
from wiregui.auth.mfa import (
|
||||
generate_totp_qr_svg,
|
||||
generate_totp_secret,
|
||||
get_totp_uri,
|
||||
verify_totp_code,
|
||||
)
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
# --- TOTP secret generation ---
|
||||
|
||||
|
||||
def test_generate_secret():
|
||||
secret = generate_totp_secret()
|
||||
assert len(secret) == 32 # base32 encoded
|
||||
assert secret.isalpha() or any(c.isdigit() for c in secret)
|
||||
|
||||
|
||||
def test_generate_secret_unique():
|
||||
s1 = generate_totp_secret()
|
||||
s2 = generate_totp_secret()
|
||||
assert s1 != s2
|
||||
|
||||
|
||||
# --- TOTP URI ---
|
||||
|
||||
|
||||
def test_get_totp_uri():
|
||||
uri = get_totp_uri("JBSWY3DPEHPK3PXP", "user@example.com")
|
||||
assert uri.startswith("otpauth://totp/")
|
||||
assert "user%40example.com" in uri or "user@example.com" in uri
|
||||
assert "secret=JBSWY3DPEHPK3PXP" in uri
|
||||
assert "issuer=WireGUI" in uri
|
||||
|
||||
|
||||
def test_get_totp_uri_custom_issuer():
|
||||
uri = get_totp_uri("SECRET", "test@test.com", issuer="MyVPN")
|
||||
assert "issuer=MyVPN" in uri
|
||||
|
||||
|
||||
# --- TOTP verification ---
|
||||
|
||||
|
||||
def test_verify_valid_code():
|
||||
secret = generate_totp_secret()
|
||||
totp = pyotp.TOTP(secret)
|
||||
code = totp.now()
|
||||
assert verify_totp_code(secret, code) is True
|
||||
|
||||
|
||||
def test_verify_invalid_code():
|
||||
secret = generate_totp_secret()
|
||||
assert verify_totp_code(secret, "000000") is False
|
||||
|
||||
|
||||
def test_verify_wrong_secret():
|
||||
secret1 = generate_totp_secret()
|
||||
secret2 = generate_totp_secret()
|
||||
code = pyotp.TOTP(secret1).now()
|
||||
assert verify_totp_code(secret2, code) is False
|
||||
|
||||
|
||||
def test_verify_empty_code():
|
||||
secret = generate_totp_secret()
|
||||
assert verify_totp_code(secret, "") is False
|
||||
|
||||
|
||||
# --- QR code generation ---
|
||||
|
||||
|
||||
def test_generate_qr_svg():
|
||||
uri = get_totp_uri("SECRET", "test@test.com")
|
||||
svg = generate_totp_qr_svg(uri)
|
||||
assert "<svg" in svg
|
||||
assert "</svg>" in svg
|
||||
|
||||
|
||||
# --- MFA method model integration ---
|
||||
|
||||
|
||||
async def test_create_totp_method(session):
|
||||
user = User(email="mfa-test@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(
|
||||
name="My Phone",
|
||||
type="totp",
|
||||
payload={"secret": secret},
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
from sqlmodel import select
|
||||
fetched = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar_one()
|
||||
|
||||
assert fetched.name == "My Phone"
|
||||
assert fetched.type == "totp"
|
||||
stored_secret = fetched.payload["secret"]
|
||||
code = pyotp.TOTP(stored_secret).now()
|
||||
assert verify_totp_code(stored_secret, code) is True
|
||||
|
||||
|
||||
async def test_user_multiple_mfa_methods(session):
|
||||
user = User(email="multi-mfa@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
m1 = MFAMethod(name="Phone", type="totp", payload={"secret": generate_totp_secret()}, user_id=user.id)
|
||||
m2 = MFAMethod(name="Backup", type="totp", payload={"secret": generate_totp_secret()}, user_id=user.id)
|
||||
session.add_all([m1, m2])
|
||||
await session.flush()
|
||||
|
||||
from sqlmodel import select, func
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 2
|
||||
168
tests/test_models.py
Normal file
168
tests/test_models.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
"""Tests for SQLModel table definitions."""
|
||||
|
||||
import pytest # noqa: F401 — needed for pytest.raises
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.connectivity_check import ConnectivityCheck
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
async def test_create_user(session):
|
||||
user = User(email="alice@example.com", role="admin")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
result = await session.execute(select(User).where(User.email == "alice@example.com"))
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.id == user.id
|
||||
assert fetched.role == "admin"
|
||||
assert fetched.disabled_at is None
|
||||
|
||||
|
||||
async def test_create_device_with_user(session):
|
||||
user = User(email="bob@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(
|
||||
name="laptop",
|
||||
public_key="pk-test-device-001",
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
result = await session.execute(select(Device).where(Device.public_key == "pk-test-device-001"))
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.name == "laptop"
|
||||
assert fetched.user_id == user.id
|
||||
assert fetched.use_default_dns is True
|
||||
assert fetched.use_default_allowed_ips is True
|
||||
assert fetched.rx_bytes is None
|
||||
|
||||
|
||||
async def test_device_unique_public_key(session):
|
||||
user = User(email="carol@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
d1 = Device(name="d1", public_key="duplicate-key", user_id=user.id)
|
||||
session.add(d1)
|
||||
await session.flush()
|
||||
|
||||
d2 = Device(name="d2", public_key="duplicate-key", user_id=user.id)
|
||||
session.add(d2)
|
||||
with pytest.raises(Exception): # IntegrityError
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def test_create_rule(session):
|
||||
user = User(email="dave@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id)
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
|
||||
result = await session.execute(select(Rule).where(Rule.user_id == user.id))
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.action == "accept"
|
||||
assert fetched.destination == "10.0.0.0/8"
|
||||
assert fetched.port_type is None
|
||||
assert fetched.port_range is None
|
||||
|
||||
|
||||
async def test_create_rule_with_port(session):
|
||||
rule = Rule(
|
||||
action="drop",
|
||||
destination="192.168.0.0/16",
|
||||
port_type="tcp",
|
||||
port_range="80-443",
|
||||
)
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(Rule).where(Rule.id == rule.id))).scalar_one()
|
||||
assert fetched.port_type == "tcp"
|
||||
assert fetched.port_range == "80-443"
|
||||
assert fetched.user_id is None # global rule
|
||||
|
||||
|
||||
async def test_create_mfa_method(session):
|
||||
user = User(email="eve@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
mfa = MFAMethod(
|
||||
name="My Authenticator",
|
||||
type="totp",
|
||||
payload={"secret": "JBSWY3DPEHPK3PXP"},
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(mfa)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(MFAMethod).where(MFAMethod.user_id == user.id))).scalar_one()
|
||||
assert fetched.type == "totp"
|
||||
assert fetched.payload["secret"] == "JBSWY3DPEHPK3PXP"
|
||||
|
||||
|
||||
async def test_create_oidc_connection(session):
|
||||
user = User(email="frank@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(provider="google", refresh_token="tok_abc", user_id=user.id)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(OIDCConnection).where(OIDCConnection.user_id == user.id))).scalar_one()
|
||||
assert fetched.provider == "google"
|
||||
assert fetched.refresh_token == "tok_abc"
|
||||
|
||||
|
||||
async def test_create_api_token(session):
|
||||
user = User(email="grace@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
token = ApiToken(token_hash="sha256_fake_hash", user_id=user.id)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(ApiToken).where(ApiToken.user_id == user.id))).scalar_one()
|
||||
assert fetched.token_hash == "sha256_fake_hash"
|
||||
assert fetched.expires_at is None
|
||||
|
||||
|
||||
async def test_create_connectivity_check(session):
|
||||
check = ConnectivityCheck(url="https://example.com", response_code=200)
|
||||
session.add(check)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(ConnectivityCheck).where(ConnectivityCheck.id == check.id))).scalar_one()
|
||||
assert fetched.response_code == 200
|
||||
|
||||
|
||||
async def test_configuration_defaults(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(Configuration).where(Configuration.id == config.id))).scalar_one()
|
||||
assert fetched.allow_unprivileged_device_management is True
|
||||
assert fetched.local_auth_enabled is True
|
||||
assert fetched.default_client_mtu == 1280
|
||||
assert fetched.default_client_persistent_keepalive == 25
|
||||
assert fetched.default_client_dns == ["1.1.1.1", "1.0.0.1"]
|
||||
assert fetched.default_client_allowed_ips == ["0.0.0.0/0", "::/0"]
|
||||
assert fetched.vpn_session_duration == 0
|
||||
assert fetched.openid_connect_providers == []
|
||||
assert fetched.saml_identity_providers == []
|
||||
89
tests/test_notifications.py
Normal file
89
tests/test_notifications.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""Tests for the notification service."""
|
||||
|
||||
from wiregui.services import notifications
|
||||
|
||||
|
||||
def setup_function():
|
||||
"""Clear notifications before each test."""
|
||||
notifications.clear_all()
|
||||
|
||||
|
||||
def test_add_notification():
|
||||
n = notifications.add("info", "Test message")
|
||||
assert n.severity == "info"
|
||||
assert n.message == "Test message"
|
||||
assert n.user is None
|
||||
assert n.id is not None
|
||||
assert n.timestamp is not None
|
||||
|
||||
|
||||
def test_add_notification_with_user():
|
||||
n = notifications.add("error", "Something broke", user="admin@example.com")
|
||||
assert n.user == "admin@example.com"
|
||||
assert n.severity == "error"
|
||||
|
||||
|
||||
def test_current_returns_newest_first():
|
||||
notifications.add("info", "First")
|
||||
notifications.add("warning", "Second")
|
||||
notifications.add("error", "Third")
|
||||
|
||||
current = notifications.current()
|
||||
assert len(current) == 3
|
||||
assert current[0].message == "Third"
|
||||
assert current[1].message == "Second"
|
||||
assert current[2].message == "First"
|
||||
|
||||
|
||||
def test_count():
|
||||
assert notifications.count() == 0
|
||||
notifications.add("info", "One")
|
||||
notifications.add("info", "Two")
|
||||
assert notifications.count() == 2
|
||||
|
||||
|
||||
def test_clear_specific():
|
||||
n1 = notifications.add("info", "Keep this")
|
||||
n2 = notifications.add("error", "Remove this")
|
||||
|
||||
notifications.clear(n2.id)
|
||||
current = notifications.current()
|
||||
assert len(current) == 1
|
||||
assert current[0].id == n1.id
|
||||
|
||||
|
||||
def test_clear_nonexistent_id_is_noop():
|
||||
notifications.add("info", "Test")
|
||||
notifications.clear("nonexistent-id")
|
||||
assert notifications.count() == 1
|
||||
|
||||
|
||||
def test_clear_all():
|
||||
notifications.add("info", "One")
|
||||
notifications.add("info", "Two")
|
||||
notifications.add("info", "Three")
|
||||
assert notifications.count() == 3
|
||||
|
||||
notifications.clear_all()
|
||||
assert notifications.count() == 0
|
||||
assert notifications.current() == []
|
||||
|
||||
|
||||
def test_to_dict():
|
||||
n = notifications.add("warning", "Test dict", user="someone@example.com")
|
||||
d = n.to_dict()
|
||||
assert d["severity"] == "warning"
|
||||
assert d["message"] == "Test dict"
|
||||
assert d["user"] == "someone@example.com"
|
||||
assert "id" in d
|
||||
assert "timestamp" in d
|
||||
|
||||
|
||||
def test_max_notifications():
|
||||
"""Deque should cap at MAX_NOTIFICATIONS."""
|
||||
for i in range(notifications.MAX_NOTIFICATIONS + 10):
|
||||
notifications.add("info", f"Notification {i}")
|
||||
|
||||
assert notifications.count() == notifications.MAX_NOTIFICATIONS
|
||||
# Newest should be the last one added
|
||||
assert notifications.current()[0].message == f"Notification {notifications.MAX_NOTIFICATIONS + 9}"
|
||||
124
tests/test_services.py
Normal file
124
tests/test_services.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""Tests for services — WireGuard and events."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated, on_rule_created
|
||||
|
||||
|
||||
def _make_device(**kwargs) -> Device:
|
||||
defaults = dict(
|
||||
name="test",
|
||||
public_key="pk-test",
|
||||
preshared_key="psk-test",
|
||||
ipv4="10.3.2.5",
|
||||
ipv6="fd00::3:2:5",
|
||||
user_id="00000000-0000-0000-0000-000000000000",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return Device(**defaults)
|
||||
|
||||
|
||||
# --- Events (with WG enabled) ---
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.firewall")
|
||||
@patch("wiregui.services.events.wireguard")
|
||||
async def test_on_device_created_calls_add_peer(mock_wg, mock_fw, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
mock_fw.add_device_jump_rule = AsyncMock()
|
||||
|
||||
device = _make_device()
|
||||
await on_device_created(device)
|
||||
|
||||
mock_wg.add_peer.assert_awaited_once_with(
|
||||
public_key="pk-test",
|
||||
allowed_ips=["10.3.2.5/32", "fd00::3:2:5/128"],
|
||||
preshared_key="psk-test",
|
||||
)
|
||||
mock_fw.add_device_jump_rule.assert_awaited_once()
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.wireguard")
|
||||
async def test_on_device_deleted_calls_remove_peer(mock_wg, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
device = _make_device()
|
||||
await on_device_deleted(device)
|
||||
|
||||
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-test")
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.wireguard")
|
||||
async def test_on_device_updated_calls_add_peer(mock_wg, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
|
||||
device = _make_device()
|
||||
await on_device_updated(device)
|
||||
|
||||
mock_wg.add_peer.assert_awaited_once()
|
||||
|
||||
|
||||
# --- Events (WG disabled) ---
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.wireguard")
|
||||
async def test_events_skip_when_wg_disabled(mock_wg, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = False
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
device = _make_device()
|
||||
await on_device_created(device)
|
||||
await on_device_deleted(device)
|
||||
await on_device_updated(device)
|
||||
|
||||
mock_wg.add_peer.assert_not_awaited()
|
||||
mock_wg.remove_peer.assert_not_awaited()
|
||||
|
||||
|
||||
# --- Events (WG error handling) ---
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.firewall")
|
||||
@patch("wiregui.services.events.wireguard")
|
||||
async def test_on_device_created_handles_wg_error(mock_wg, mock_fw, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_wg.add_peer = AsyncMock(side_effect=RuntimeError("wg failed"))
|
||||
mock_fw.add_device_jump_rule = AsyncMock()
|
||||
|
||||
device = _make_device()
|
||||
# Should not raise — error is logged
|
||||
await on_device_created(device)
|
||||
|
||||
|
||||
# --- Rule events ---
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.firewall")
|
||||
async def test_on_rule_created_calls_apply_rule(mock_fw, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_fw.apply_rule = AsyncMock()
|
||||
|
||||
rule = Rule(
|
||||
action="accept",
|
||||
destination="10.0.0.0/8",
|
||||
port_type="tcp",
|
||||
port_range="80",
|
||||
user_id="00000000-0000-0000-0000-000000000000",
|
||||
)
|
||||
await on_rule_created(rule)
|
||||
|
||||
mock_fw.apply_rule.assert_awaited_once_with(
|
||||
"00000000-0000-0000-0000-000000000000", "10.0.0.0/8", "accept", "tcp", "80",
|
||||
)
|
||||
203
tests/test_services_extended.py
Normal file
203
tests/test_services_extended.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
"""Extended service tests — wireguard subprocess mocking, firewall nft mocking, email."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from wiregui.services.wireguard import PeerInfo, add_peer, get_peers, remove_peer
|
||||
|
||||
|
||||
# ========== WireGuard service (mocked subprocess) ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
async def test_add_peer_without_psk(mock_run):
|
||||
mock_run.return_value = ""
|
||||
await add_peer("pubkey123", ["10.0.0.1/32", "fd00::1/128"], iface="wg-test")
|
||||
mock_run.assert_awaited_once()
|
||||
args = mock_run.call_args[0][0]
|
||||
assert "wg" in args
|
||||
assert "set" in args
|
||||
assert "pubkey123" in args
|
||||
assert "10.0.0.1/32,fd00::1/128" in args
|
||||
|
||||
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_add_peer_with_psk(mock_exec):
|
||||
"""PSK path uses subprocess directly with stdin."""
|
||||
mock_proc = AsyncMock()
|
||||
mock_proc.communicate.return_value = (b"", b"")
|
||||
mock_proc.returncode = 0
|
||||
mock_exec.return_value = mock_proc
|
||||
|
||||
await add_peer("pubkey456", ["10.0.0.2/32"], preshared_key="psk-data", iface="wg-test")
|
||||
mock_exec.assert_awaited_once()
|
||||
call_args = mock_exec.call_args[0]
|
||||
assert "preshared-key" in call_args
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
async def test_remove_peer(mock_run):
|
||||
mock_run.return_value = ""
|
||||
await remove_peer("pubkey789", iface="wg-test")
|
||||
mock_run.assert_awaited_once()
|
||||
args = mock_run.call_args[0][0]
|
||||
assert "remove" in args
|
||||
assert "pubkey789" in args
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
async def test_get_peers_parses_dump(mock_run):
|
||||
dump_output = (
|
||||
"privkey\tpubkey\t51820\toff\n"
|
||||
"peerkey1\t(none)\t1.2.3.4:51820\t10.0.0.1/32\t1700000000\t12345\t67890\t25\n"
|
||||
"peerkey2\t(none)\t(none)\t10.0.0.2/32,fd00::2/128\t0\t0\t0\t0\n"
|
||||
)
|
||||
mock_run.return_value = dump_output
|
||||
|
||||
peers = await get_peers(iface="wg-test")
|
||||
assert len(peers) == 2
|
||||
|
||||
assert peers[0].public_key == "peerkey1"
|
||||
assert peers[0].endpoint == "1.2.3.4:51820"
|
||||
assert peers[0].rx_bytes == 12345
|
||||
assert peers[0].tx_bytes == 67890
|
||||
assert peers[0].latest_handshake is not None
|
||||
|
||||
assert peers[1].public_key == "peerkey2"
|
||||
assert peers[1].endpoint is None
|
||||
assert peers[1].rx_bytes == 0
|
||||
assert peers[1].latest_handshake is None
|
||||
assert len(peers[1].allowed_ips) == 2
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
async def test_get_peers_returns_empty_on_error(mock_run):
|
||||
mock_run.side_effect = RuntimeError("interface not found")
|
||||
peers = await get_peers(iface="wg-test")
|
||||
assert peers == []
|
||||
|
||||
|
||||
# ========== Firewall (mocked nft) ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_setup_base_tables(mock_batch):
|
||||
from wiregui.services.firewall import setup_base_tables
|
||||
await setup_base_tables()
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("add table" in c for c in cmds)
|
||||
assert any("forward" in c for c in cmds)
|
||||
assert any("postrouting" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_add_user_chain(mock_batch):
|
||||
from wiregui.services.firewall import add_user_chain
|
||||
await add_user_chain("a1b2c3d4-0000-0000-0000-000000000000")
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("user_a1b2c3d40000" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_remove_user_chain(mock_batch):
|
||||
from wiregui.services.firewall import remove_user_chain
|
||||
await remove_user_chain("a1b2c3d4-0000-0000-0000-000000000000")
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("flush" in c for c in cmds)
|
||||
assert any("delete" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_add_device_jump_rule(mock_batch):
|
||||
from wiregui.services.firewall import add_device_jump_rule
|
||||
await add_device_jump_rule("user-id-123", "10.0.0.5", "fd00::5")
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("10.0.0.5" in c and "jump" in c for c in cmds)
|
||||
assert any("fd00::5" in c and "jump" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_apply_rule(mock_batch):
|
||||
from wiregui.services.firewall import apply_rule
|
||||
await apply_rule("user-123", "10.0.0.0/8", "accept", "tcp", "80-443")
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("10.0.0.0/8" in c and "accept" in c and "tcp dport 80-443" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_rebuild_all_rules(mock_batch):
|
||||
from wiregui.services.firewall import rebuild_all_rules
|
||||
await rebuild_all_rules([
|
||||
{
|
||||
"user_id": "user-1",
|
||||
"devices": [{"ipv4": "10.0.0.1", "ipv6": "fd00::1"}],
|
||||
"rules": [
|
||||
{"destination": "0.0.0.0/0", "action": "accept", "port_type": None, "port_range": None},
|
||||
{"destination": "192.168.0.0/16", "action": "drop", "port_type": "tcp", "port_range": "22"},
|
||||
],
|
||||
}
|
||||
])
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("flush chain" in c and "forward" in c for c in cmds)
|
||||
assert any("0.0.0.0/0" in c and "accept" in c for c in cmds)
|
||||
assert any("192.168.0.0/16" in c and "drop" in c for c in cmds)
|
||||
assert any("10.0.0.1" in c and "jump" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_setup_masquerade(mock_batch):
|
||||
from wiregui.services.firewall import setup_masquerade
|
||||
await setup_masquerade(iface="wg0")
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("masquerade" in c for c in cmds)
|
||||
|
||||
|
||||
# ========== Email service (mocked smtp) ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.email.aiosmtplib.send", new_callable=AsyncMock)
|
||||
async def test_send_email_success(mock_send, monkeypatch):
|
||||
monkeypatch.setattr("wiregui.services.email.get_settings", lambda: type("S", (), {
|
||||
"smtp_host": "smtp.test.com",
|
||||
"smtp_port": 587,
|
||||
"smtp_user": "user",
|
||||
"smtp_password": "pass",
|
||||
"smtp_from": "test@test.com",
|
||||
})())
|
||||
|
||||
from wiregui.services.email import send_email
|
||||
result = await send_email("to@test.com", "Subject", "Body")
|
||||
assert result is True
|
||||
mock_send.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_send_email_no_smtp_configured(monkeypatch):
|
||||
monkeypatch.setattr("wiregui.services.email.get_settings", lambda: type("S", (), {
|
||||
"smtp_host": None,
|
||||
})())
|
||||
|
||||
from wiregui.services.email import send_email
|
||||
result = await send_email("to@test.com", "Subject", "Body")
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("wiregui.services.email.aiosmtplib.send", new_callable=AsyncMock)
|
||||
async def test_send_magic_link(mock_send, monkeypatch):
|
||||
monkeypatch.setattr("wiregui.services.email.get_settings", lambda: type("S", (), {
|
||||
"smtp_host": "smtp.test.com",
|
||||
"smtp_port": 587,
|
||||
"smtp_user": "u",
|
||||
"smtp_password": "p",
|
||||
"smtp_from": "noreply@test.com",
|
||||
})())
|
||||
|
||||
from wiregui.services.email import send_magic_link
|
||||
result = await send_magic_link("user@test.com", "https://app.test/magic/123/token")
|
||||
assert result is True
|
||||
mock_send.assert_awaited_once()
|
||||
231
tests/test_tasks.py
Normal file
231
tests/test_tasks.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
"""Tests for background tasks — VPN session expiry and connectivity checks."""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.connectivity_check import ConnectivityCheck
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# --- VPN session expiry ---
|
||||
|
||||
|
||||
async def test_vpn_session_expiry_removes_expired_peers(session, monkeypatch):
|
||||
"""Users whose session expired should have their WG peers removed."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
|
||||
|
||||
# Create config with 1-hour session duration
|
||||
config = Configuration(vpn_session_duration=3600)
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
# Create a user who signed in 2 hours ago (expired)
|
||||
expired_user = User(
|
||||
email="expired@example.com",
|
||||
password_hash=hash_password("pw"),
|
||||
last_signed_in_at=utcnow() - timedelta(hours=2),
|
||||
)
|
||||
session.add(expired_user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(name="laptop", public_key="pk-expired", user_id=expired_user.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
# Create a user who signed in 30 min ago (still valid)
|
||||
active_user = User(
|
||||
email="active@example.com",
|
||||
password_hash=hash_password("pw"),
|
||||
last_signed_in_at=utcnow() - timedelta(minutes=30),
|
||||
)
|
||||
session.add(active_user)
|
||||
await session.flush()
|
||||
|
||||
active_device = Device(name="phone", public_key="pk-active", user_id=active_user.id)
|
||||
session.add(active_device)
|
||||
await session.flush()
|
||||
|
||||
# Mock WireGuard
|
||||
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.vpn_session import _expire_sessions
|
||||
await _expire_sessions()
|
||||
|
||||
# Only expired user's peer should be removed
|
||||
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-expired")
|
||||
|
||||
|
||||
async def test_vpn_session_no_expiry_when_duration_zero(session, monkeypatch):
|
||||
"""When vpn_session_duration is 0 (unlimited), no peers should be removed."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
|
||||
|
||||
config = Configuration(vpn_session_duration=0)
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
email="unlimited@example.com",
|
||||
last_signed_in_at=utcnow() - timedelta(days=365),
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.vpn_session import _expire_sessions
|
||||
await _expire_sessions()
|
||||
|
||||
mock_wg.remove_peer.assert_not_awaited()
|
||||
|
||||
|
||||
async def test_vpn_session_no_expiry_when_no_config(session, monkeypatch):
|
||||
"""When no Configuration exists, no peers should be removed."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
|
||||
|
||||
# No Configuration row at all
|
||||
|
||||
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.vpn_session import _expire_sessions
|
||||
await _expire_sessions()
|
||||
|
||||
mock_wg.remove_peer.assert_not_awaited()
|
||||
|
||||
|
||||
async def test_vpn_session_skips_disabled_users(session, monkeypatch):
|
||||
"""Disabled users should be skipped even if their session is expired."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
|
||||
|
||||
config = Configuration(vpn_session_duration=3600)
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
email="disabled-session@example.com",
|
||||
last_signed_in_at=utcnow() - timedelta(hours=2),
|
||||
disabled_at=utcnow(),
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(name="d", public_key="pk-disabled-session", user_id=user.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.vpn_session import _expire_sessions
|
||||
await _expire_sessions()
|
||||
|
||||
mock_wg.remove_peer.assert_not_awaited()
|
||||
|
||||
|
||||
# --- Connectivity checks ---
|
||||
|
||||
|
||||
async def test_connectivity_check_success(session, monkeypatch):
|
||||
"""Successful connectivity check should store result in DB."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.connectivity.async_session", mock_session)
|
||||
|
||||
# Mock httpx to return a successful response
|
||||
import httpx
|
||||
|
||||
class MockResponse:
|
||||
status_code = 200
|
||||
headers = {"content-type": "text/plain"}
|
||||
text = "203.0.113.1"
|
||||
|
||||
class MockAsyncClient:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
async def get(self, url):
|
||||
return MockResponse()
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.connectivity.httpx.AsyncClient", lambda **kw: MockAsyncClient())
|
||||
|
||||
from wiregui.tasks.connectivity import _check_connectivity
|
||||
await _check_connectivity()
|
||||
|
||||
result = (await session.execute(select(ConnectivityCheck).limit(1))).scalar_one()
|
||||
assert result.response_code == 200
|
||||
assert result.response_body == "203.0.113.1"
|
||||
|
||||
|
||||
async def test_connectivity_check_failure(session, monkeypatch):
|
||||
"""Failed connectivity check should store error and create notification."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.connectivity.async_session", mock_session)
|
||||
|
||||
class MockAsyncClient:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
async def get(self, url):
|
||||
raise ConnectionError("Network unreachable")
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.connectivity.httpx.AsyncClient", lambda **kw: MockAsyncClient())
|
||||
|
||||
from wiregui.services import notifications
|
||||
notifications.clear_all()
|
||||
|
||||
from wiregui.tasks.connectivity import _check_connectivity
|
||||
await _check_connectivity()
|
||||
|
||||
result = (await session.execute(select(ConnectivityCheck).limit(1))).scalar_one()
|
||||
assert result.response_code is None
|
||||
assert "Network unreachable" in result.response_body
|
||||
|
||||
assert notifications.count() > 0
|
||||
assert "connectivity" in notifications.current()[0].message.lower()
|
||||
229
tests/test_tasks_extended.py
Normal file
229
tests/test_tasks_extended.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""Extended task tests — stats polling, reconciliation, OIDC refresh."""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.user import User
|
||||
from wiregui.services.wireguard import PeerInfo
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# ========== Stats task ==========
|
||||
|
||||
|
||||
async def test_stats_update_from_wg_peers(session, monkeypatch):
|
||||
"""Stats task should update device records from WireGuard peer data."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.stats.async_session", mock_session)
|
||||
|
||||
user = User(email="stats-user@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(name="stats-dev", public_key="pk-stats-test", user_id=user.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
mock_peers = [
|
||||
PeerInfo(
|
||||
public_key="pk-stats-test",
|
||||
endpoint="1.2.3.4:51820",
|
||||
rx_bytes=123456,
|
||||
tx_bytes=789012,
|
||||
latest_handshake=utcnow(),
|
||||
)
|
||||
]
|
||||
|
||||
with patch("wiregui.tasks.stats.wireguard") as mock_wg:
|
||||
mock_wg.get_peers = AsyncMock(return_value=mock_peers)
|
||||
from wiregui.tasks.stats import _update_stats
|
||||
await _update_stats()
|
||||
|
||||
refreshed = await session.get(Device, device.id)
|
||||
assert refreshed.rx_bytes == 123456
|
||||
assert refreshed.tx_bytes == 789012
|
||||
assert refreshed.remote_ip == "1.2.3.4"
|
||||
assert refreshed.latest_handshake is not None
|
||||
|
||||
|
||||
async def test_stats_no_peers_is_noop(session, monkeypatch):
|
||||
"""No WG peers should result in no DB changes."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.stats.async_session", mock_session)
|
||||
|
||||
with patch("wiregui.tasks.stats.wireguard") as mock_wg:
|
||||
mock_wg.get_peers = AsyncMock(return_value=[])
|
||||
from wiregui.tasks.stats import _update_stats
|
||||
await _update_stats() # Should not raise
|
||||
|
||||
|
||||
async def test_stats_unmatched_peer_ignored(session, monkeypatch):
|
||||
"""Peers not matching any device should be ignored."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.stats.async_session", mock_session)
|
||||
|
||||
mock_peers = [
|
||||
PeerInfo(public_key="unknown-peer-key", rx_bytes=100, tx_bytes=200)
|
||||
]
|
||||
|
||||
with patch("wiregui.tasks.stats.wireguard") as mock_wg:
|
||||
mock_wg.get_peers = AsyncMock(return_value=mock_peers)
|
||||
from wiregui.tasks.stats import _update_stats
|
||||
await _update_stats() # Should not raise
|
||||
|
||||
|
||||
# ========== Reconciliation task ==========
|
||||
|
||||
|
||||
async def test_reconcile_adds_missing_peers(session, monkeypatch):
|
||||
"""Devices in DB but not in WG should be added."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.reconcile.async_session", mock_session)
|
||||
|
||||
user = User(email="reconcile@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(name="missing", public_key="pk-missing", ipv4="10.0.0.5", user_id=user.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
with patch("wiregui.tasks.reconcile.wireguard") as mock_wg:
|
||||
mock_wg.get_peers = AsyncMock(return_value=[]) # WG has no peers
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.reconcile import reconcile
|
||||
await reconcile()
|
||||
|
||||
mock_wg.add_peer.assert_awaited_once()
|
||||
call_kwargs = mock_wg.add_peer.call_args[1]
|
||||
assert call_kwargs["public_key"] == "pk-missing"
|
||||
assert "10.0.0.5/32" in call_kwargs["allowed_ips"]
|
||||
mock_wg.remove_peer.assert_not_awaited()
|
||||
|
||||
|
||||
async def test_reconcile_removes_orphaned_peers(session, monkeypatch):
|
||||
"""Peers in WG but not in DB should be removed."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.reconcile.async_session", mock_session)
|
||||
|
||||
# No devices in DB, but WG has a peer
|
||||
orphan = PeerInfo(public_key="pk-orphan", rx_bytes=0, tx_bytes=0)
|
||||
|
||||
with patch("wiregui.tasks.reconcile.wireguard") as mock_wg:
|
||||
mock_wg.get_peers = AsyncMock(return_value=[orphan])
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.reconcile import reconcile
|
||||
await reconcile()
|
||||
|
||||
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-orphan")
|
||||
mock_wg.add_peer.assert_not_awaited()
|
||||
|
||||
|
||||
async def test_reconcile_in_sync(session, monkeypatch):
|
||||
"""When DB and WG match, nothing should happen."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.reconcile.async_session", mock_session)
|
||||
|
||||
user = User(email="in-sync@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(name="synced", public_key="pk-synced", user_id=user.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
peer = PeerInfo(public_key="pk-synced", rx_bytes=0, tx_bytes=0)
|
||||
|
||||
with patch("wiregui.tasks.reconcile.wireguard") as mock_wg:
|
||||
mock_wg.get_peers = AsyncMock(return_value=[peer])
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.reconcile import reconcile
|
||||
await reconcile()
|
||||
|
||||
mock_wg.add_peer.assert_not_awaited()
|
||||
mock_wg.remove_peer.assert_not_awaited()
|
||||
|
||||
|
||||
# ========== OIDC refresh task ==========
|
||||
|
||||
|
||||
async def test_oidc_refresh_no_connections_is_noop(session, monkeypatch):
|
||||
"""No OIDC connections should result in no refresh attempts."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.oidc_refresh.async_session", mock_session)
|
||||
monkeypatch.setattr("wiregui.auth.oidc.load_providers", AsyncMock(return_value=[]))
|
||||
|
||||
from wiregui.tasks.oidc_refresh import _refresh_all
|
||||
await _refresh_all() # Should not raise
|
||||
|
||||
|
||||
async def test_oidc_refresh_skips_unknown_provider(session, monkeypatch):
|
||||
"""Connections for unknown providers should be skipped."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.oidc_refresh.async_session", mock_session)
|
||||
monkeypatch.setattr("wiregui.auth.oidc.load_providers", AsyncMock(return_value=[
|
||||
{"id": "known-provider", "client_id": "cid", "client_secret": "cs", "discovery_document_uri": "https://x"}
|
||||
]))
|
||||
|
||||
user = User(email="oidc-skip@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(provider="unknown-provider", refresh_token="tok", user_id=user.id)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
from wiregui.tasks.oidc_refresh import _refresh_all
|
||||
await _refresh_all() # Should skip gracefully
|
||||
120
tests/test_utils.py
Normal file
120
tests/test_utils.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
"""Tests for utility modules."""
|
||||
|
||||
import subprocess
|
||||
|
||||
import pytest
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.network import allocate_ipv4, allocate_ipv6
|
||||
from wiregui.utils.wg_conf import build_client_config
|
||||
|
||||
|
||||
# --- IP allocation ---
|
||||
|
||||
|
||||
async def test_allocate_ipv4_first_device(session):
|
||||
user = User(email="net-test@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
ip = await allocate_ipv4(session, "10.3.2.0/24")
|
||||
assert ip.startswith("10.3.2.")
|
||||
# Should not be the network (.0) or gateway (.1)
|
||||
last_octet = int(ip.split(".")[-1])
|
||||
assert last_octet >= 2
|
||||
|
||||
|
||||
async def test_allocate_ipv4_skips_used(session):
|
||||
user = User(email="net-skip@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# Exhaust a tiny /30 network (4 addresses: .0 network, .1 gateway, .2 usable, .3 broadcast)
|
||||
d1 = Device(name="d1", public_key="pk-net-1", ipv4="10.99.0.2", user_id=user.id)
|
||||
session.add(d1)
|
||||
await session.flush()
|
||||
|
||||
# Only .2 was usable in a /30 — allocation should fail
|
||||
with pytest.raises(ValueError, match="No available"):
|
||||
await allocate_ipv4(session, "10.99.0.0/30")
|
||||
|
||||
|
||||
async def test_allocate_ipv6(session):
|
||||
user = User(email="net6-test@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
ip = await allocate_ipv6(session, "fd00::3:2:0/120")
|
||||
assert ip.startswith("fd00::3:2:")
|
||||
|
||||
|
||||
# --- WireGuard config builder ---
|
||||
|
||||
|
||||
def test_build_client_config():
|
||||
device = Device(
|
||||
name="test-device",
|
||||
public_key="device-pub-key",
|
||||
preshared_key="device-psk",
|
||||
ipv4="10.3.2.5",
|
||||
ipv6="fd00::3:2:5",
|
||||
use_default_allowed_ips=True,
|
||||
use_default_dns=True,
|
||||
use_default_endpoint=True,
|
||||
use_default_mtu=True,
|
||||
use_default_persistent_keepalive=True,
|
||||
user_id="00000000-0000-0000-0000-000000000000",
|
||||
)
|
||||
|
||||
config = build_client_config(device, "PRIVATE_KEY_HERE", "SERVER_PUB_KEY")
|
||||
|
||||
assert "[Interface]" in config
|
||||
assert "PrivateKey = PRIVATE_KEY_HERE" in config
|
||||
assert "10.3.2.5/32" in config
|
||||
assert "fd00::3:2:5/128" in config
|
||||
assert "[Peer]" in config
|
||||
assert "PublicKey = SERVER_PUB_KEY" in config
|
||||
assert "PresharedKey = device-psk" in config
|
||||
assert "Endpoint = " in config
|
||||
|
||||
|
||||
def test_build_client_config_no_psk():
|
||||
device = Device(
|
||||
name="no-psk",
|
||||
public_key="pub",
|
||||
preshared_key=None,
|
||||
ipv4="10.3.2.6",
|
||||
ipv6=None,
|
||||
use_default_allowed_ips=True,
|
||||
use_default_dns=True,
|
||||
use_default_endpoint=True,
|
||||
use_default_mtu=True,
|
||||
use_default_persistent_keepalive=True,
|
||||
user_id="00000000-0000-0000-0000-000000000000",
|
||||
)
|
||||
|
||||
config = build_client_config(device, "PRIV", "SERVPUB")
|
||||
assert "PresharedKey" not in config
|
||||
assert "fd00::" not in config # no ipv6
|
||||
|
||||
|
||||
# --- Crypto (only if wg is installed) ---
|
||||
|
||||
|
||||
def test_generate_keypair():
|
||||
"""Test keypair generation — requires `wg` CLI to be installed."""
|
||||
try:
|
||||
subprocess.run(["wg", "--version"], capture_output=True, check=True)
|
||||
except FileNotFoundError:
|
||||
pytest.skip("wg CLI not installed")
|
||||
|
||||
from wiregui.utils.crypto import generate_keypair, generate_preshared_key
|
||||
|
||||
priv, pub = generate_keypair()
|
||||
assert len(priv) == 44 # base64-encoded 32 bytes
|
||||
assert len(pub) == 44
|
||||
|
||||
psk = generate_preshared_key()
|
||||
assert len(psk) == 44
|
||||
Loading…
Add table
Add a link
Reference in a new issue