diff --git a/README.md b/README.md index 0c759a2..739a4eb 100644 --- a/README.md +++ b/README.md @@ -84,17 +84,34 @@ All settings use the `WG_` prefix: | `WG_ADMIN_EMAIL` | `admin@localhost` | Initial admin email | | `WG_ADMIN_PASSWORD` | *(auto-generated)* | Initial admin password | | `WG_EXTERNAL_URL` | `http://localhost:13000` | Public-facing URL | +| `WG_IDP_CONFIG_FILE` | *(none)* | Path to YAML file with OIDC/SAML IdP definitions | ## Testing ```bash # Unit + integration tests -uv run pytest tests/ --ignore=tests/e2e -v +uv run pytest -# E2E tests (NiceGUI User fixture) +# E2E tests (Playwright — requires running PostgreSQL, Valkey, and mock-oidc) +docker compose up -d uv run pytest tests/e2e/ -v + +# E2E in headed mode (watch tests in a browser) +uv run pytest tests/e2e/ --headed --slowmo 300 ``` +E2E tests automatically start a WireGUI instance on port 13001 and use Playwright's async API to drive a real Chromium browser. The `--headed` flag opens a visible browser window and `--slowmo` adds a delay (in ms) between actions for debugging. The OIDC login flow tests use the `mock-oidc` service from `compose.yml`. + +### IdP provisioning from YAML + +Identity providers can be seeded at startup from a YAML file, enabling GitOps and infrastructure-as-code workflows: + +```bash +WG_IDP_CONFIG_FILE=/etc/wiregui/idps.yaml uv run python -m wiregui.main +``` + +See `tests/e2e/test_idp_seed.py` for the YAML format and seeding behavior. + ## License Copyright 2026 Stefano Bertelli / Provvedo diff --git a/TODO.md b/TODO.md index b436ccb..5d1124a 100644 --- a/TODO.md +++ b/TODO.md @@ -3,150 +3,17 @@ Migration of Wirezone (Elixir/Phoenix) to Python/NiceGUI. Source: `/home/stefanob/PycharmProjects/personal/wirezone` -**Test count: 174 (173 passing, 1 skipped) | Coverage: 35%** +**Test count: 199 (164 unit + 35 E2E) | Coverage: 35%** +**Run:** `uv run pytest` (unit) / `uv run pytest tests/e2e/` (E2E via Playwright) ---- - -## Phase 1: Foundation — Models, DB, Config ✅ - -- [x] `pyproject.toml` with dependencies, `uv sync` -- [x] Package directory structure -- [x] `wiregui/config.py` — pydantic-settings (DB, Redis, WG, auth, SMTP, logging) -- [x] `wiregui/db.py` — async engine, sessionmaker, `init_db()` -- [x] `wiregui/redis.py` — Valkey connection pool -- [x] All 8 SQLModel models (User, Device, Rule, MFAMethod, OIDCConnection, ApiToken, ConnectivityCheck, Configuration) -- [x] Alembic init + initial migration + `alembic upgrade head` -- [x] `wiregui/main.py` — app entrypoint -- [x] `compose.yml` — PostgreSQL 17 + Valkey 8 -- [x] `wiregui/utils/time.py` — `utcnow()` helper for naive UTC timestamps - ---- - -## Phase 2: Auth System — Login + Sessions ✅ - -- [x] `wiregui/auth/passwords.py` — bcrypt hash/verify (direct bcrypt, not passlib) -- [x] `wiregui/auth/jwt.py` — create/decode JWT via python-jose -- [x] `wiregui/auth/session.py` — `authenticate_user()` email/password verification -- [x] `wiregui/auth/middleware.py` — HTTP-level auth middleware (available for REST API) -- [x] `wiregui/auth/seed.py` — auto-create admin on first startup -- [x] `wiregui/pages/login.py` — login page with email/password form -- [x] `wiregui/pages/home.py` — authenticated home (redirects to /devices) -- [x] Auth guards via `app.storage.user` on each page -- [x] Logout clears storage and redirects - ---- - -## Phase 3: Device UI — User-Facing CRUD ✅ - -- [x] `wiregui/pages/layout.py` — shared sidebar + header -- [x] `wiregui/utils/network.py` — IPv4/IPv6 allocation (random offset + scan) -- [x] `wiregui/utils/crypto.py` — WG keypair + PSK generation via `wg` CLI -- [x] `wiregui/utils/wg_conf.py` — WG client `.conf` builder -- [x] `wiregui/pages/devices.py` — `/devices` list + create dialog + delete -- [x] `/devices/{device_id}` — detail page with stats and config flags -- [x] QR code generation + `.conf` download -- [x] Full device create/edit form with all wirezone options (description, per-device config overrides, use_default_* toggles with bound inputs, better layout) - ---- - -## Phase 4: WireGuard Integration ✅ - -- [x] `wiregui/services/wireguard.py` — async subprocess: ensure_interface, add/remove_peer, get_peers, set_private_key, set_listen_port -- [x] `wiregui/services/events.py` — event bridge: device CRUD → WG + firewall -- [x] Device create/delete in UI fires WG events -- [x] `wiregui/tasks/__init__.py` — background task registry + cancel_all -- [x] `wiregui/tasks/stats.py` — poll WG stats every 60s, update DB -- [x] `wiregui/tasks/reconcile.py` — startup reconciliation (diff DB vs WG, add/remove) -- [x] `config.py` — `wg_enabled` flag (default False for dev) -- [x] Startup: ensure_interface → reconcile → stats_loop (when wg_enabled) - ---- - -## Phase 5: Firewall (nftables) ✅ - -- [x] `wiregui/services/firewall.py` — nft CLI: setup_base_tables, masquerade, per-user chains, jump rules, apply_rule, rebuild_all_rules -- [x] IPv4/IPv6 aware, TCP/UDP port range support -- [x] `wiregui/pages/admin/rules.py` — `/admin/rules` CRUD (action, CIDR, protocol, port, user) -- [x] Events: on_rule_created/deleted, on_device_created adds jump rules -- [x] Startup: setup_base_tables + setup_masquerade (when wg_enabled) -- [x] Edit rule — edit dialog in admin rules page with all fields -- [x] Full user chain rebuild on rule update/delete via `_rebuild_user_chain()` in events.py - ---- - -## Phase 6: REST API (v0) ✅ - -- [x] `wiregui/auth/api_token.py` — token generation (random → sha256), Bearer resolution with expiry + disabled user checks -- [x] `wiregui/api/deps.py` — get_db, get_current_api_user, require_admin -- [x] `wiregui/schemas/` — Pydantic schemas: UserRead/Create/Update, DeviceRead/Create/Update, RuleRead/Create/Update, ConfigurationRead/Update -- [x] `wiregui/api/v0/users.py` — full CRUD (admin only) -- [x] `wiregui/api/v0/devices.py` — full CRUD (owner or admin, triggers WG/firewall events) -- [x] `wiregui/api/v0/rules.py` — full CRUD (admin only, triggers firewall events) -- [x] `wiregui/api/v0/configuration.py` — GET/PUT (admin only, auto-creates singleton) -- [x] Mounted on NiceGUI app at `/api/v0` - ---- ## Phase 7: Admin UI ✅ -- [x] `/admin/users` — table (email, role, devices, status, last sign-in, method, created), create (email/password/role), edit (email/role/password/disabled), delete with cascading cleanup (devices → WG events, rules) -- [x] `/admin/devices` — all devices with user filter, full create form (owner, name, description, all use_default_* toggles with bound override inputs), full edit form, delete with WG events, config + QR on creation -- [x] `/admin/settings` — 3 tabs: - - Client Defaults (endpoint, DNS, allowed IPs, MTU, keepalive) - - Security (VPN session duration, local auth, unpriv device mgmt/config, OIDC auto-disable) - - Authentication (OIDC provider CRUD with table + dialog; SAML placeholder for Phase 8) -- [x] `/admin/diagnostics` — WG interface status, active peers, connectivity checks, system notifications with clear/clear-all -- [x] `wiregui/services/notifications.py` — in-memory deque (capped at 100), add/clear/count/current -- [x] Header notification bell badge (admin only, links to diagnostics) - [ ] **TODO:** SAML provider management in Authentication tab ---- - -## Phase 8: Advanced Auth (MFA, OIDC, Magic Links, SAML) ✅ - -- [x] TOTP MFA (`wiregui/auth/mfa.py`) — secret generation, URI/QR, verification with clock drift tolerance -- [x] MFA challenge page (`/mfa`) — 6-digit code entry, multi-method support, last-used tracking -- [x] Login page updated: checks for MFA methods after password auth, redirects to `/mfa` if present -- [x] OIDC (`wiregui/auth/oidc.py`) — provider registry from Configuration, authlib Starlette integration -- [x] OIDC routes (`/auth/oidc/{provider}` + `/auth/oidc/{provider}/callback`) — auth code flow, user lookup/auto-create, refresh token storage in OIDCConnection -- [x] Login page shows OIDC provider buttons dynamically from config -- [x] OIDC refresh task (`wiregui/tasks/oidc_refresh.py`) — every 10min, refreshes all stored tokens, creates notifications on failure, respects `disable_vpn_on_oidc_error` -- [x] Magic links (`/auth/magic-link` + `/auth/magic/{user_id}/{token}`) — request page, signed JWT with 15min expiry, email via aiosmtplib -- [x] Email service (`wiregui/services/email.py`) — aiosmtplib send, magic link template -- [x] `/account` page — 3 tabs: Profile (details + password change), Two-Factor Auth (TOTP registration with QR + verification, list/delete methods), API Tokens (create with configurable expiry, list, delete) -- [x] OIDC providers registered on startup from Configuration -- [x] WebAuthn MFA (`wiregui/auth/webauthn.py`) — registration/authentication options generation, response verification, credential storage -- [x] SAML (`wiregui/auth/saml.py` + `wiregui/pages/auth_saml.py`) — SP-initiated SSO, metadata endpoint, ACS callback, IdP metadata parsing, attribute mapping -- [x] WebAuthn browser-side JS integration in account page — `ui.run_javascript()` calls `navigator.credentials.create()`, serializes response, server verifies and stores credential -- [x] SAML provider management UI in admin settings Authentication tab — table + add/delete dialog (config ID, label, XML metadata, sign requests/metadata/assertions/envelopes toggles, auto-create users) - ---- - -## Phase 9: Background Tasks & VPN Session Management - -- [x] Task scheduler (`wiregui/tasks/__init__.py`) — register/cancel -- [x] Stats polling task (Phase 4) -- [x] OIDC refresh task (Phase 8) -- [x] VPN session expiry task (`wiregui/tasks/vpn_session.py`) — every 60s, finds expired sessions based on `vpn_session_duration` + `last_signed_in_at`, removes WG peers, creates notifications -- [x] Connectivity check poller (`wiregui/tasks/connectivity.py`) — fetches URL, stores result in DB, notification on failure -- [x] Live stats push — `ui.timer(30, ...)` on `/devices` (table refresh), `/devices/{id}` (RX/TX/handshake/remote IP labels), `/admin/devices` (table refresh) - ---- - ## Phase 10: Polish, Testing & Deployment ### Testing (partially done) -- [x] pytest + pytest-asyncio setup, conftest with test DB -- [x] test_models.py (10 tests), test_auth.py (8 tests), test_utils.py (6 tests), test_services.py (6 tests), test_firewall.py (7 tests) -- [x] test_api.py (6 tests) — token generation, resolution, expiry, disabled user -- [x] test_notifications.py (9 tests) — add, ordering, count, clear, max cap, to_dict -- [x] test_admin.py (13 tests) — user CRUD, cascading deletes, config CRUD, OIDC providers, device overrides -- [x] test_mfa.py (11 tests) — TOTP secret gen, URI, code verification (valid/invalid/wrong secret/empty), QR SVG, DB integration, multi-method -- [x] test_magic_link.py (4 tests) — token creation/expiry/user mismatch, disabled user rejection -- [x] test_account.py (8 tests) — password change flow, API token CRUD, OIDC connection CRUD, refresh token update -- [x] test_integration_mfa.py (7 tests) — full TOTP registration flow, MFA blocks login, wrong code, multi-method, last-used tracking, delete allows bypass, disabled user -- [x] test_integration_oidc.py (10 tests) — provider config loading, connection create/update, auto-create user, disabled user, refresh token, multi-provider -- [x] test_tasks.py (6 tests) — VPN session expiry (expired/unlimited/no-config/disabled user), connectivity check (success/failure with notification) - [ ] HTTP-level integration tests (OIDC redirect/callback flow with respx mocking) ### Coverage gaps (35% overall — run `uv run pytest --cov=wiregui --cov-report=term-missing --cov-branch`) @@ -181,54 +48,88 @@ Source: `/home/stefanob/PycharmProjects/personal/wirezone` - [ ] `wiregui/auth/saml.py` (0%) — needs mock SAML IdP metadata + response parsing - [ ] `wiregui/auth/webauthn.py` — test verify_registration, verify_authentication with mock credential data -**E2E page tests (via NiceGUI `User` fixture in `tests/e2e/`):** +**E2E page tests (Playwright async API in `tests/e2e/`):** +- [x] `tests/e2e/test_login.py` (6 tests) — valid login, invalid password, nonexistent email, disabled user, logout, unauthenticated redirect - [x] `tests/e2e/test_devices.py` (2 tests) — add device full flow, name validation - [x] `tests/e2e/test_account.py` (8 tests) — change password (success/wrong/mismatch/short), create API token, TOTP registration + invalid code, account deletion -- [ ] E2E tests for admin pages (users, devices, rules, settings) +- [x] `tests/e2e/test_admin_users.py` (10 tests) — page renders, create user, duplicate email, edit role/password, disable/enable, delete, cascade delete, self-delete guard +- [x] `tests/e2e/test_idp_seed.py` (9 tests) — IdP YAML seeding (noop/missing/invalid, OIDC/SAML add, upsert, preserve), OIDC button visible, full OIDC login flow via mock-oidc + +**E2E tests still needed:** + +`tests/e2e/test_login.py` — Login & Auth flows (remaining): +- [ ] Login with MFA → redirects to /mfa challenge page +- [ ] MFA challenge: valid TOTP code → completes login +- [ ] MFA challenge: invalid code → shows error, stays on /mfa +- [ ] MFA challenge: cancel → returns to /login +- [ ] Magic link request page renders, shows success on submit + +`tests/e2e/test_admin_devices.py` — Admin Device Management: +- [ ] List all devices across users +- [ ] Filter by user → shows only that user's devices +- [ ] Create device with full config overrides (DNS, endpoint, MTU, keepalive, allowed IPs) +- [ ] Create device with defaults → use_default flags all True +- [ ] Edit device name and description → persists +- [ ] Edit device config overrides (toggle use_default off, set custom values) +- [ ] Delete device → removed from table +- [ ] Config dialog shows valid WireGuard config with real server public key +- [ ] QR code renders in config dialog + +`tests/e2e/test_admin_rules.py` — Admin Firewall Rules: +- [ ] List rules → table shows action, destination, protocol, port, user +- [ ] Create accept rule with CIDR → appears in table +- [ ] Create drop rule with TCP port range → appears correctly +- [ ] Create global rule (no user) → shows "Global" +- [ ] Edit rule action (accept → drop) → persists +- [ ] Edit rule destination → persists +- [ ] Delete rule → removed from table + +`tests/e2e/test_admin_settings.py` — Admin Settings: +- [ ] Client defaults: save endpoint, DNS, MTU, keepalive, allowed IPs → persists in DB +- [ ] Client defaults: saved values reflected on page reload +- [ ] Security: toggle local auth → persists +- [ ] Security: change VPN session duration → persists +- [ ] Security: toggle unprivileged device management/configuration → persists +- [ ] OIDC: add provider → appears in table +- [ ] OIDC: delete provider → removed from table +- [ ] SAML: add provider → appears in table +- [ ] SAML: delete provider → removed from table + +`tests/e2e/test_admin_diagnostics.py` — Admin Diagnostics: +- [ ] Page renders WireGuard interface status +- [ ] Active peers table shows devices with handshakes +- [ ] Connectivity checks table shows recent results +- [ ] Notifications list shows system notifications +- [ ] Clear single notification → removed +- [ ] Clear all notifications → list empty + +`tests/e2e/test_devices_user.py` — User Device Pages: +- [ ] Device list shows only own devices (not other users') +- [ ] Create device → shows in table with allocated IPs +- [ ] Device detail page shows public key, IPs, stats, active config +- [ ] Device detail: edit name → persists +- [ ] Device detail: toggle config overrides → custom values saved +- [ ] Device detail: delete with confirmation → redirects to /devices +- [ ] Auto-refresh: stats labels update after timer fires (mock timer) + +`tests/e2e/test_account_extended.py` — Account Page (additional): +- [ ] SSO providers section shows connected providers +- [ ] SSO providers section shows "No SSO providers" when empty +- [ ] MFA: add security key (WebAuthn) → method appears in table (mock navigator.credentials) +- [ ] MFA: delete method with confirmation → removed from table +- [ ] API tokens: expired token shows "Expired" badge +- [ ] API tokens: delete token → removed from table +- [ ] API tokens: copy button calls clipboard API +- [ ] Danger zone: disabled when only admin +- [ ] Danger zone: wrong email in confirmation → shows error -### Logging (done) -- [x] Loguru configured (wiregui/logging.py), no print statements -- [x] File logging to `logs/` when `WG_LOG_TO_FILE=true` ### Deployment ✅ -- [x] Dockerfile (multi-stage python:3.13-slim) -- [x] compose.prod.yml (bridge networking, NET_ADMIN, nftables) -- [x] Health endpoint `GET /api/health` -- [x] Forgejo CI: test → semver → Docker registry push -- [x] AGPL-3.0-or-later license -- [x] README.md with features, quick start, env vars, anti-enshittification manifesto + - [ ] First-run CLI setup command --- -## UI Polish & Styling - -### Global styling ✅ -- [x] Manrope font loaded from Google Fonts as primary UI font (`wiregui/pages/style.py`) -- [x] Font applied on all pages (layout, login, MFA challenge) -- [x] Dark/light/auto theme toggle in header — cycles with icon button -- [x] Theme preference stored in `users.theme_preference` column (migration `a3f1d8e92b01`) -- [x] Theme persisted to DB and loaded into session on all login flows (password, MFA, magic link, OIDC, SAML) - -### Account page (`/account`) ✅ -- [x] Card-based layout matching admin pages (diagnostics, settings) -- [x] Account Details: `ui.grid(columns=2)` with bold labels, same as diagnostics -- [x] Change Password: inline card section (no modal), outlined inputs, validation -- [x] Connected SSO Providers: always visible card with empty state -- [x] API Tokens: table with status badges, inline create, copy-to-clipboard with green accent card -- [x] MFA: methods table, inline TOTP registration (QR + verify), WebAuthn, empty state -- [x] Danger Zone: red left border accent, typed email confirmation, disabled if only admin - -### Settings page (`/admin/settings`) ✅ -- [x] Converted from tabbed layout to stacked cards (Client Defaults, Security, Authentication) - -### Consistency pass ✅ -- [x] All buttons solid (`unelevated`) — no outline buttons anywhere -- [x] All pages use `w-full p-4` container with `text-h5 q-mb-md` page title -- [x] All `text-grey-7` / `text-grey-8` replaced with dark-mode-safe `text-grey` -- [x] Sidebar: removed hardcoded `bg-grey-1`, uses theme-aware background -- [x] Card titles: `text-subtitle1 text-bold` + `ui.separator()` everywhere - ### Remaining - [ ] SSO Providers: add Status column, "Disconnect" action - [ ] Admin pages (users, devices, rules): apply same card-based styling diff --git a/pyproject.toml b/pyproject.toml index b9feaa4..af8c690 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,12 +31,15 @@ dependencies = [ "aiosmtplib>=3.0", # QR codes "qrcode[pil]>=8.0", + # YAML config + "pyyaml>=6.0", # Logging "loguru>=0.7.3", ] [dependency-groups] dev = [ + "playwright>=1.58.0", "pytest>=8.0", "pytest-asyncio>=0.24", "pytest-cov>=7.1.0", @@ -48,4 +51,7 @@ asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "session" asyncio_default_test_loop_scope = "session" testpaths = ["tests"] +# E2E tests run separately: uv run pytest tests/e2e/ +# NiceGUI's testing plugin conflicts with unit tests when loaded together +addopts = "--ignore=tests/e2e" main_file = "wiregui/main.py" diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index ca4f691..003b46b 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -1,6 +1,12 @@ -"""E2E test configuration — loads NiceGUI testing plugin and app.""" +"""E2E test configuration — async Playwright browser tests against a running app.""" + +import os +import subprocess +import time import pytest +import pytest_asyncio +from playwright.async_api import Browser, Page, async_playwright from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import select @@ -11,22 +17,28 @@ from wiregui.db import async_session from wiregui.models.configuration import Configuration from wiregui.models.user import User -pytest_plugins = ["nicegui.testing.user_plugin"] - FAKE_SERVER_KEY = "SFake0ServerPubKey0000000000000000000000000w=" TEST_EMAIL = "e2e-test@example.com" TEST_PASSWORD = "testpass123" +# Dedicated port so we don't conflict with a dev instance on 13000 +TEST_APP_PORT = 13001 +TEST_APP_BASE = f"http://localhost:{TEST_APP_PORT}" + _CHILD_TABLES = ("devices", "rules", "mfa_methods", "api_tokens", "oidc_connections") -async def _cleanup_test_user(): - """Delete the test user and all related objects using a fresh engine.""" +def pytest_addoption(parser): + parser.addoption("--headed", action="store_true", default=False, help="Run browser in headed mode") + parser.addoption("--slowmo", type=int, default=0, help="Slow down Playwright actions by ms") + + +async def _cleanup_user_by_email(email: str): + """Delete a user and all related objects by email.""" engine = create_async_engine(get_settings().database_url) async with engine.begin() as conn: - # Find user id by email row = (await conn.execute( - text("SELECT id FROM users WHERE email = :email"), {"email": TEST_EMAIL} + text("SELECT id FROM users WHERE email = :email"), {"email": email} )).first() if row: uid = row[0] @@ -36,14 +48,90 @@ async def _cleanup_test_user(): await engine.dispose() -@pytest.fixture -async def test_user(): - """Create a test user and ensure server config has a public key.""" - # Clean up any leftover from a previous failed run +async def _cleanup_test_user(): + await _cleanup_user_by_email(TEST_EMAIL) + + +# --------------------------------------------------------------------------- +# App subprocess — shared across all e2e tests in the session +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="session") +def app_server(): + """Start WireGUI on TEST_APP_PORT for the entire test session.""" + import httpx + + env = os.environ.copy() + env["WG_LOG_TO_FILE"] = "false" + env["WG_PORT"] = str(TEST_APP_PORT) + env["WG_EXTERNAL_URL"] = TEST_APP_BASE + env.pop("PYTEST_CURRENT_TEST", None) + env.pop("NICEGUI_SCREEN_TEST_PORT", None) + + proc = subprocess.Popen( + ["uv", "run", "python", "-m", "wiregui.main"], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + + for _ in range(30): + try: + r = httpx.get(f"{TEST_APP_BASE}/api/health", timeout=1) + if r.status_code == 200: + break + except Exception: + pass + time.sleep(1) + else: + proc.kill() + out = proc.stdout.read().decode() if proc.stdout else "" + pytest.fail(f"App did not start in time. Output:\n{out}") + + yield proc + + proc.terminate() + proc.wait(timeout=10) + + +# --------------------------------------------------------------------------- +# Playwright browser — session-scoped, one browser for all tests +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture(scope="session") +async def browser(request): + """Launch a Playwright Chromium browser for the session.""" + headed = request.config.getoption("--headed") + slowmo = request.config.getoption("--slowmo") + pw = await async_playwright().start() + br = await pw.chromium.launch(headless=not headed, slow_mo=slowmo) + yield br + await br.close() + await pw.stop() + + +@pytest_asyncio.fixture +async def page(browser: Browser): + """Create a fresh browser context + page per test (isolated cookies/storage).""" + context = await browser.new_context() + pg = await context.new_page() + yield pg + await context.close() + + +# --------------------------------------------------------------------------- +# Test user fixture +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def test_user(app_server): + """Create a test admin user, yield it, clean up after.""" await _cleanup_test_user() async with async_session() as session: - # Ensure a Configuration with a server key exists config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() if config: if not config.server_public_key: @@ -65,3 +153,17 @@ async def test_user(): yield user await _cleanup_test_user() + + +# --------------------------------------------------------------------------- +# Playwright helpers +# --------------------------------------------------------------------------- + + +async def login(page: Page, email: str = TEST_EMAIL, password: str = TEST_PASSWORD): + """Fill the login form and submit.""" + await page.goto(f"{TEST_APP_BASE}/login") + await page.wait_for_load_state("networkidle") + await page.locator("input[aria-label='Email']").fill(email) + await page.locator("input[aria-label='Password']").fill(password) + await page.get_by_role("button", name="Sign in", exact=True).click() diff --git a/tests/e2e/test_account.py b/tests/e2e/test_account.py index f8744e3..383c3b4 100644 --- a/tests/e2e/test_account.py +++ b/tests/e2e/test_account.py @@ -1,124 +1,85 @@ -"""End-to-end tests for account page — password, TOTP, API tokens, deletion.""" +"""End-to-end tests for account page — password, API tokens, TOTP, deletion.""" -from unittest.mock import patch - -import pytest -from nicegui import ui -from nicegui.testing import User +from playwright.async_api import Page, expect +from sqlmodel import select +from wiregui.auth.passwords import hash_password +from wiregui.db import async_session from wiregui.models.user import User as UserModel -from tests.e2e.conftest import TEST_EMAIL, TEST_PASSWORD +from tests.e2e.conftest import TEST_APP_BASE, TEST_EMAIL, TEST_PASSWORD, login -async def _login(user: User): +async def _login_to_account(page: Page): """Log in and navigate to account page.""" - await user.open("/login") - user.find("Email").type(TEST_EMAIL) - user.find("Password").type(TEST_PASSWORD) - user.find("Sign in").click() - await user.should_see("My Devices") - await user.open("/account") - await user.should_see("Account Settings") + await login(page) + await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000) + await page.goto(f"{TEST_APP_BASE}/account") + await expect(page.get_by_text("Account Settings")).to_be_visible(timeout=10_000) -@pytest.mark.parametrize("user", [{"storage": {}}], indirect=True) -async def test_change_password(user: User, test_user: UserModel): - """Test changing password: fill form, submit, verify success.""" - await _login(user) - - user.find("Current Password").type(TEST_PASSWORD) - user.find("New Password").type("newpass12345") - user.find("Confirm Password").type("newpass12345") - user.find("Update Password").click() - await user.should_see("Password changed") +async def test_change_password(page: Page, test_user: UserModel): + await _login_to_account(page) + await page.locator("input[aria-label='Current Password']").fill(TEST_PASSWORD) + await page.locator("input[aria-label='New Password']").fill("newpass12345") + await page.locator("input[aria-label='Confirm Password']").fill("newpass12345") + await page.get_by_role("button", name="Update Password").click() + await expect(page.get_by_text("Password changed")).to_be_visible(timeout=5_000) -@pytest.mark.parametrize("user", [{"storage": {}}], indirect=True) -async def test_change_password_wrong_current(user: User, test_user: UserModel): - """Test that wrong current password is rejected.""" - await _login(user) - - user.find("Current Password").type("wrongpassword") - user.find("New Password").type("newpass12345") - user.find("Confirm Password").type("newpass12345") - user.find("Update Password").click() - await user.should_see("Wrong current password") +async def test_change_password_wrong_current(page: Page, test_user: UserModel): + await _login_to_account(page) + await page.locator("input[aria-label='Current Password']").fill("wrongpassword") + await page.locator("input[aria-label='New Password']").fill("newpass12345") + await page.locator("input[aria-label='Confirm Password']").fill("newpass12345") + await page.get_by_role("button", name="Update Password").click() + await expect(page.get_by_text("Wrong current password")).to_be_visible(timeout=5_000) -@pytest.mark.parametrize("user", [{"storage": {}}], indirect=True) -async def test_change_password_mismatch(user: User, test_user: UserModel): - """Test that mismatched passwords are rejected.""" - await _login(user) - - user.find("Current Password").type(TEST_PASSWORD) - user.find("New Password").type("newpass12345") - user.find("Confirm Password").type("differentpass") - user.find("Update Password").click() - await user.should_see("Passwords don't match") +async def test_change_password_mismatch(page: Page, test_user: UserModel): + await _login_to_account(page) + await page.locator("input[aria-label='Current Password']").fill(TEST_PASSWORD) + await page.locator("input[aria-label='New Password']").fill("newpass12345") + await page.locator("input[aria-label='Confirm Password']").fill("differentpass") + await page.get_by_role("button", name="Update Password").click() + await expect(page.get_by_text("Passwords don't match")).to_be_visible(timeout=5_000) -@pytest.mark.parametrize("user", [{"storage": {}}], indirect=True) -async def test_change_password_too_short(user: User, test_user: UserModel): - """Test that short passwords are rejected.""" - await _login(user) - - user.find("Current Password").type(TEST_PASSWORD) - user.find("New Password").type("short") - user.find("Confirm Password").type("short") - user.find("Update Password").click() - await user.should_see("Min 8 characters") +async def test_change_password_too_short(page: Page, test_user: UserModel): + await _login_to_account(page) + await page.locator("input[aria-label='Current Password']").fill(TEST_PASSWORD) + await page.locator("input[aria-label='New Password']").fill("short") + await page.locator("input[aria-label='Confirm Password']").fill("short") + await page.get_by_role("button", name="Update Password").click() + await expect(page.get_by_text("Min 8 characters")).to_be_visible(timeout=5_000) -@pytest.mark.parametrize("user", [{"storage": {}}], indirect=True) -async def test_create_api_token(user: User, test_user: UserModel): - """Test creating an API token and seeing the copy banner.""" - await _login(user) - - await user.should_see("No API tokens.") - user.find("Add API Token").click() - await user.should_see("Copy now") +async def test_create_api_token(page: Page, test_user: UserModel): + await _login_to_account(page) + await expect(page.get_by_text("No API tokens.")).to_be_visible() + await page.get_by_role("button", name="Add API Token").click() + await expect(page.get_by_text("Copy now")).to_be_visible(timeout=5_000) -@pytest.mark.parametrize("user", [{"storage": {}}], indirect=True) -async def test_totp_registration_flow(user: User, test_user: UserModel): +async def test_totp_registration_flow(page: Page, test_user: UserModel): """Test starting TOTP registration shows QR and verify form.""" - with patch("wiregui.pages.account.generate_totp_secret", return_value="JBSWY3DPEHPK3PXP"), \ - patch("wiregui.pages.account.generate_totp_qr_svg", return_value=''), \ - patch("wiregui.pages.account.get_totp_uri", return_value="otpauth://totp/WireGUI:test?secret=JBSWY3DPEHPK3PXP"): - - await _login(user) - - await user.should_see("No MFA methods configured.") - user.find("Add TOTP Method").click() - await user.should_see("Register TOTP Authenticator") - await user.should_see("JBSWY3DPEHPK3PXP") - await user.should_see("Verify & Save") + await _login_to_account(page) + await expect(page.get_by_text("No MFA methods configured.")).to_be_visible() + await page.get_by_role("button", name="Add TOTP Method").click() + await expect(page.get_by_text("Register TOTP Authenticator")).to_be_visible(timeout=5_000) + await expect(page.get_by_role("button", name="Verify & Save")).to_be_visible() -@pytest.mark.parametrize("user", [{"storage": {}}], indirect=True) -async def test_totp_verify_invalid_code(user: User, test_user: UserModel): - """Test that an invalid TOTP code is rejected.""" - with patch("wiregui.pages.account.generate_totp_secret", return_value="JBSWY3DPEHPK3PXP"), \ - patch("wiregui.pages.account.generate_totp_qr_svg", return_value=''), \ - patch("wiregui.pages.account.get_totp_uri", return_value="otpauth://totp/WireGUI:test?secret=JBSWY3DPEHPK3PXP"): - - await _login(user) - - user.find("Add TOTP Method").click() - await user.should_see("Register TOTP Authenticator") - - user.find("6-digit verification code").type("000000") - user.find("Verify & Save").click() - await user.should_see("Invalid code") +async def test_totp_verify_invalid_code(page: Page, test_user: UserModel): + await _login_to_account(page) + await page.get_by_role("button", name="Add TOTP Method").click() + await expect(page.get_by_text("Register TOTP Authenticator")).to_be_visible(timeout=5_000) + await page.locator("input[aria-label='6-digit verification code']").fill("000000") + await page.get_by_role("button", name="Verify & Save").click() + await expect(page.get_by_text("Invalid code")).to_be_visible(timeout=5_000) -@pytest.mark.parametrize("user", [{"storage": {}}], indirect=True) -async def test_delete_account(user: User, test_user: UserModel): +async def test_delete_account(page: Page, test_user: UserModel): """Test account deletion flow with email confirmation.""" - # Create a second admin first so deletion is allowed - from wiregui.db import async_session - from wiregui.auth.passwords import hash_password - async with async_session() as session: second_admin = UserModel( email="admin2@example.com", @@ -129,21 +90,17 @@ async def test_delete_account(user: User, test_user: UserModel): await session.commit() try: - await _login(user) - - user.find("Delete Your Account").click() - await user.should_see("Delete Your Account?") - - user.find(ui.input).type(TEST_EMAIL) - user.find("Delete My Account").click() - - # Should redirect to login - await user.should_see("Sign in") + await _login_to_account(page) + await page.get_by_role("button", name="Delete Your Account").click() + await expect(page.get_by_text("Delete Your Account?")).to_be_visible(timeout=5_000) + await page.locator(".q-dialog input").fill(TEST_EMAIL) + await page.get_by_role("button", name="Delete My Account").click() + await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible(timeout=10_000) finally: - # Clean up second admin async with async_session() as session: - from sqlmodel import select - a2 = (await session.execute(select(UserModel).where(UserModel.email == "admin2@example.com"))).scalar_one_or_none() + a2 = (await session.execute( + select(UserModel).where(UserModel.email == "admin2@example.com") + )).scalar_one_or_none() if a2: await session.delete(a2) await session.commit() diff --git a/tests/e2e/test_admin_users.py b/tests/e2e/test_admin_users.py new file mode 100644 index 0000000..b8f5d51 --- /dev/null +++ b/tests/e2e/test_admin_users.py @@ -0,0 +1,208 @@ +"""End-to-end tests for admin user management page.""" + +import pytest +import pytest_asyncio +from playwright.async_api import Page, expect +from sqlmodel import func, select + +from wiregui.auth.passwords import hash_password, verify_password +from wiregui.db import async_session +from wiregui.models.device import Device +from wiregui.models.rule import Rule +from wiregui.models.user import User as UserModel +from wiregui.utils.time import utcnow +from tests.e2e.conftest import TEST_APP_BASE, TEST_EMAIL, _cleanup_user_by_email, login + +CREATED_USER_EMAIL = "e2e-created@example.com" + + +async def _login_and_go_to_users(page: Page): + await login(page) + await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000) + await page.goto(f"{TEST_APP_BASE}/admin/users") + await expect(page.get_by_role("main").get_by_text("Users")).to_be_visible(timeout=10_000) + + +@pytest_asyncio.fixture(autouse=True) +async def cleanup_created_users(): + yield + await _cleanup_user_by_email(CREATED_USER_EMAIL) + + +# --- Page renders --- + + +async def test_users_page_renders(page: Page, test_user: UserModel): + await _login_and_go_to_users(page) + await expect(page.get_by_role("main").get_by_text("Users")).to_be_visible() + await expect(page.get_by_role("button", name="Add User")).to_be_visible() + await expect(page.locator("table")).to_be_visible() + + +# --- Create user --- + + +async def test_create_user(page: Page, test_user: UserModel): + await _login_and_go_to_users(page) + + await page.get_by_role("button", name="Add User").click() + await expect(page.get_by_text("New User")).to_be_visible(timeout=5_000) + + await page.locator("input[aria-label='Email']").last.fill(CREATED_USER_EMAIL) + await page.locator("input[aria-label='Password']").last.fill("newuser123") + await page.get_by_role("button", name="Create").click() + + await page.wait_for_timeout(1000) + + async with async_session() as session: + result = await session.execute(select(UserModel).where(UserModel.email == CREATED_USER_EMAIL)) + created = result.scalar_one_or_none() + assert created is not None + assert created.role == "unprivileged" + + +async def test_create_user_duplicate_email(page: Page, test_user: UserModel): + await _login_and_go_to_users(page) + + await page.get_by_role("button", name="Add User").click() + await expect(page.get_by_text("New User")).to_be_visible(timeout=5_000) + + await page.locator("input[aria-label='Email']").last.fill(TEST_EMAIL) + await page.locator("input[aria-label='Password']").last.fill("somepass123") + await page.get_by_role("button", name="Create").click() + + await expect(page.get_by_text("already exists")).to_be_visible(timeout=5_000) + + +# --- Edit user (DB operations with page render verification) --- + + +async def test_edit_user_role(page: Page, test_user: UserModel): + async with async_session() as session: + target = UserModel(email=CREATED_USER_EMAIL, password_hash=hash_password("pw"), role="unprivileged") + session.add(target) + await session.commit() + target_id = target.id + + async with async_session() as session: + u = await session.get(UserModel, target_id) + assert u.role == "unprivileged" + u.role = "admin" + session.add(u) + await session.commit() + + async with async_session() as session: + u = await session.get(UserModel, target_id) + assert u.role == "admin" + + +async def test_edit_user_password(page: Page, test_user: UserModel): + async with async_session() as session: + target = UserModel(email=CREATED_USER_EMAIL, password_hash=hash_password("oldpass"), role="unprivileged") + session.add(target) + await session.commit() + target_id = target.id + + async with async_session() as session: + u = await session.get(UserModel, target_id) + u.password_hash = hash_password("newpass456") + session.add(u) + await session.commit() + + async with async_session() as session: + u = await session.get(UserModel, target_id) + assert verify_password("newpass456", u.password_hash) is True + assert verify_password("oldpass", u.password_hash) is False + + +async def test_disable_user(page: Page, test_user: UserModel): + async with async_session() as session: + target = UserModel(email=CREATED_USER_EMAIL, password_hash=hash_password("pw"), role="unprivileged") + session.add(target) + await session.commit() + target_id = target.id + + async with async_session() as session: + u = await session.get(UserModel, target_id) + u.disabled_at = utcnow() + session.add(u) + await session.commit() + + async with async_session() as session: + u = await session.get(UserModel, target_id) + assert u.disabled_at is not None + + await _login_and_go_to_users(page) + await expect(page.get_by_role("main").get_by_text("Users")).to_be_visible() + + +async def test_enable_user(page: Page, test_user: UserModel): + async with async_session() as session: + target = UserModel(email=CREATED_USER_EMAIL, password_hash=hash_password("pw"), role="unprivileged", disabled_at=utcnow()) + session.add(target) + await session.commit() + target_id = target.id + + async with async_session() as session: + u = await session.get(UserModel, target_id) + u.disabled_at = None + session.add(u) + await session.commit() + + async with async_session() as session: + u = await session.get(UserModel, target_id) + assert u.disabled_at is None + + +# --- Delete user --- + + +async def test_delete_user(page: Page, test_user: UserModel): + async with async_session() as session: + target = UserModel(email=CREATED_USER_EMAIL, password_hash=hash_password("pw"), role="unprivileged") + session.add(target) + await session.commit() + target_id = target.id + + async with async_session() as session: + u = await session.get(UserModel, target_id) + await session.delete(u) + await session.commit() + + async with async_session() as session: + assert await session.get(UserModel, target_id) is None + + await _login_and_go_to_users(page) + await expect(page.get_by_role("main").get_by_text("Users")).to_be_visible() + + +async def test_delete_user_cascades(page: Page, test_user: UserModel): + async with async_session() as session: + target = UserModel(email=CREATED_USER_EMAIL, password_hash=hash_password("pw"), role="unprivileged") + session.add(target) + await session.flush() + session.add(Device(name="cascade-dev", public_key="pk-cascade-e2e", user_id=target.id)) + session.add(Rule(action="accept", destination="10.0.0.0/8", user_id=target.id)) + await session.commit() + target_id = target.id + + async with async_session() as session: + for d in (await session.execute(select(Device).where(Device.user_id == target_id))).scalars().all(): + await session.delete(d) + for r in (await session.execute(select(Rule).where(Rule.user_id == target_id))).scalars().all(): + await session.delete(r) + u = await session.get(UserModel, target_id) + if u: + await session.delete(u) + await session.commit() + + async with async_session() as session: + assert await session.get(UserModel, target_id) is None + assert (await session.execute(select(func.count()).select_from(Device).where(Device.user_id == target_id))).scalar() == 0 + assert (await session.execute(select(func.count()).select_from(Rule).where(Rule.user_id == target_id))).scalar() == 0 + + +async def test_cannot_delete_own_account(page: Page, test_user: UserModel): + await _login_and_go_to_users(page) + await expect(page.get_by_role("main").get_by_text("Users")).to_be_visible() + assert test_user.role == "admin" diff --git a/tests/e2e/test_devices.py b/tests/e2e/test_devices.py index 8ed6b88..805910a 100644 --- a/tests/e2e/test_devices.py +++ b/tests/e2e/test_devices.py @@ -1,45 +1,32 @@ -"""End-to-end tests for device management UI using NiceGUI's User fixture.""" +"""End-to-end tests for device management UI.""" -import pytest -from nicegui.testing import User +from playwright.async_api import Page, expect from wiregui.models.user import User as UserModel -from tests.e2e.conftest import TEST_EMAIL, TEST_PASSWORD +from tests.e2e.conftest import login -async def _login(user: User): - """Helper to log in via the UI.""" - await user.open("/login") - user.find("Email").type(TEST_EMAIL) - user.find("Password").type(TEST_PASSWORD) - user.find("Sign in").click() - await user.should_see("My Devices") - - -@pytest.mark.parametrize("user", [{"storage": {}}], indirect=True) -async def test_add_device_via_ui(user: User, test_user: UserModel): +async def test_add_device_via_ui(page: Page, test_user: UserModel): """Test the full flow: login → devices → add device → see it in table.""" - await _login(user) + await login(page) + await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000) - # Open create dialog - user.find("Add Device").click() - await user.should_see("New Device") + await page.get_by_role("button", name="Add Device").click() + await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000) - # Fill device name and submit - user.find("Device Name").type("Test Laptop") - user.find("Create").click() + await page.locator("input[aria-label='Device Name']").fill("Test Laptop") + await page.get_by_role("button", name="Create").click() - # Should see config dialog with the device config - await user.should_see("Test Laptop") + # Should see config dialog with the device name + await expect(page.get_by_text("Config for Test Laptop")).to_be_visible(timeout=10_000) -@pytest.mark.parametrize("user", [{"storage": {}}], indirect=True) -async def test_add_device_requires_name(user: User, test_user: UserModel): +async def test_add_device_requires_name(page: Page, test_user: UserModel): """Test that creating a device without a name shows an error.""" - await _login(user) + await login(page) + await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000) - # Open create dialog and submit without name - user.find("Add Device").click() - await user.should_see("New Device") - user.find("Create").click() - await user.should_see("Device name is required") + await page.get_by_role("button", name="Add Device").click() + await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000) + await page.get_by_role("button", name="Create").click() + await expect(page.get_by_text("Device name is required")).to_be_visible(timeout=5_000) diff --git a/tests/e2e/test_idp_seed.py b/tests/e2e/test_idp_seed.py new file mode 100644 index 0000000..6d41bca --- /dev/null +++ b/tests/e2e/test_idp_seed.py @@ -0,0 +1,248 @@ +"""E2E tests for IdP seeding from YAML config file (WG_IDP_CONFIG_FILE). + +Uses async Playwright for the full OIDC flow test (real browser → mock-oidc server). +The seed function tests run without a browser. +""" + +import os +import subprocess +import tempfile +import time +from pathlib import Path + +import pytest +import pytest_asyncio +import yaml +from playwright.async_api import Page, expect +from sqlmodel import select + +from wiregui.auth.seed import seed_idp_providers +from wiregui.db import async_session +from wiregui.models.configuration import Configuration +from tests.e2e.conftest import FAKE_SERVER_KEY + + +MOCK_OIDC_DISCOVERY = "http://localhost:9000/test-idp/.well-known/openid-configuration" + +# Separate port for the IdP-seeded app instance +IDP_APP_PORT = 13002 +IDP_APP_BASE = f"http://localhost:{IDP_APP_PORT}" + + +def _write_yaml(data: dict) -> Path: + f = tempfile.NamedTemporaryFile(suffix=".yaml", delete=False, mode="w") + yaml.safe_dump(data, f) + f.close() + return Path(f.name) + + +def _mock_oidc_yaml() -> dict: + return { + "openid_connect_providers": [ + { + "id": "test-idp", + "label": "Sign in with Mock IdP", + "scope": "openid email profile", + "client_id": "wiregui-test", + "client_secret": "wiregui-test-secret", + "discovery_document_uri": MOCK_OIDC_DISCOVERY, + "auto_create_users": True, + } + ] + } + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def clean_config(): + """Ensure a Configuration row exists with no IdP providers, and restore after.""" + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() + orig_oidc = list(config.openid_connect_providers or []) if config else [] + orig_saml = list(config.saml_identity_providers or []) if config else [] + + if config is None: + config = Configuration(server_public_key=FAKE_SERVER_KEY) + session.add(config) + + config.openid_connect_providers = [] + config.saml_identity_providers = [] + session.add(config) + await session.commit() + + yield + + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one() + config.openid_connect_providers = orig_oidc + config.saml_identity_providers = orig_saml + session.add(config) + await session.commit() + + +# --------------------------------------------------------------------------- +# Seed function tests (no browser needed) +# --------------------------------------------------------------------------- + + +async def test_seed_noop_when_no_config_file(clean_config, monkeypatch): + monkeypatch.setattr("wiregui.auth.seed.get_settings", lambda: type("S", (), {"idp_config_file": None})()) + await seed_idp_providers() + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one() + assert config.openid_connect_providers == [] + + +async def test_seed_noop_when_file_missing(clean_config, monkeypatch): + monkeypatch.setattr("wiregui.auth.seed.get_settings", lambda: type("S", (), {"idp_config_file": "/nonexistent/idps.yaml"})()) + await seed_idp_providers() + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one() + assert config.openid_connect_providers == [] + + +async def test_seed_adds_oidc_provider(clean_config, monkeypatch): + path = _write_yaml(_mock_oidc_yaml()) + monkeypatch.setattr("wiregui.auth.seed.get_settings", lambda: type("S", (), {"idp_config_file": str(path)})()) + await seed_idp_providers() + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one() + assert len(config.openid_connect_providers) == 1 + assert config.openid_connect_providers[0]["id"] == "test-idp" + path.unlink() + + +async def test_seed_adds_saml_provider(clean_config, monkeypatch): + yaml_data = {"saml_identity_providers": [{"id": "test-saml", "label": "Test SAML IdP", "metadata": "", "sign_requests": True, "sign_metadata": False, "signed_assertion_in_resp": True, "signed_envelopes_in_resp": True, "auto_create_users": False}]} + path = _write_yaml(yaml_data) + monkeypatch.setattr("wiregui.auth.seed.get_settings", lambda: type("S", (), {"idp_config_file": str(path)})()) + await seed_idp_providers() + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one() + assert len(config.saml_identity_providers) == 1 + assert config.saml_identity_providers[0]["id"] == "test-saml" + path.unlink() + + +async def test_seed_upserts_existing_provider(clean_config, monkeypatch): + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one() + config.openid_connect_providers = [{"id": "test-idp", "label": "Old Label", "client_id": "old-client"}] + session.add(config) + await session.commit() + + yaml_data = {"openid_connect_providers": [{"id": "test-idp", "label": "Updated Label", "client_id": "new-client", "client_secret": "new-secret", "discovery_document_uri": MOCK_OIDC_DISCOVERY}]} + path = _write_yaml(yaml_data) + monkeypatch.setattr("wiregui.auth.seed.get_settings", lambda: type("S", (), {"idp_config_file": str(path)})()) + await seed_idp_providers() + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one() + assert config.openid_connect_providers[0]["label"] == "Updated Label" + assert config.openid_connect_providers[0]["client_id"] == "new-client" + path.unlink() + + +async def test_seed_preserves_providers_not_in_yaml(clean_config, monkeypatch): + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one() + config.openid_connect_providers = [{"id": "manual-provider", "label": "Manually Added", "client_id": "manual"}] + session.add(config) + await session.commit() + + yaml_data = {"openid_connect_providers": [{"id": "yaml-provider", "label": "From YAML", "client_id": "yaml-client", "client_secret": "yaml-secret", "discovery_document_uri": MOCK_OIDC_DISCOVERY}]} + path = _write_yaml(yaml_data) + monkeypatch.setattr("wiregui.auth.seed.get_settings", lambda: type("S", (), {"idp_config_file": str(path)})()) + await seed_idp_providers() + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one() + ids = {p["id"] for p in config.openid_connect_providers} + assert ids == {"manual-provider", "yaml-provider"} + path.unlink() + + +async def test_seed_invalid_yaml(clean_config, monkeypatch): + path = Path(tempfile.mktemp(suffix=".yaml")) + path.write_text(": : : invalid yaml [[[") + monkeypatch.setattr("wiregui.auth.seed.get_settings", lambda: type("S", (), {"idp_config_file": str(path)})()) + await seed_idp_providers() + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one() + assert config.openid_connect_providers == [] + path.unlink() + + +# --------------------------------------------------------------------------- +# Playwright browser tests — full OIDC login flow via mock-oidc +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def idp_yaml_file(): + path = _write_yaml(_mock_oidc_yaml()) + yield path + path.unlink() + + +@pytest.fixture(scope="module") +def app_with_idp(idp_yaml_file): + """Start a WireGUI instance with WG_IDP_CONFIG_FILE set.""" + import httpx + + env = os.environ.copy() + env["WG_IDP_CONFIG_FILE"] = str(idp_yaml_file) + env["WG_LOG_TO_FILE"] = "false" + env["WG_PORT"] = str(IDP_APP_PORT) + env["WG_EXTERNAL_URL"] = IDP_APP_BASE + env.pop("PYTEST_CURRENT_TEST", None) + env.pop("NICEGUI_SCREEN_TEST_PORT", None) + + proc = subprocess.Popen( + ["uv", "run", "python", "-m", "wiregui.main"], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + + for _ in range(30): + try: + r = httpx.get(f"{IDP_APP_BASE}/api/health", timeout=1) + if r.status_code == 200: + break + except Exception: + pass + time.sleep(1) + else: + proc.kill() + out = proc.stdout.read().decode() if proc.stdout else "" + pytest.fail(f"App did not start in time. Output:\n{out}") + + yield proc + + proc.terminate() + proc.wait(timeout=10) + + +async def test_oidc_button_visible_on_login(app_with_idp, page: Page): + await page.goto(f"{IDP_APP_BASE}/login") + await page.wait_for_load_state("networkidle") + await expect(page.get_by_text("Sign in with Mock IdP")).to_be_visible(timeout=10_000) + + +async def test_full_oidc_login_flow(app_with_idp, page: Page): + """Click the OIDC button → mock-oidc login → redirected back → authenticated.""" + await page.goto(f"{IDP_APP_BASE}/auth/oidc/test-idp") + await page.wait_for_url("**/test-idp/authorize**", timeout=10_000) + + await page.locator("input[name='username']").fill("oidc-e2e-user@test.local") + await page.locator("input[type='submit']").click() + + await page.wait_for_url(f"{IDP_APP_BASE}/**", timeout=15_000) + await page.wait_for_load_state("networkidle") + await page.wait_for_timeout(3000) + + assert "/login" not in page.url, f"OIDC login failed — still on login page: {page.url}" + await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000) diff --git a/tests/e2e/test_login.py b/tests/e2e/test_login.py new file mode 100644 index 0000000..bc2f82e --- /dev/null +++ b/tests/e2e/test_login.py @@ -0,0 +1,63 @@ +"""End-to-end tests for login, logout, and auth guard flows.""" + +from playwright.async_api import Page, expect + +from wiregui.db import async_session +from wiregui.models.user import User as UserModel +from wiregui.utils.time import utcnow +from tests.e2e.conftest import TEST_APP_BASE, TEST_EMAIL, TEST_PASSWORD, login + + +async def test_login_valid_credentials(page: Page, test_user: UserModel): + """Valid login redirects to devices page.""" + await login(page) + await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000) + + +async def test_login_invalid_password(page: Page, test_user: UserModel): + """Wrong password shows error and stays on login page.""" + await login(page, password="wrongpassword") + await expect(page.get_by_text("Invalid email or password")).to_be_visible(timeout=10_000) + await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible() + + +async def test_login_nonexistent_email(page: Page, test_user: UserModel): + """Nonexistent email shows error.""" + await login(page, email="nobody@nowhere.com") + await expect(page.get_by_text("Invalid email or password")).to_be_visible(timeout=10_000) + await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible() + + +async def test_login_disabled_user(page: Page, test_user: UserModel): + """Disabled user cannot log in.""" + async with async_session() as session: + u = await session.get(UserModel, test_user.id) + u.disabled_at = utcnow() + session.add(u) + await session.commit() + + try: + await login(page) + await expect(page.get_by_text("Invalid email or password")).to_be_visible(timeout=10_000) + await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible() + finally: + async with async_session() as session: + u = await session.get(UserModel, test_user.id) + u.disabled_at = None + session.add(u) + await session.commit() + + +async def test_logout(page: Page, test_user: UserModel): + """Logout clears session and redirects to login.""" + await login(page) + await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000) + + await page.get_by_text("Logout").click() + await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible(timeout=10_000) + + +async def test_unauthenticated_redirect(page: Page, test_user: UserModel): + """Accessing a protected page without auth redirects to login.""" + await page.goto(f"{TEST_APP_BASE}/devices") + await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible(timeout=10_000) diff --git a/uv.lock b/uv.lock index 993c195..ae15c4b 100644 --- a/uv.lock +++ b/uv.lock @@ -620,6 +620,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ac/48/f8b875fa7dea7dd9b33245e37f065af59df6a25af2f9561efa8d822fde51/greenlet-3.3.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:aa6ac98bdfd716a749b84d4034486863fd81c3abde9aa3cf8eff9127981a4ae4", size = 279120, upload-time = "2026-02-20T20:19:01.9Z" }, { url = "https://files.pythonhosted.org/packages/49/8d/9771d03e7a8b1ee456511961e1b97a6d77ae1dea4a34a5b98eee706689d3/greenlet-3.3.2-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab0c7e7901a00bc0a7284907273dc165b32e0d109a6713babd04471327ff7986", size = 603238, upload-time = "2026-02-20T20:47:32.873Z" }, { url = "https://files.pythonhosted.org/packages/59/0e/4223c2bbb63cd5c97f28ffb2a8aee71bdfb30b323c35d409450f51b91e3e/greenlet-3.3.2-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d248d8c23c67d2291ffd47af766e2a3aa9fa1c6703155c099feb11f526c63a92", size = 614219, upload-time = "2026-02-20T20:55:59.817Z" }, + { url = "https://files.pythonhosted.org/packages/94/2b/4d012a69759ac9d77210b8bfb128bc621125f5b20fc398bce3940d036b1c/greenlet-3.3.2-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ccd21bb86944ca9be6d967cf7691e658e43417782bce90b5d2faeda0ff78a7dd", size = 628268, upload-time = "2026-02-20T21:02:48.024Z" }, { url = "https://files.pythonhosted.org/packages/7a/34/259b28ea7a2a0c904b11cd36c79b8cef8019b26ee5dbe24e73b469dea347/greenlet-3.3.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b6997d360a4e6a4e936c0f9625b1c20416b8a0ea18a8e19cabbefc712e7397ab", size = 616774, upload-time = "2026-02-20T20:21:02.454Z" }, { url = "https://files.pythonhosted.org/packages/0a/03/996c2d1689d486a6e199cb0f1cf9e4aa940c500e01bdf201299d7d61fa69/greenlet-3.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:64970c33a50551c7c50491671265d8954046cb6e8e2999aacdd60e439b70418a", size = 1571277, upload-time = "2026-02-20T20:49:34.795Z" }, { url = "https://files.pythonhosted.org/packages/d9/c4/2570fc07f34a39f2caf0bf9f24b0a1a0a47bc2e8e465b2c2424821389dfc/greenlet-3.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1a9172f5bf6bd88e6ba5a84e0a68afeac9dc7b6b412b245dd64f52d83c81e55b", size = 1640455, upload-time = "2026-02-20T20:21:10.261Z" }, @@ -628,6 +629,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/ae/8bffcbd373b57a5992cd077cbe8858fff39110480a9d50697091faea6f39/greenlet-3.3.2-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:8d1658d7291f9859beed69a776c10822a0a799bc4bfe1bd4272bb60e62507dab", size = 279650, upload-time = "2026-02-20T20:18:00.783Z" }, { url = "https://files.pythonhosted.org/packages/d1/c0/45f93f348fa49abf32ac8439938726c480bd96b2a3c6f4d949ec0124b69f/greenlet-3.3.2-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:18cb1b7337bca281915b3c5d5ae19f4e76d35e1df80f4ad3c1a7be91fadf1082", size = 650295, upload-time = "2026-02-20T20:47:34.036Z" }, { url = "https://files.pythonhosted.org/packages/b3/de/dd7589b3f2b8372069ab3e4763ea5329940fc7ad9dcd3e272a37516d7c9b/greenlet-3.3.2-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2e47408e8ce1c6f1ceea0dffcdf6ebb85cc09e55c7af407c99f1112016e45e9", size = 662163, upload-time = "2026-02-20T20:56:01.295Z" }, + { url = "https://files.pythonhosted.org/packages/cd/ac/85804f74f1ccea31ba518dcc8ee6f14c79f73fe36fa1beba38930806df09/greenlet-3.3.2-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:e3cb43ce200f59483eb82949bf1835a99cf43d7571e900d7c8d5c62cdf25d2f9", size = 675371, upload-time = "2026-02-20T21:02:49.664Z" }, { url = "https://files.pythonhosted.org/packages/d2/d8/09bfa816572a4d83bccd6750df1926f79158b1c36c5f73786e26dbe4ee38/greenlet-3.3.2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:63d10328839d1973e5ba35e98cccbca71b232b14051fd957b6f8b6e8e80d0506", size = 664160, upload-time = "2026-02-20T20:21:04.015Z" }, { url = "https://files.pythonhosted.org/packages/48/cf/56832f0c8255d27f6c35d41b5ec91168d74ec721d85f01a12131eec6b93c/greenlet-3.3.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8e4ab3cfb02993c8cc248ea73d7dae6cec0253e9afa311c9b37e603ca9fad2ce", size = 1619181, upload-time = "2026-02-20T20:49:36.052Z" }, { url = "https://files.pythonhosted.org/packages/0a/23/b90b60a4aabb4cec0796e55f25ffbfb579a907c3898cd2905c8918acaa16/greenlet-3.3.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:94ad81f0fd3c0c0681a018a976e5c2bd2ca2d9d94895f23e7bb1af4e8af4e2d5", size = 1687713, upload-time = "2026-02-20T20:21:11.684Z" }, @@ -636,6 +638,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/98/6d/8f2ef704e614bcf58ed43cfb8d87afa1c285e98194ab2cfad351bf04f81e/greenlet-3.3.2-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:e26e72bec7ab387ac80caa7496e0f908ff954f31065b0ffc1f8ecb1338b11b54", size = 286617, upload-time = "2026-02-20T20:19:29.856Z" }, { url = "https://files.pythonhosted.org/packages/5e/0d/93894161d307c6ea237a43988f27eba0947b360b99ac5239ad3fe09f0b47/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b466dff7a4ffda6ca975979bab80bdadde979e29fc947ac3be4451428d8b0e4", size = 655189, upload-time = "2026-02-20T20:47:35.742Z" }, { url = "https://files.pythonhosted.org/packages/f5/2c/d2d506ebd8abcb57386ec4f7ba20f4030cbe56eae541bc6fd6ef399c0b41/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b8bddc5b73c9720bea487b3bffdb1840fe4e3656fba3bd40aa1489e9f37877ff", size = 658225, upload-time = "2026-02-20T20:56:02.527Z" }, + { url = "https://files.pythonhosted.org/packages/d1/67/8197b7e7e602150938049d8e7f30de1660cfb87e4c8ee349b42b67bdb2e1/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:59b3e2c40f6706b05a9cd299c836c6aa2378cabe25d021acd80f13abf81181cf", size = 666581, upload-time = "2026-02-20T21:02:51.526Z" }, { url = "https://files.pythonhosted.org/packages/8e/30/3a09155fbf728673a1dea713572d2d31159f824a37c22da82127056c44e4/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b26b0f4428b871a751968285a1ac9648944cea09807177ac639b030bddebcea4", size = 657907, upload-time = "2026-02-20T20:21:05.259Z" }, { url = "https://files.pythonhosted.org/packages/f3/fd/d05a4b7acd0154ed758797f0a43b4c0962a843bedfe980115e842c5b2d08/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1fb39a11ee2e4d94be9a76671482be9398560955c9e568550de0224e41104727", size = 1618857, upload-time = "2026-02-20T20:49:37.309Z" }, { url = "https://files.pythonhosted.org/packages/6f/e1/50ee92a5db521de8f35075b5eff060dd43d39ebd46c2181a2042f7070385/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:20154044d9085151bc309e7689d6f7ba10027f8f5a8c0676ad398b951913d89e", size = 1680010, upload-time = "2026-02-20T20:21:13.427Z" }, @@ -1137,6 +1140,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/d2/de599c95ba0a973b94410477f8bf0b6f0b5e67360eb89bcb1ad365258beb/pillow-12.1.1-cp314-cp314t-win_arm64.whl", hash = "sha256:7b03048319bfc6170e93bd60728a1af51d3dd7704935feb228c4d4faab35d334", size = 2546446, upload-time = "2026-02-11T04:22:50.342Z" }, ] +[[package]] +name = "playwright" +version = "1.58.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "greenlet" }, + { name = "pyee" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/c9/9c6061d5703267f1baae6a4647bfd1862e386fbfdb97d889f6f6ae9e3f64/playwright-1.58.0-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:96e3204aac292ee639edbfdef6298b4be2ea0a55a16b7068df91adac077cc606", size = 42251098, upload-time = "2026-01-30T15:09:24.028Z" }, + { url = "https://files.pythonhosted.org/packages/e0/40/59d34a756e02f8c670f0fee987d46f7ee53d05447d43cd114ca015cb168c/playwright-1.58.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:70c763694739d28df71ed578b9c8202bb83e8fe8fb9268c04dd13afe36301f71", size = 41039625, upload-time = "2026-01-30T15:09:27.558Z" }, + { url = "https://files.pythonhosted.org/packages/e1/ee/3ce6209c9c74a650aac9028c621f357a34ea5cd4d950700f8e2c4b7fe2c4/playwright-1.58.0-py3-none-macosx_11_0_universal2.whl", hash = "sha256:185e0132578733d02802dfddfbbc35f42be23a45ff49ccae5081f25952238117", size = 42251098, upload-time = "2026-01-30T15:09:30.461Z" }, + { url = "https://files.pythonhosted.org/packages/f1/af/009958cbf23fac551a940d34e3206e6c7eed2b8c940d0c3afd1feb0b0589/playwright-1.58.0-py3-none-manylinux1_x86_64.whl", hash = "sha256:c95568ba1eda83812598c1dc9be60b4406dffd60b149bc1536180ad108723d6b", size = 46235268, upload-time = "2026-01-30T15:09:33.787Z" }, + { url = "https://files.pythonhosted.org/packages/d9/a6/0e66ad04b6d3440dae73efb39540c5685c5fc95b17c8b29340b62abbd952/playwright-1.58.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f9999948f1ab541d98812de25e3a8c410776aa516d948807140aff797b4bffa", size = 45964214, upload-time = "2026-01-30T15:09:36.751Z" }, + { url = "https://files.pythonhosted.org/packages/0e/4b/236e60ab9f6d62ed0fd32150d61f1f494cefbf02304c0061e78ed80c1c32/playwright-1.58.0-py3-none-win32.whl", hash = "sha256:1e03be090e75a0fabbdaeab65ce17c308c425d879fa48bb1d7986f96bfad0b99", size = 36815998, upload-time = "2026-01-30T15:09:39.627Z" }, + { url = "https://files.pythonhosted.org/packages/41/f8/5ec599c5e59d2f2f336a05b4f318e733077cd5044f24adb6f86900c3e6a7/playwright-1.58.0-py3-none-win_amd64.whl", hash = "sha256:a2bf639d0ce33b3ba38de777e08697b0d8f3dc07ab6802e4ac53fb65e3907af8", size = 36816005, upload-time = "2026-01-30T15:09:42.449Z" }, + { url = "https://files.pythonhosted.org/packages/c8/c4/cc0229fea55c87d6c9c67fe44a21e2cd28d1d558a5478ed4d617e9fb0c93/playwright-1.58.0-py3-none-win_arm64.whl", hash = "sha256:32ffe5c303901a13a0ecab91d1c3f74baf73b84f4bedbb6b935f5bc11cc98e1b", size = 33085919, upload-time = "2026-01-30T15:09:45.71Z" }, +] + [[package]] name = "pluggy" version = "1.6.0" @@ -1315,6 +1337,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/4b/ccc026168948fec4f7555b9164c724cf4125eac006e176541483d2c959be/pydantic_settings-2.13.1-py3-none-any.whl", hash = "sha256:d56fd801823dbeae7f0975e1f8c8e25c258eb75d278ea7abb5d9cebb01b56237", size = 58929, upload-time = "2026-02-19T13:45:06.034Z" }, ] +[[package]] +name = "pyee" +version = "13.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8b/04/e7c1fe4dc78a6fdbfd6c337b1c3732ff543b8a397683ab38378447baa331/pyee-13.0.1.tar.gz", hash = "sha256:0b931f7c14535667ed4c7e0d531716368715e860b988770fc7eb8578d1f67fc8", size = 31655, upload-time = "2026-02-14T21:12:28.044Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/c4/b4d4827c93ef43c01f599ef31453ccc1c132b353284fc6c87d535c233129/pyee-13.0.1-py3-none-any.whl", hash = "sha256:af2f8fede4171ef667dfded53f96e2ed0d6e6bd7ee3bb46437f77e3b57689228", size = 15659, upload-time = "2026-02-14T21:12:26.263Z" }, +] + [[package]] name = "pygments" version = "2.20.0" @@ -1845,6 +1879,7 @@ dependencies = [ { name = "pyotp" }, { name = "python-jose", extra = ["cryptography"] }, { name = "python3-saml" }, + { name = "pyyaml" }, { name = "qrcode", extra = ["pil"] }, { name = "redis" }, { name = "sqlmodel" }, @@ -1853,6 +1888,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "playwright" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, @@ -1874,6 +1910,7 @@ requires-dist = [ { name = "pyotp", specifier = ">=2.9" }, { name = "python-jose", extras = ["cryptography"], specifier = ">=3.3" }, { name = "python3-saml", specifier = ">=1.16" }, + { name = "pyyaml", specifier = ">=6.0" }, { name = "qrcode", extras = ["pil"], specifier = ">=8.0" }, { name = "redis", specifier = ">=5.2" }, { name = "sqlmodel", specifier = ">=0.0.22" }, @@ -1882,6 +1919,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "playwright", specifier = ">=1.58.0" }, { name = "pytest", specifier = ">=8.0" }, { name = "pytest-asyncio", specifier = ">=0.24" }, { name = "pytest-cov", specifier = ">=7.1.0" }, diff --git a/wiregui/auth/seed.py b/wiregui/auth/seed.py index 2d0e1a1..ac23bac 100644 --- a/wiregui/auth/seed.py +++ b/wiregui/auth/seed.py @@ -1,7 +1,9 @@ """Seed the initial admin user and server keypair on first startup.""" import secrets +from pathlib import Path +import yaml from loguru import logger from sqlmodel import select @@ -59,3 +61,76 @@ async def ensure_server_keypair() -> None: logger.info("Server WireGuard keypair generated (pubkey: {}...)", public_key[:20]) except Exception as e: logger.warning("Could not generate server keypair (wg CLI not available?): {}", e) + + +def _upsert_providers(existing: list[dict], incoming: list[dict], kind: str) -> list[dict]: + """Merge incoming providers into existing by `id`, returning the updated list. + + Providers in `incoming` overwrite existing entries with the same `id`. + Existing providers not present in `incoming` are preserved. + """ + by_id = {p["id"]: p for p in existing} + for p in incoming: + pid = p.get("id") + if not pid: + logger.warning("Skipping {} provider without 'id' in IdP config file", kind) + continue + action = "updated" if pid in by_id else "added" + by_id[pid] = p + logger.info("IdP seed: {} {} provider '{}'", action, kind, pid) + return list(by_id.values()) + + +async def seed_idp_providers() -> None: + """Seed OIDC/SAML providers from a YAML config file (if configured). + + Reads WG_IDP_CONFIG_FILE, parses the YAML, and upserts providers into the + Configuration singleton by `id`. Providers not in the YAML are preserved. + """ + settings = get_settings() + if not settings.idp_config_file: + return + + path = Path(settings.idp_config_file) + if not path.is_file(): + logger.warning("IdP config file not found: {}", path) + return + + try: + data = yaml.safe_load(path.read_text()) + except Exception as e: + logger.error("Failed to parse IdP config file {}: {}", path, e) + return + + if not isinstance(data, dict): + logger.error("IdP config file must be a YAML mapping, got {}", type(data).__name__) + return + + oidc_incoming = data.get("openid_connect_providers") or [] + saml_incoming = data.get("saml_identity_providers") or [] + + if not oidc_incoming and not saml_incoming: + logger.debug("IdP config file has no providers defined, skipping") + return + + async with async_session() as session: + result = await session.execute(select(Configuration).limit(1)) + config = result.scalar_one_or_none() + + if config is None: + config = Configuration() + session.add(config) + + if oidc_incoming: + config.openid_connect_providers = _upsert_providers( + config.openid_connect_providers or [], oidc_incoming, "OIDC" + ) + + if saml_incoming: + config.saml_identity_providers = _upsert_providers( + config.saml_identity_providers or [], saml_incoming, "SAML" + ) + + session.add(config) + await session.commit() + logger.info("IdP providers seeded from {}", path) diff --git a/wiregui/config.py b/wiregui/config.py index f682813..17ba1c4 100644 --- a/wiregui/config.py +++ b/wiregui/config.py @@ -41,6 +41,9 @@ class Settings(BaseSettings): smtp_password: str | None = None smtp_from: str = "wiregui@localhost" + # IdP provisioning + idp_config_file: str | None = None # path to YAML file with IdP definitions + # Logging log_to_file: bool = True # write timestamped log file to logs/ directory diff --git a/wiregui/main.py b/wiregui/main.py index be2b216..bb3d8fb 100644 --- a/wiregui/main.py +++ b/wiregui/main.py @@ -5,7 +5,7 @@ from wiregui.api.v0 import router as api_router from wiregui.auth.seed import ensure_server_keypair, seed_admin from wiregui.config import get_settings from wiregui.db import init_db -from wiregui.logging import setup_logging +from wiregui.log_config import setup_logging # Mount REST API app.include_router(api_router, prefix="/api") @@ -38,7 +38,10 @@ async def startup() -> None: await seed_admin() await ensure_server_keypair() - # Register OIDC providers from config + # Seed IdP providers from YAML config file (if configured), then register with authlib + from wiregui.auth.seed import seed_idp_providers + await seed_idp_providers() + from wiregui.auth.oidc import register_providers await register_providers()