diff --git a/.forgejo/workflows/dev.yml b/.forgejo/workflows/dev.yml
index 5a1a023..86a030f 100644
--- a/.forgejo/workflows/dev.yml
+++ b/.forgejo/workflows/dev.yml
@@ -34,11 +34,18 @@ jobs:
env:
SERVER_PORT: "9000"
JSON_CONFIG: '{"interactiveLogin":true,"httpServer":"NettyWrapper","tokenCallbacks":[{"issuerId":"test-idp","tokenExpiry":3600,"requestMappings":[{"requestParam":"scope","match":"*","claims":{"sub":"$${claim:sub}","email":"$${claim:sub}@test.local","name":"Test User"}}]}]}'
+ mock-saml:
+ image: kenchan0130/simplesamlphp
+ env:
+ SIMPLESAMLPHP_SP_ENTITY_ID: http://localhost:13003/auth/saml/test-saml/metadata
+ SIMPLESAMLPHP_SP_ASSERTION_CONSUMER_SERVICE: http://localhost:13003/auth/saml/test-saml/callback
+ SIMPLESAMLPHP_IDP_BASE_URL: http://mock-saml:8080/simplesaml/
env:
CI: "true"
WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui
WG_REDIS_URL: redis://valkey:6379/0
MOCK_OIDC_HOST: mock-oidc
+ MOCK_SAML_HOST: mock-saml
steps:
- name: Install system dependencies and checkout
run: |
diff --git a/.forgejo/workflows/release.yml b/.forgejo/workflows/release.yml
index 187075f..263e1ae 100644
--- a/.forgejo/workflows/release.yml
+++ b/.forgejo/workflows/release.yml
@@ -35,11 +35,18 @@ jobs:
env:
SERVER_PORT: "9000"
JSON_CONFIG: '{"interactiveLogin":true,"httpServer":"NettyWrapper","tokenCallbacks":[{"issuerId":"test-idp","tokenExpiry":3600,"requestMappings":[{"requestParam":"scope","match":"*","claims":{"sub":"$${claim:sub}","email":"$${claim:sub}@test.local","name":"Test User"}}]}]}'
+ mock-saml:
+ image: kenchan0130/simplesamlphp
+ env:
+ SIMPLESAMLPHP_SP_ENTITY_ID: http://localhost:13003/auth/saml/test-saml/metadata
+ SIMPLESAMLPHP_SP_ASSERTION_CONSUMER_SERVICE: http://localhost:13003/auth/saml/test-saml/callback
+ SIMPLESAMLPHP_IDP_BASE_URL: http://mock-saml:8080/simplesaml/
env:
CI: "true"
WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui
WG_REDIS_URL: redis://valkey:6379/0
MOCK_OIDC_HOST: mock-oidc
+ MOCK_SAML_HOST: mock-saml
steps:
- name: Install system dependencies and checkout
run: |
diff --git a/TODO.md b/TODO.md
index 340c382..e23431f 100644
--- a/TODO.md
+++ b/TODO.md
@@ -1,6 +1,6 @@
# WireGUI — Pending Items
-**Test count: 174 (164 unit + 10 E2E) | Coverage: ~35%**
+**Test count: 268 (198 unit + 70 E2E) | Coverage: 36% unit, ~63% effective (incl. E2E)**
---
@@ -11,7 +11,7 @@
Migration of Wirezone (Elixir/Phoenix) to Python/NiceGUI.
Source: `/home/stefanob/PycharmProjects/personal/wirezone`
-**Test count: 199 (164 unit + 35 E2E) | Coverage: 35%**
+**Test count: 268 (198 unit + 70 E2E) | Coverage: 36% unit, ~63% effective (incl. E2E)**
**Run:** `uv run pytest` (unit) / `uv run pytest tests/e2e/` (E2E via Playwright)
@@ -23,11 +23,11 @@ Source: `/home/stefanob/PycharmProjects/personal/wirezone`
### Testing (partially done)
- [ ] HTTP-level integration tests (OIDC redirect/callback flow with respx mocking)
-- [ ] `wiregui/api/deps.py` — test get_current_api_user with real Bearer header parsing, require_admin rejection
-- [ ] `wiregui/services/wireguard.py` — test ensure_interface, set_private_key, set_listen_port
-- [ ] `wiregui/services/firewall.py` — test _nft/_nft_batch error handling, add_device_jump_rule with only ipv4/ipv6
+- [x] `wiregui/api/deps.py` (11 tests) — resolve_bearer_token (valid/expired/invalid/disabled/no-expiry), get_current_api_user (missing header/bad scheme/invalid token/valid token), require_admin (admin/unprivileged)
+- [x] `wiregui/services/wireguard.py` (6 tests) — ensure_interface (exists/creates new), set_private_key, set_listen_port, configure_interface (no config/sets key+port)
+- [x] `wiregui/services/firewall.py` (17 tests) — _nft error/success, _nft_batch error/stdin, add_device_jump_rule (ipv4-only/ipv6-only/no-ips/both), setup_base_tables error handling, masquerade error, peer-to-peer/lan-to-peers policies, get_ruleset fallback
- [ ] `wiregui/tasks/oidc_refresh.py` — test successful refresh, failure with notification, disable_vpn_on_oidc_error
-- [ ] `wiregui/auth/saml.py` (0%) — needs mock SAML IdP metadata + response parsing
+- [x] `wiregui/auth/saml.py` — full SAML flow tested via mock SimpleSAMLphp IdP (e2e)
- [ ] `wiregui/auth/webauthn.py` — test verify_registration, verify_authentication with mock credential data
- [ ] E2E tests for admin pages (users, devices, rules, settings)
@@ -37,46 +37,52 @@ Source: `/home/stefanob/PycharmProjects/personal/wirezone`
- [x] `tests/e2e/test_account.py` (8 tests) — change password (success/wrong/mismatch/short), create API token, TOTP registration + invalid code, account deletion
- [x] `tests/e2e/test_admin_users.py` (10 tests) — page renders, create user, duplicate email, edit role/password, disable/enable, delete, cascade delete, self-delete guard
- [x] `tests/e2e/test_idp_seed.py` (9 tests) — IdP YAML seeding (noop/missing/invalid, OIDC/SAML add, upsert, preserve), OIDC button visible, full OIDC login flow via mock-oidc
+- [x] `tests/e2e/test_mfa_login.py` (4 tests) — MFA redirect on login, valid TOTP completes login, invalid code error, cancel returns to login
+- [x] `tests/e2e/test_magic_link_page.py` (4 tests) — page renders, success on submit, empty email error, back to login
+- [x] `tests/e2e/test_admin_devices.py` (7 tests) — list all devices, filter by user, create with defaults, create with overrides, edit name/description, delete, config dialog with QR
+- [x] `tests/e2e/test_admin_rules.py` (7 tests) — list rules table, create accept/drop/global rules, edit action/destination, delete rule (all verified in DB)
+- [x] `tests/e2e/test_admin_settings.py` (9 tests) — client defaults save/reload, security toggles (local auth, VPN session, unprivileged), OIDC add/delete, SAML add/delete (all verified in DB)
+- [x] `tests/e2e/test_saml_login.py` (4 tests) — SAML button visible, redirect to IdP, SP metadata endpoint, full SAML login flow via mock SimpleSAMLphp
**E2E tests still needed:**
`tests/e2e/test_login.py` — Login & Auth flows (remaining):
-- [ ] Login with MFA → redirects to /mfa challenge page
-- [ ] MFA challenge: valid TOTP code → completes login
-- [ ] MFA challenge: invalid code → shows error, stays on /mfa
-- [ ] MFA challenge: cancel → returns to /login
-- [ ] Magic link request page renders, shows success on submit
+- [x] Login with MFA → redirects to /mfa challenge page
+- [x] MFA challenge: valid TOTP code → completes login
+- [x] MFA challenge: invalid code → shows error, stays on /mfa
+- [x] MFA challenge: cancel → returns to /login
+- [x] Magic link request page renders, shows success on submit
`tests/e2e/test_admin_devices.py` — Admin Device Management:
-- [ ] List all devices across users
-- [ ] Filter by user → shows only that user's devices
-- [ ] Create device with full config overrides (DNS, endpoint, MTU, keepalive, allowed IPs)
-- [ ] Create device with defaults → use_default flags all True
-- [ ] Edit device name and description → persists
-- [ ] Edit device config overrides (toggle use_default off, set custom values)
-- [ ] Delete device → removed from table
-- [ ] Config dialog shows valid WireGuard config with real server public key
-- [ ] QR code renders in config dialog
+- [x] List all devices across users
+- [x] Filter by user → shows only that user's devices
+- [x] Create device with full config overrides (DNS, endpoint, MTU, keepalive, allowed IPs)
+- [x] Create device with defaults → use_default flags all True
+- [x] Edit device name and description → persists
+- [x] Edit device config overrides (toggle use_default off, set custom values)
+- [x] Delete device → removed from table
+- [x] Config dialog shows valid WireGuard config with real server public key
+- [x] QR code renders in config dialog
`tests/e2e/test_admin_rules.py` — Admin Firewall Rules:
-- [ ] List rules → table shows action, destination, protocol, port, user
-- [ ] Create accept rule with CIDR → appears in table
-- [ ] Create drop rule with TCP port range → appears correctly
-- [ ] Create global rule (no user) → shows "Global"
-- [ ] Edit rule action (accept → drop) → persists
-- [ ] Edit rule destination → persists
-- [ ] Delete rule → removed from table
+- [x] List rules → table shows action, destination, protocol, port, user
+- [x] Create accept rule with CIDR → appears in table
+- [x] Create drop rule with TCP port range → appears correctly
+- [x] Create global rule (no user) → shows "Global"
+- [x] Edit rule action (accept → drop) → persists
+- [x] Edit rule destination → persists
+- [x] Delete rule → removed from table
`tests/e2e/test_admin_settings.py` — Admin Settings:
-- [ ] Client defaults: save endpoint, DNS, MTU, keepalive, allowed IPs → persists in DB
-- [ ] Client defaults: saved values reflected on page reload
-- [ ] Security: toggle local auth → persists
-- [ ] Security: change VPN session duration → persists
-- [ ] Security: toggle unprivileged device management/configuration → persists
-- [ ] OIDC: add provider → appears in table
-- [ ] OIDC: delete provider → removed from table
-- [ ] SAML: add provider → appears in table
-- [ ] SAML: delete provider → removed from table
+- [x] Client defaults: save endpoint, DNS, MTU, keepalive, allowed IPs → persists in DB
+- [x] Client defaults: saved values reflected on page reload
+- [x] Security: toggle local auth → persists
+- [x] Security: change VPN session duration → persists
+- [x] Security: toggle unprivileged device management/configuration → persists
+- [x] OIDC: add provider → appears in table
+- [x] OIDC: delete provider → removed from table
+- [x] SAML: add provider → appears in table
+- [x] SAML: delete provider → removed from table
`tests/e2e/test_admin_diagnostics.py` — Admin Diagnostics:
- [ ] Page renders WireGuard interface status
diff --git a/compose.yml b/compose.yml
index a26f327..30dd691 100644
--- a/compose.yml
+++ b/compose.yml
@@ -49,6 +49,21 @@ services:
]
}
+ # Test SAML Identity Provider — SimpleSAMLphp as IdP
+ # IdP Metadata: http://localhost:8080/simplesaml/saml2/idp/metadata.php
+ # Admin UI: http://localhost:8080/simplesaml (admin / secret)
+ # Test users: user1/password, user2/password
+ mock-saml:
+ image: kenchan0130/simplesamlphp
+ ports:
+ - "8080:8080"
+ environment:
+ SIMPLESAMLPHP_SP_ENTITY_ID: "http://localhost:13000/auth/saml/test-saml/metadata"
+ SIMPLESAMLPHP_SP_ASSERTION_CONSUMER_SERVICE: "http://localhost:13000/auth/saml/test-saml/callback"
+ SIMPLESAMLPHP_IDP_BASE_URL: http://localhost:8080/simplesaml/
+ volumes:
+ - ./docker/mock-saml/saml20-sp-remote.php:/var/www/simplesamlphp/metadata/saml20-sp-remote.php:ro
+
volumes:
postgres_data:
valkey_data:
diff --git a/docker/mock-saml/saml20-sp-remote.php b/docker/mock-saml/saml20-sp-remote.php
new file mode 100644
index 0000000..099f8a2
--- /dev/null
+++ b/docker/mock-saml/saml20-sp-remote.php
@@ -0,0 +1,15 @@
+ 'http://localhost:13000/auth/saml/test-saml/callback',
+];
+
+// E2E test instance
+$metadata['http://localhost:13003/auth/saml/test-saml/metadata'] = [
+ 'AssertionConsumerService' => 'http://localhost:13003/auth/saml/test-saml/callback',
+];
\ No newline at end of file
diff --git a/tests/e2e/test_admin_devices.py b/tests/e2e/test_admin_devices.py
new file mode 100644
index 0000000..b44a262
--- /dev/null
+++ b/tests/e2e/test_admin_devices.py
@@ -0,0 +1,239 @@
+"""E2E tests for admin device management page."""
+
+import pytest_asyncio
+from playwright.async_api import Page, expect
+from sqlmodel import select
+
+from wiregui.auth.passwords import hash_password
+from wiregui.db import async_session
+from wiregui.models.device import Device
+from wiregui.models.user import User
+from wiregui.utils.crypto import generate_keypair, generate_preshared_key
+from tests.e2e.conftest import (
+ TEST_APP_BASE,
+ TEST_EMAIL,
+ TEST_PASSWORD,
+ _cleanup_user_by_email,
+ login,
+)
+
+SECOND_USER_EMAIL = "e2e-device-user2@example.com"
+
+
+@pytest_asyncio.fixture
+async def second_user(test_user):
+ """Create a second user with a device for filtering tests."""
+ await _cleanup_user_by_email(SECOND_USER_EMAIL)
+
+ async with async_session() as session:
+ user = User(
+ email=SECOND_USER_EMAIL,
+ password_hash=hash_password("pass12345"),
+ role="unprivileged",
+ )
+ session.add(user)
+ await session.commit()
+ await session.refresh(user)
+
+ yield user
+
+ await _cleanup_user_by_email(SECOND_USER_EMAIL)
+
+
+@pytest_asyncio.fixture
+async def devices_for_both_users(test_user, second_user):
+ """Create one device per user for table/filter tests."""
+ _, pub1 = generate_keypair()
+ _, pub2 = generate_keypair()
+ psk1 = generate_preshared_key()
+ psk2 = generate_preshared_key()
+
+ async with async_session() as session:
+ d1 = Device(
+ name="admin-laptop",
+ public_key=pub1,
+ preshared_key=psk1,
+ ipv4="10.0.0.10",
+ user_id=test_user.id,
+ )
+ d2 = Device(
+ name="user2-phone",
+ public_key=pub2,
+ preshared_key=psk2,
+ ipv4="10.0.0.11",
+ user_id=second_user.id,
+ )
+ session.add_all([d1, d2])
+ await session.commit()
+
+ yield d1, d2
+
+ # Cleanup handled by user fixture cascade
+
+
+async def _go_to_admin_devices(page: Page):
+ """Login as admin and navigate to admin devices page."""
+ await login(page)
+ await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
+ await page.goto(f"{TEST_APP_BASE}/admin/devices")
+ await expect(page.locator("role=main").get_by_text("All Devices")).to_be_visible(timeout=10_000)
+
+
+async def test_list_all_devices(page: Page, devices_for_both_users):
+ """Admin devices page lists devices from all users."""
+ await _go_to_admin_devices(page)
+ await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
+ await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
+
+
+async def test_filter_by_user(page: Page, second_user, devices_for_both_users):
+ """Filtering by user shows only that user's devices."""
+ await _go_to_admin_devices(page)
+ await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
+ await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
+
+ # Filter to second user
+ await page.locator("label:has-text('Filter by User')").click()
+ await page.get_by_role("option", name=SECOND_USER_EMAIL).click()
+ await page.wait_for_timeout(1000)
+
+ await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
+ await expect(page.get_by_text("admin-laptop")).not_to_be_visible()
+
+ # Filter back to all
+ await page.locator("label:has-text('Filter by User')").click()
+ await page.get_by_role("option", name="All Users").click()
+ await page.wait_for_timeout(1000)
+
+ await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
+ await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
+
+
+async def test_create_device_with_defaults(page: Page, test_user):
+ """Create device with all defaults — config dialog appears."""
+ await _go_to_admin_devices(page)
+ await page.get_by_role("button", name="Add Device").click()
+ await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
+
+ await page.locator("input[aria-label='Device Name']").fill("default-test-device")
+ await page.get_by_role("button", name="Create").click()
+
+ # Config dialog should appear with WireGuard config
+ await expect(page.get_by_text("Config for default-test-device")).to_be_visible(timeout=10_000)
+ await expect(page.get_by_text("[Interface]")).to_be_visible(timeout=5_000)
+ await page.get_by_role("button", name="Close").click()
+ await page.wait_for_timeout(500)
+
+ # Device should be in the table
+ await expect(page.get_by_role("cell", name="default-test-device").first).to_be_visible(timeout=5_000)
+
+
+async def test_create_device_with_overrides(page: Page, test_user):
+ """Create device with custom config overrides."""
+ await _go_to_admin_devices(page)
+ await page.get_by_role("button", name="Add Device").click()
+ await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
+
+ await page.locator("input[aria-label='Device Name']").fill("custom-override-dev")
+ await page.locator("input[aria-label='Description (optional)']").fill("Custom overrides test")
+
+ # Toggle off DNS default and set custom — Quasar switches use .q-toggle
+ await page.locator(".q-toggle", has_text="Use default DNS").click()
+ dns_input = page.locator("input[aria-label='DNS Servers']")
+ await dns_input.clear()
+ await dns_input.fill("8.8.8.8, 8.8.4.4")
+
+ # Toggle off MTU default and set custom
+ await page.locator(".q-toggle", has_text="Use default MTU").click()
+ mtu_input = page.locator("input[aria-label='MTU']")
+ await mtu_input.clear()
+ await mtu_input.fill("1400")
+
+ await page.get_by_role("button", name="Create").click()
+
+ await expect(page.get_by_text("Config for custom-override-dev")).to_be_visible(timeout=10_000)
+ await page.get_by_role("button", name="Close").click()
+ await page.wait_for_timeout(500)
+
+ await expect(page.get_by_role("cell", name="custom-override-dev").first).to_be_visible(timeout=5_000)
+
+ # Verify in DB
+ async with async_session() as session:
+ result = await session.execute(
+ select(Device).where(Device.name == "custom-override-dev")
+ .order_by(Device.inserted_at.desc()).limit(1)
+ )
+ device = result.scalar_one()
+ assert device.use_default_dns is False
+ assert "8.8.8.8" in device.dns
+ assert device.use_default_mtu is False
+ assert device.mtu == 1400
+
+
+async def test_edit_device_name_and_description(page: Page, devices_for_both_users):
+ """Edit a device name and description via the edit dialog."""
+ await _go_to_admin_devices(page)
+ await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
+
+ # Click edit button on admin-laptop row — Quasar slot buttons with icon
+ row = page.locator("tr", has_text="admin-laptop")
+ await row.locator(".q-btn").first.click()
+
+ await expect(page.get_by_text("Edit Device")).to_be_visible(timeout=5_000)
+
+ name_input = page.locator(".q-dialog input[aria-label='Device Name']")
+ await name_input.clear()
+ await name_input.fill("admin-laptop-renamed")
+
+ desc_input = page.locator(".q-dialog input[aria-label='Description']")
+ await desc_input.clear()
+ await desc_input.fill("Updated description")
+
+ await page.get_by_role("button", name="Save").click()
+ await expect(page.get_by_text("Device updated")).to_be_visible(timeout=5_000)
+ await expect(page.get_by_text("admin-laptop-renamed")).to_be_visible(timeout=5_000)
+
+
+async def test_delete_device(page: Page, test_user):
+ """Delete a device — removed from table."""
+ _, pub = generate_keypair()
+ async with async_session() as session:
+ d = Device(
+ name="delete-me-device",
+ public_key=pub,
+ preshared_key=generate_preshared_key(),
+ ipv4="10.0.0.99",
+ user_id=test_user.id,
+ )
+ session.add(d)
+ await session.commit()
+
+ await _go_to_admin_devices(page)
+ await expect(page.get_by_role("cell", name="delete-me-device")).to_be_visible(timeout=5_000)
+
+ # Click the delete (second) button in the row
+ row = page.locator("tr", has_text="delete-me-device")
+ await row.locator(".q-btn").nth(1).click()
+
+ await expect(page.get_by_text("Deleted delete-me-device")).to_be_visible(timeout=5_000)
+ await page.wait_for_timeout(1000)
+ await expect(page.get_by_role("cell", name="delete-me-device")).not_to_be_visible()
+
+
+async def test_config_dialog_shows_wg_config(page: Page, test_user):
+ """Config dialog after device creation shows valid WireGuard config."""
+ await _go_to_admin_devices(page)
+ await page.get_by_role("button", name="Add Device").click()
+ await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
+
+ await page.locator("input[aria-label='Device Name']").fill("config-test-device")
+ await page.get_by_role("button", name="Create").click()
+
+ await expect(page.get_by_text("Config for config-test-device")).to_be_visible(timeout=10_000)
+ await expect(page.get_by_text("[Interface]")).to_be_visible(timeout=5_000)
+ await expect(page.get_by_text("[Peer]")).to_be_visible(timeout=5_000)
+ await expect(page.get_by_text("PrivateKey")).to_be_visible()
+ await expect(page.get_by_role("button", name="Download .conf")).to_be_visible()
+
+ # QR code should be rendered
+ await expect(page.locator(".q-dialog img")).to_be_visible(timeout=5_000)
\ No newline at end of file
diff --git a/tests/e2e/test_admin_rules.py b/tests/e2e/test_admin_rules.py
new file mode 100644
index 0000000..6aa5b4e
--- /dev/null
+++ b/tests/e2e/test_admin_rules.py
@@ -0,0 +1,227 @@
+"""E2E tests for admin firewall rules management page."""
+
+from uuid import UUID
+
+import pytest_asyncio
+from playwright.async_api import Page, expect
+from sqlmodel import select
+
+from wiregui.db import async_session
+from wiregui.models.rule import Rule
+from wiregui.models.user import User
+from tests.e2e.conftest import TEST_APP_BASE, TEST_EMAIL, login
+
+
+async def _cleanup_test_rules():
+ """Remove rules created by tests (identified by test-specific destinations)."""
+ async with async_session() as session:
+ result = await session.execute(
+ select(Rule).where(Rule.destination.in_([
+ "10.99.0.0/16", "10.88.0.0/16", "10.77.0.0/16",
+ "10.66.0.0/16", "10.55.0.0/16",
+ ]))
+ )
+ for rule in result.scalars().all():
+ await session.delete(rule)
+ await session.commit()
+
+
+@pytest_asyncio.fixture(autouse=True)
+async def clean_rules(app_server):
+ """Clean up test rules before and after each test."""
+ await _cleanup_test_rules()
+ yield
+ await _cleanup_test_rules()
+
+
+async def _go_to_rules(page: Page):
+ """Login and navigate to admin rules page."""
+ await login(page)
+ await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
+ await page.goto(f"{TEST_APP_BASE}/admin/rules")
+ await expect(page.locator("role=main").get_by_text("Firewall Rules")).to_be_visible(timeout=10_000)
+
+
+async def _create_rule_via_dialog(
+ page: Page, *, action: str = "accept", destination: str = "10.99.0.0/16",
+ protocol: str = "any", port_range: str = "", user: str = "global",
+):
+ """Open create dialog and fill in a rule."""
+ await page.get_by_role("button", name="Add Rule").click()
+ await expect(page.get_by_text("New Firewall Rule")).to_be_visible(timeout=5_000)
+
+ # Action select
+ if action != "accept":
+ await page.locator(".q-dialog label:has-text('Action')").click()
+ await page.get_by_role("option", name=action).click()
+
+ # Destination
+ await page.locator(".q-dialog input[aria-label='Destination (CIDR)']").fill(destination)
+
+ # Protocol
+ if protocol != "any":
+ await page.locator(".q-dialog label:has-text('Protocol')").click()
+ await page.get_by_role("option", name=protocol).click()
+
+ # Port range
+ if port_range:
+ await page.locator(".q-dialog input[aria-label='Port Range']").fill(port_range)
+
+ # User
+ if user != "global":
+ await page.locator(".q-dialog label:has-text('Applies to')").click()
+ await page.get_by_role("option", name=user).click()
+
+ await page.get_by_role("button", name="Create").click()
+ await page.wait_for_timeout(500)
+
+
+async def test_list_rules_table(page: Page, test_user: User):
+ """Rules page renders table with correct columns."""
+ # Seed a rule in DB
+ async with async_session() as session:
+ rule = Rule(action="accept", destination="10.99.0.0/16", port_type="tcp",
+ port_range="443", user_id=test_user.id)
+ session.add(rule)
+ await session.commit()
+
+ await _go_to_rules(page)
+
+ await expect(page.get_by_role("cell", name="accept")).to_be_visible(timeout=5_000)
+ await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible()
+ await expect(page.get_by_role("cell", name="tcp")).to_be_visible()
+ await expect(page.get_by_role("cell", name="443")).to_be_visible()
+ await expect(page.get_by_role("cell", name=TEST_EMAIL)).to_be_visible()
+
+
+async def test_create_accept_rule_with_cidr(page: Page, test_user: User):
+ """Create an accept rule with CIDR — verify in table and DB."""
+ await _go_to_rules(page)
+ await _create_rule_via_dialog(page, action="accept", destination="10.88.0.0/16")
+
+ await expect(page.get_by_role("cell", name="10.88.0.0/16")).to_be_visible(timeout=5_000)
+
+ # Verify in DB
+ async with async_session() as session:
+ result = await session.execute(select(Rule).where(Rule.destination == "10.88.0.0/16"))
+ rule = result.scalar_one()
+ assert rule.action == "accept"
+ assert rule.port_type is None
+ assert rule.port_range is None
+ assert rule.user_id is None
+
+
+async def test_create_drop_rule_with_tcp_port_range(page: Page, test_user: User):
+ """Create a drop rule with TCP port range — verify in table and DB."""
+ await _go_to_rules(page)
+ await _create_rule_via_dialog(
+ page, action="drop", destination="10.77.0.0/16",
+ protocol="tcp", port_range="80-443",
+ )
+
+ await expect(page.get_by_role("cell", name="10.77.0.0/16")).to_be_visible(timeout=5_000)
+ await expect(page.get_by_role("cell", name="drop").first).to_be_visible()
+
+ # Verify in DB
+ async with async_session() as session:
+ result = await session.execute(select(Rule).where(Rule.destination == "10.77.0.0/16"))
+ rule = result.scalar_one()
+ assert rule.action == "drop"
+ assert rule.port_type == "tcp"
+ assert rule.port_range == "80-443"
+
+
+async def test_create_global_rule(page: Page, test_user: User):
+ """Create a global rule (no user) — shows 'Global' in table and DB has null user_id."""
+ await _go_to_rules(page)
+ await _create_rule_via_dialog(page, destination="10.66.0.0/16", user="global")
+
+ await expect(page.get_by_role("cell", name="10.66.0.0/16")).to_be_visible(timeout=5_000)
+ await expect(page.get_by_role("cell", name="Global")).to_be_visible()
+
+ # Verify in DB
+ async with async_session() as session:
+ result = await session.execute(select(Rule).where(Rule.destination == "10.66.0.0/16"))
+ rule = result.scalar_one()
+ assert rule.user_id is None
+
+
+async def test_edit_rule_action(page: Page, test_user: User):
+ """Edit rule action from accept to drop — verify in table and DB."""
+ async with async_session() as session:
+ rule = Rule(action="accept", destination="10.55.0.0/16")
+ session.add(rule)
+ await session.commit()
+ rule_id = rule.id
+
+ await _go_to_rules(page)
+ await expect(page.get_by_role("cell", name="10.55.0.0/16")).to_be_visible(timeout=5_000)
+
+ # Click edit (first button in the row)
+ row = page.locator("tr", has_text="10.55.0.0/16")
+ await row.locator(".q-btn").first.click()
+ await expect(page.get_by_text("Edit Firewall Rule")).to_be_visible(timeout=5_000)
+
+ # Change action to drop
+ await page.locator(".q-dialog label:has-text('Action')").click()
+ await page.get_by_role("option", name="drop").click()
+
+ await page.get_by_role("button", name="Save").click()
+ await expect(page.get_by_text("Rule updated")).to_be_visible(timeout=5_000)
+
+ # Verify in DB
+ async with async_session() as session:
+ rule = await session.get(Rule, rule_id)
+ assert rule.action == "drop"
+
+
+async def test_edit_rule_destination(page: Page, test_user: User):
+ """Edit rule destination — verify in table and DB."""
+ async with async_session() as session:
+ rule = Rule(action="accept", destination="10.99.0.0/16")
+ session.add(rule)
+ await session.commit()
+ rule_id = rule.id
+
+ await _go_to_rules(page)
+ await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible(timeout=5_000)
+
+ row = page.locator("tr", has_text="10.99.0.0/16")
+ await row.locator(".q-btn").first.click()
+ await expect(page.get_by_text("Edit Firewall Rule")).to_be_visible(timeout=5_000)
+
+ dest_input = page.locator(".q-dialog input[aria-label='Destination (CIDR)']")
+ await dest_input.clear()
+ await dest_input.fill("10.88.0.0/16")
+
+ await page.get_by_role("button", name="Save").click()
+ await expect(page.get_by_text("Rule updated")).to_be_visible(timeout=5_000)
+
+ # Verify in DB
+ async with async_session() as session:
+ rule = await session.get(Rule, rule_id)
+ assert rule.destination == "10.88.0.0/16"
+
+
+async def test_delete_rule(page: Page, test_user: User):
+ """Delete a rule — removed from table and DB."""
+ async with async_session() as session:
+ rule = Rule(action="accept", destination="10.99.0.0/16")
+ session.add(rule)
+ await session.commit()
+ rule_id = rule.id
+
+ await _go_to_rules(page)
+ await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible(timeout=5_000)
+
+ # Click delete (second button in the row)
+ row = page.locator("tr", has_text="10.99.0.0/16")
+ await row.locator(".q-btn").nth(1).click()
+ await page.wait_for_timeout(1000)
+
+ await expect(page.get_by_role("cell", name="10.99.0.0/16")).not_to_be_visible()
+
+ # Verify in DB
+ async with async_session() as session:
+ rule = await session.get(Rule, rule_id)
+ assert rule is None
\ No newline at end of file
diff --git a/tests/e2e/test_admin_settings.py b/tests/e2e/test_admin_settings.py
new file mode 100644
index 0000000..bae28e6
--- /dev/null
+++ b/tests/e2e/test_admin_settings.py
@@ -0,0 +1,281 @@
+"""E2E tests for admin settings page — client defaults, security, OIDC/SAML providers."""
+
+import pytest_asyncio
+from playwright.async_api import Page, expect
+from sqlmodel import select
+
+from wiregui.db import async_session
+from wiregui.models.configuration import Configuration
+from wiregui.models.user import User
+from tests.e2e.conftest import TEST_APP_BASE, login
+
+
+@pytest_asyncio.fixture(autouse=True)
+async def reset_config(app_server):
+ """Snapshot config before test, restore after."""
+ async with async_session() as session:
+ c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
+ if not c:
+ yield
+ return
+ snap = {
+ "default_client_endpoint": c.default_client_endpoint,
+ "default_client_dns": list(c.default_client_dns),
+ "default_client_mtu": c.default_client_mtu,
+ "default_client_persistent_keepalive": c.default_client_persistent_keepalive,
+ "default_client_allowed_ips": list(c.default_client_allowed_ips),
+ "vpn_session_duration": c.vpn_session_duration,
+ "local_auth_enabled": c.local_auth_enabled,
+ "allow_unprivileged_device_management": c.allow_unprivileged_device_management,
+ "allow_unprivileged_device_configuration": c.allow_unprivileged_device_configuration,
+ "openid_connect_providers": list(c.openid_connect_providers or []),
+ "saml_identity_providers": list(c.saml_identity_providers or []),
+ }
+ cid = c.id
+
+ yield
+
+ async with async_session() as session:
+ c = await session.get(Configuration, cid)
+ if c:
+ for k, v in snap.items():
+ setattr(c, k, v)
+ session.add(c)
+ await session.commit()
+
+
+async def _go_to_settings(page: Page):
+ await login(page)
+ await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
+ await page.goto(f"{TEST_APP_BASE}/admin/settings")
+ await expect(page.get_by_text("Default Client Configuration")).to_be_visible(timeout=10_000)
+
+
+# --- Client Defaults ---
+
+
+async def test_save_client_defaults(page: Page, test_user: User):
+ """Save endpoint, DNS, MTU, keepalive, allowed IPs — verify persists in DB."""
+ await _go_to_settings(page)
+
+ endpoint = page.locator("input[aria-label='Endpoint']")
+ await endpoint.clear()
+ await endpoint.fill("vpn.test.local")
+
+ dns = page.locator("input[aria-label='DNS Servers']")
+ await dns.clear()
+ await dns.fill("9.9.9.9, 149.112.112.112")
+
+ mtu = page.locator("input[aria-label='MTU']")
+ await mtu.clear()
+ await mtu.fill("1420")
+
+ keepalive = page.locator("input[aria-label='Persistent Keepalive']")
+ await keepalive.clear()
+ await keepalive.fill("30")
+
+ allowed = page.locator("input[aria-label='Allowed IPs']")
+ await allowed.clear()
+ await allowed.fill("10.0.0.0/8, 192.168.0.0/16")
+
+ await page.get_by_role("button", name="Save Defaults").click()
+ await expect(page.get_by_text("Client defaults saved")).to_be_visible(timeout=5_000)
+
+ # Verify in DB
+ async with async_session() as session:
+ c = (await session.execute(select(Configuration).limit(1))).scalar_one()
+ assert c.default_client_endpoint == "vpn.test.local"
+ assert c.default_client_dns == ["9.9.9.9", "149.112.112.112"]
+ assert c.default_client_mtu == 1420
+ assert c.default_client_persistent_keepalive == 30
+ assert c.default_client_allowed_ips == ["10.0.0.0/8", "192.168.0.0/16"]
+
+
+async def test_client_defaults_persist_on_reload(page: Page, test_user: User):
+ """Saved defaults are reflected after page reload."""
+ # Set values via DB
+ async with async_session() as session:
+ c = (await session.execute(select(Configuration).limit(1))).scalar_one()
+ c.default_client_endpoint = "reload-test.vpn"
+ c.default_client_dns = ["8.8.8.8"]
+ c.default_client_mtu = 1500
+ c.default_client_persistent_keepalive = 15
+ c.default_client_allowed_ips = ["172.16.0.0/12"]
+ session.add(c)
+ await session.commit()
+
+ await _go_to_settings(page)
+
+ await expect(page.locator("input[aria-label='Endpoint']")).to_have_value("reload-test.vpn")
+ await expect(page.locator("input[aria-label='DNS Servers']")).to_have_value("8.8.8.8")
+ await expect(page.locator("input[aria-label='MTU']")).to_have_value("1500")
+ await expect(page.locator("input[aria-label='Persistent Keepalive']")).to_have_value("15")
+ await expect(page.locator("input[aria-label='Allowed IPs']")).to_have_value("172.16.0.0/12")
+
+
+# --- Security ---
+
+
+async def test_save_security_local_auth_toggle(page: Page, test_user: User):
+ """Toggle local auth off — verify in DB."""
+ await _go_to_settings(page)
+
+ # Find the local auth switch and toggle it off
+ switch = page.locator(".q-toggle", has_text="Local Authentication")
+ await switch.click()
+
+ await page.get_by_role("button", name="Save Security Settings").click()
+ await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000)
+
+ async with async_session() as session:
+ c = (await session.execute(select(Configuration).limit(1))).scalar_one()
+ assert c.local_auth_enabled is False
+
+
+async def test_save_vpn_session_duration(page: Page, test_user: User):
+ """Change VPN session duration — verify in DB."""
+ await _go_to_settings(page)
+
+ await page.locator("label:has-text('VPN Session Duration')").click()
+ await page.get_by_role("option", name="Every Day").click()
+
+ await page.get_by_role("button", name="Save Security Settings").click()
+ await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000)
+
+ async with async_session() as session:
+ c = (await session.execute(select(Configuration).limit(1))).scalar_one()
+ assert c.vpn_session_duration == 86400
+
+
+async def test_save_unprivileged_toggles(page: Page, test_user: User):
+ """Toggle unprivileged device management/configuration — verify in DB."""
+ await _go_to_settings(page)
+
+ await page.locator(".q-toggle", has_text="Allow Unprivileged Device Management").click()
+ await page.locator(".q-toggle", has_text="Allow Unprivileged Device Configuration").click()
+
+ await page.get_by_role("button", name="Save Security Settings").click()
+ await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000)
+
+ async with async_session() as session:
+ c = (await session.execute(select(Configuration).limit(1))).scalar_one()
+ # Toggled from default (True) to False
+ assert c.allow_unprivileged_device_management is False
+ assert c.allow_unprivileged_device_configuration is False
+
+
+# --- OIDC Providers ---
+
+
+async def test_add_oidc_provider(page: Page, test_user: User):
+ """Add an OIDC provider — appears in table and DB."""
+ await _go_to_settings(page)
+
+ await page.get_by_role("button", name="Add OIDC Provider").click()
+ await expect(page.get_by_text("OIDC Provider", exact=True)).to_be_visible(timeout=5_000)
+
+ await page.locator(".q-dialog input[aria-label='Config ID']").fill("e2e-test-oidc")
+ await page.locator(".q-dialog input[aria-label='Label']").fill("E2E Test IdP")
+ await page.locator(".q-dialog input[aria-label='Client ID']").fill("test-client-id")
+ await page.locator(".q-dialog input[aria-label='Client Secret']").fill("test-client-secret")
+ await page.locator(".q-dialog input[aria-label='Discovery Document URI']").fill("https://idp.test/.well-known/openid-configuration")
+
+ await page.locator(".q-dialog").get_by_role("button", name="Save").click()
+ await expect(page.get_by_text("OIDC provider 'E2E Test IdP' saved")).to_be_visible(timeout=5_000)
+
+ await expect(page.get_by_role("cell", name="e2e-test-oidc")).to_be_visible(timeout=5_000)
+
+ # Verify in DB
+ async with async_session() as session:
+ c = (await session.execute(select(Configuration).limit(1))).scalar_one()
+ provider = next((p for p in c.openid_connect_providers if p["id"] == "e2e-test-oidc"), None)
+ assert provider is not None
+ assert provider["label"] == "E2E Test IdP"
+ assert provider["client_id"] == "test-client-id"
+
+
+async def test_delete_oidc_provider(page: Page, test_user: User):
+ """Delete an OIDC provider — removed from table and DB."""
+ # Seed a provider
+ async with async_session() as session:
+ c = (await session.execute(select(Configuration).limit(1))).scalar_one()
+ providers = list(c.openid_connect_providers or [])
+ providers.append({
+ "id": "delete-me-oidc", "label": "Delete Me", "scope": "openid",
+ "client_id": "x", "client_secret": "x",
+ "discovery_document_uri": "https://x/.well-known/openid-configuration",
+ })
+ c.openid_connect_providers = providers
+ session.add(c)
+ await session.commit()
+
+ await _go_to_settings(page)
+ await expect(page.get_by_role("cell", name="delete-me-oidc")).to_be_visible(timeout=5_000)
+
+ row = page.locator("tr", has_text="delete-me-oidc")
+ await row.locator(".q-btn").first.click()
+
+ await expect(page.get_by_text("OIDC provider deleted")).to_be_visible(timeout=5_000)
+ await page.wait_for_timeout(500)
+ await expect(page.get_by_role("cell", name="delete-me-oidc")).not_to_be_visible()
+
+ # Verify in DB
+ async with async_session() as session:
+ c = (await session.execute(select(Configuration).limit(1))).scalar_one()
+ assert not any(p["id"] == "delete-me-oidc" for p in c.openid_connect_providers)
+
+
+# --- SAML Providers ---
+
+
+async def test_add_saml_provider(page: Page, test_user: User):
+ """Add a SAML provider — appears in table and DB."""
+ await _go_to_settings(page)
+
+ await page.get_by_role("button", name="Add SAML Provider").click()
+ await expect(page.get_by_text("SAML Identity Provider", exact=True)).to_be_visible(timeout=5_000)
+
+ await page.locator(".q-dialog input[aria-label='Config ID']").fill("e2e-test-saml")
+ await page.locator(".q-dialog input[aria-label='Label']").fill("E2E SAML IdP")
+ await page.locator(".q-dialog textarea").fill("test")
+
+ await page.locator(".q-dialog").get_by_role("button", name="Save").click()
+ await expect(page.get_by_text("SAML provider 'E2E SAML IdP' saved")).to_be_visible(timeout=5_000)
+
+ await expect(page.get_by_role("cell", name="e2e-test-saml")).to_be_visible(timeout=5_000)
+
+ # Verify in DB
+ async with async_session() as session:
+ c = (await session.execute(select(Configuration).limit(1))).scalar_one()
+ provider = next((p for p in c.saml_identity_providers if p["id"] == "e2e-test-saml"), None)
+ assert provider is not None
+ assert provider["label"] == "E2E SAML IdP"
+
+
+async def test_delete_saml_provider(page: Page, test_user: User):
+ """Delete a SAML provider — removed from table and DB."""
+ async with async_session() as session:
+ c = (await session.execute(select(Configuration).limit(1))).scalar_one()
+ providers = list(c.saml_identity_providers or [])
+ providers.append({
+ "id": "delete-me-saml", "label": "Delete Me SAML",
+ "metadata": "",
+ })
+ c.saml_identity_providers = providers
+ session.add(c)
+ await session.commit()
+
+ await _go_to_settings(page)
+ await expect(page.get_by_role("cell", name="delete-me-saml")).to_be_visible(timeout=5_000)
+
+ row = page.locator("tr", has_text="delete-me-saml")
+ await row.locator(".q-btn").first.click()
+
+ await expect(page.get_by_text("SAML provider deleted")).to_be_visible(timeout=5_000)
+ await page.wait_for_timeout(500)
+ await expect(page.get_by_role("cell", name="delete-me-saml")).not_to_be_visible()
+
+ # Verify in DB
+ async with async_session() as session:
+ c = (await session.execute(select(Configuration).limit(1))).scalar_one()
+ assert not any(p["id"] == "delete-me-saml" for p in c.saml_identity_providers)
\ No newline at end of file
diff --git a/tests/e2e/test_magic_link_page.py b/tests/e2e/test_magic_link_page.py
new file mode 100644
index 0000000..c4676ec
--- /dev/null
+++ b/tests/e2e/test_magic_link_page.py
@@ -0,0 +1,41 @@
+"""E2E tests for magic link request page."""
+
+from playwright.async_api import Page, expect
+
+from tests.e2e.conftest import TEST_APP_BASE, TEST_EMAIL
+from wiregui.models.user import User
+
+
+async def test_magic_link_page_renders(page: Page, test_user: User):
+ """Magic link request page renders with email input and submit button."""
+ await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
+ await page.wait_for_load_state("networkidle")
+ await expect(page.get_by_text("Sign in with magic link")).to_be_visible(timeout=10_000)
+ await expect(page.locator("input[aria-label='Email']")).to_be_visible()
+ await expect(page.get_by_role("button", name="Send Magic Link")).to_be_visible()
+ await expect(page.get_by_role("button", name="Back to login")).to_be_visible()
+
+
+async def test_magic_link_shows_success_on_submit(page: Page, test_user: User):
+ """Submitting an email shows success message (regardless of whether account exists)."""
+ await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
+ await page.wait_for_load_state("networkidle")
+ await page.locator("input[aria-label='Email']").fill(TEST_EMAIL)
+ await page.get_by_role("button", name="Send Magic Link").click()
+ await expect(page.get_by_text("a sign-in link has been sent")).to_be_visible(timeout=5_000)
+
+
+async def test_magic_link_empty_email_shows_error(page: Page, test_user: User):
+ """Submitting without email shows error."""
+ await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
+ await page.wait_for_load_state("networkidle")
+ await page.get_by_role("button", name="Send Magic Link").click()
+ await expect(page.get_by_text("Enter your email")).to_be_visible(timeout=5_000)
+
+
+async def test_magic_link_back_to_login(page: Page, test_user: User):
+ """Back to login button navigates to login page."""
+ await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
+ await page.wait_for_load_state("networkidle")
+ await page.get_by_role("button", name="Back to login").click()
+ await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible(timeout=10_000)
\ No newline at end of file
diff --git a/tests/e2e/test_mfa_login.py b/tests/e2e/test_mfa_login.py
new file mode 100644
index 0000000..0de2e99
--- /dev/null
+++ b/tests/e2e/test_mfa_login.py
@@ -0,0 +1,111 @@
+"""E2E tests for MFA login flow — login with TOTP redirects to /mfa challenge page."""
+
+import pyotp
+import pytest_asyncio
+from playwright.async_api import Page, expect
+
+from wiregui.auth.mfa import generate_totp_secret
+from wiregui.auth.passwords import hash_password
+from wiregui.db import async_session
+from wiregui.models.mfa_method import MFAMethod
+from wiregui.models.user import User
+from tests.e2e.conftest import (
+ FAKE_SERVER_KEY,
+ TEST_APP_BASE,
+ TEST_PASSWORD,
+ _cleanup_user_by_email,
+)
+
+MFA_EMAIL = "e2e-mfa@example.com"
+MFA_PASSWORD = "mfapass123"
+TOTP_SECRET = generate_totp_secret()
+
+
+@pytest_asyncio.fixture
+async def mfa_user(app_server):
+ """Create a user with a TOTP MFA method, clean up after."""
+ await _cleanup_user_by_email(MFA_EMAIL)
+
+ async with async_session() as session:
+ from sqlmodel import select
+ from wiregui.models.configuration import Configuration
+
+ config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
+ if config:
+ if not config.server_public_key:
+ config.server_public_key = FAKE_SERVER_KEY
+ session.add(config)
+ else:
+ config = Configuration(server_public_key=FAKE_SERVER_KEY)
+ session.add(config)
+
+ user = User(
+ email=MFA_EMAIL,
+ password_hash=hash_password(MFA_PASSWORD),
+ role="admin",
+ )
+ session.add(user)
+ await session.commit()
+ await session.refresh(user)
+
+ mfa = MFAMethod(
+ name="Test TOTP",
+ type="totp",
+ payload={"secret": TOTP_SECRET},
+ user_id=user.id,
+ )
+ session.add(mfa)
+ await session.commit()
+
+ yield user
+
+ await _cleanup_user_by_email(MFA_EMAIL)
+
+
+async def _login_mfa_user(page: Page):
+ """Fill login form for the MFA user and submit."""
+ await page.goto(f"{TEST_APP_BASE}/login")
+ await page.wait_for_load_state("networkidle")
+ await page.locator("input[aria-label='Email']").fill(MFA_EMAIL)
+ await page.locator("input[aria-label='Password']").fill(MFA_PASSWORD)
+ await page.get_by_role("button", name="Sign in", exact=True).click()
+
+
+async def test_mfa_login_redirects_to_challenge(page: Page, mfa_user: User):
+ """Login with MFA-enabled user redirects to /mfa challenge page."""
+ await _login_mfa_user(page)
+ await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
+ await expect(page.locator("input[aria-label='Authentication Code']")).to_be_visible()
+
+
+async def test_mfa_valid_totp_completes_login(page: Page, mfa_user: User):
+ """Entering a valid TOTP code on /mfa completes login."""
+ await _login_mfa_user(page)
+ await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
+
+ code = pyotp.TOTP(TOTP_SECRET).now()
+ await page.locator("input[aria-label='Authentication Code']").fill(code)
+ await page.get_by_role("button", name="Verify").click()
+
+ await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
+
+
+async def test_mfa_invalid_code_shows_error(page: Page, mfa_user: User):
+ """Entering an invalid TOTP code shows error and stays on /mfa."""
+ await _login_mfa_user(page)
+ await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
+
+ await page.locator("input[aria-label='Authentication Code']").fill("000000")
+ await page.get_by_role("button", name="Verify").click()
+
+ await expect(page.get_by_text("Invalid code")).to_be_visible(timeout=5_000)
+ await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible()
+
+
+async def test_mfa_cancel_returns_to_login(page: Page, mfa_user: User):
+ """Clicking Cancel on /mfa clears session and returns to login."""
+ await _login_mfa_user(page)
+ await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
+
+ await page.get_by_role("button", name="Cancel").click()
+ await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible(timeout=10_000)
\ No newline at end of file
diff --git a/tests/e2e/test_saml_login.py b/tests/e2e/test_saml_login.py
new file mode 100644
index 0000000..2942750
--- /dev/null
+++ b/tests/e2e/test_saml_login.py
@@ -0,0 +1,177 @@
+"""E2E tests for SAML authentication — mock SimpleSAMLphp IdP.
+
+Requires mock-saml service running (docker compose up -d mock-saml).
+IdP metadata: http://localhost:8080/simplesaml/saml2/idp/metadata.php
+Test users: user1/user1pass, user2/user2pass
+"""
+
+import os
+import subprocess
+import time
+
+import httpx
+import pytest
+import pytest_asyncio
+from playwright.async_api import Page, expect
+from sqlmodel import select
+
+from wiregui.db import async_session
+from wiregui.models.configuration import Configuration
+from wiregui.models.user import User
+from tests.e2e.conftest import FAKE_SERVER_KEY, _cleanup_user_by_email
+
+MOCK_SAML_HOST = os.environ.get("MOCK_SAML_HOST", "localhost")
+MOCK_SAML_METADATA_URL = f"http://{MOCK_SAML_HOST}:8080/simplesaml/saml2/idp/metadata.php"
+
+# Separate app port for SAML tests (like OIDC IdP tests)
+SAML_APP_PORT = 13003
+SAML_APP_BASE = f"http://localhost:{SAML_APP_PORT}"
+
+SAML_TEST_EMAIL = "user1@example.com"
+
+
+def _fetch_idp_metadata() -> str:
+ """Fetch IdP metadata XML from the mock SAML server."""
+ try:
+ r = httpx.get(MOCK_SAML_METADATA_URL, timeout=5)
+ r.raise_for_status()
+ return r.text
+ except Exception:
+ pytest.skip(f"Mock SAML IdP not available at {MOCK_SAML_METADATA_URL}")
+
+
+def _saml_provider_config(metadata: str) -> dict:
+ return {
+ "id": "test-saml",
+ "label": "Sign in with Mock SAML",
+ "metadata": metadata,
+ "sign_requests": False,
+ "sign_metadata": False,
+ "signed_assertion_in_resp": False,
+ "signed_envelopes_in_resp": False,
+ "auto_create_users": True,
+ "strict": False, # Relaxed for test IdP with expired certs
+ }
+
+
+@pytest_asyncio.fixture(scope="module")
+async def saml_metadata():
+ return _fetch_idp_metadata()
+
+
+@pytest.fixture(scope="module")
+def app_with_saml(saml_metadata):
+ """Start a WireGUI instance with a SAML provider seeded in the DB."""
+ import asyncio
+
+ # Seed the SAML provider config into the database
+ async def _seed():
+ async with async_session() as session:
+ config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
+ if config is None:
+ config = Configuration(server_public_key=FAKE_SERVER_KEY)
+ session.add(config)
+ await session.flush()
+
+ providers = list(config.saml_identity_providers or [])
+ providers = [p for p in providers if p.get("id") != "test-saml"]
+ providers.append(_saml_provider_config(saml_metadata))
+ config.saml_identity_providers = providers
+ session.add(config)
+ await session.commit()
+
+ asyncio.get_event_loop().run_until_complete(_seed())
+
+ env = os.environ.copy()
+ env["WG_LOG_TO_FILE"] = "false"
+ env["WG_PORT"] = str(SAML_APP_PORT)
+ env["WG_EXTERNAL_URL"] = SAML_APP_BASE
+ env.pop("PYTEST_CURRENT_TEST", None)
+ env.pop("NICEGUI_SCREEN_TEST_PORT", None)
+
+ proc = subprocess.Popen(
+ ["uv", "run", "python", "-m", "wiregui.main"],
+ env=env,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ )
+
+ for _ in range(30):
+ try:
+ r = httpx.get(f"{SAML_APP_BASE}/api/health", timeout=1)
+ if r.status_code == 200:
+ break
+ except Exception:
+ pass
+ time.sleep(1)
+ else:
+ proc.kill()
+ out = proc.stdout.read().decode() if proc.stdout else ""
+ pytest.fail(f"App did not start in time. Output:\n{out}")
+
+ yield proc
+
+ proc.terminate()
+ proc.wait(timeout=10)
+
+ # Clean up seeded provider and test user
+ async def _cleanup():
+ await _cleanup_user_by_email(SAML_TEST_EMAIL)
+ async with async_session() as session:
+ config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
+ if config:
+ config.saml_identity_providers = [
+ p for p in (config.saml_identity_providers or []) if p.get("id") != "test-saml"
+ ]
+ session.add(config)
+ await session.commit()
+
+ asyncio.get_event_loop().run_until_complete(_cleanup())
+
+
+async def test_saml_button_visible_on_login(app_with_saml, page: Page):
+ """Login page shows SAML provider button."""
+ await page.goto(f"{SAML_APP_BASE}/login")
+ await page.wait_for_load_state("networkidle")
+ await expect(page.get_by_text("Sign in with Mock SAML")).to_be_visible(timeout=10_000)
+
+
+async def test_saml_redirect_to_idp(app_with_saml, page: Page):
+ """Clicking SAML login redirects to the SimpleSAMLphp IdP login page."""
+ await page.goto(f"{SAML_APP_BASE}/auth/saml/test-saml")
+ # Should redirect to the SimpleSAMLphp SSO service
+ await page.wait_for_url(f"**{MOCK_SAML_HOST}:8080/simplesaml/**", timeout=10_000)
+
+
+async def test_saml_sp_metadata_endpoint(app_with_saml, page: Page):
+ """SP metadata endpoint returns valid XML."""
+ response = await page.request.get(f"{SAML_APP_BASE}/auth/saml/test-saml/metadata")
+ assert response.status == 200
+ body = await response.text()
+ assert "EntityDescriptor" in body
+ assert "AssertionConsumerService" in body
+
+
+async def test_full_saml_login_flow(app_with_saml, page: Page):
+ """Full SAML SSO flow: app → IdP login → callback → authenticated."""
+ await page.goto(f"{SAML_APP_BASE}/auth/saml/test-saml")
+ await page.wait_for_url(f"**{MOCK_SAML_HOST}:8080/simplesaml/**", timeout=10_000)
+
+ # SimpleSAMLphp login form
+ await page.locator("input[name='username']").fill("user1")
+ await page.locator("input[name='password']").fill("password")
+ await page.locator("button[type='submit'], input[type='submit']").first.click()
+
+ # Should redirect back to the app after SAML response
+ await page.wait_for_url(f"{SAML_APP_BASE}/**", timeout=15_000)
+ await page.wait_for_load_state("networkidle")
+ await page.wait_for_timeout(3000)
+
+ assert "/login" not in page.url, f"SAML login failed — still on login page: {page.url}"
+
+ # Verify user was auto-created in DB
+ async with async_session() as session:
+ result = await session.execute(select(User).where(User.email == SAML_TEST_EMAIL))
+ user = result.scalar_one_or_none()
+ assert user is not None, f"Expected user {SAML_TEST_EMAIL} to be auto-created"
+ assert user.last_signed_in_method == "saml:test-saml"
\ No newline at end of file
diff --git a/tests/test_api_deps.py b/tests/test_api_deps.py
new file mode 100644
index 0000000..64d8a32
--- /dev/null
+++ b/tests/test_api_deps.py
@@ -0,0 +1,263 @@
+"""Tests for API dependency injection — Bearer token auth and admin guard."""
+
+import hashlib
+from datetime import timedelta
+from uuid import uuid4
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock
+
+from wiregui.auth.api_token import generate_api_token
+from wiregui.auth.passwords import hash_password
+from wiregui.db import async_session
+from wiregui.models.api_token import ApiToken
+from wiregui.models.user import User
+from wiregui.utils.time import utcnow
+
+
+# ========== resolve_bearer_token ==========
+
+
+async def test_resolve_valid_token():
+ """Valid, non-expired token resolves to user."""
+ from wiregui.auth.api_token import resolve_bearer_token
+
+ plaintext, token_hash = generate_api_token()
+
+ async with async_session() as session:
+ user = User(email="api-test@test.com", password_hash=hash_password("x"), role="admin")
+ session.add(user)
+ await session.commit()
+ await session.refresh(user)
+
+ api_token = ApiToken(
+ token_hash=token_hash,
+ user_id=user.id,
+ expires_at=utcnow() + timedelta(hours=1),
+ )
+ session.add(api_token)
+ await session.commit()
+
+ try:
+ async with async_session() as session:
+ resolved = await resolve_bearer_token(session, plaintext)
+ assert resolved is not None
+ assert resolved.id == user.id
+ assert resolved.email == "api-test@test.com"
+ finally:
+ async with async_session() as session:
+ await session.delete(await session.get(ApiToken, api_token.id))
+ await session.delete(await session.get(User, user.id))
+ await session.commit()
+
+
+async def test_resolve_expired_token():
+ """Expired token returns None."""
+ from wiregui.auth.api_token import resolve_bearer_token
+
+ plaintext, token_hash = generate_api_token()
+
+ async with async_session() as session:
+ user = User(email="api-expired@test.com", password_hash=hash_password("x"), role="admin")
+ session.add(user)
+ await session.commit()
+ await session.refresh(user)
+
+ api_token = ApiToken(
+ token_hash=token_hash,
+ user_id=user.id,
+ expires_at=utcnow() - timedelta(hours=1), # already expired
+ )
+ session.add(api_token)
+ await session.commit()
+
+ try:
+ async with async_session() as session:
+ resolved = await resolve_bearer_token(session, plaintext)
+ assert resolved is None
+ finally:
+ async with async_session() as session:
+ await session.delete(await session.get(ApiToken, api_token.id))
+ await session.delete(await session.get(User, user.id))
+ await session.commit()
+
+
+async def test_resolve_invalid_token():
+ """Nonexistent token returns None."""
+ from wiregui.auth.api_token import resolve_bearer_token
+
+ async with async_session() as session:
+ resolved = await resolve_bearer_token(session, "totally-bogus-token")
+ assert resolved is None
+
+
+async def test_resolve_token_disabled_user():
+ """Token for disabled user returns None."""
+ from wiregui.auth.api_token import resolve_bearer_token
+
+ plaintext, token_hash = generate_api_token()
+
+ async with async_session() as session:
+ user = User(
+ email="api-disabled@test.com", password_hash=hash_password("x"),
+ role="admin", disabled_at=utcnow(),
+ )
+ session.add(user)
+ await session.commit()
+ await session.refresh(user)
+
+ api_token = ApiToken(
+ token_hash=token_hash,
+ user_id=user.id,
+ expires_at=utcnow() + timedelta(hours=1),
+ )
+ session.add(api_token)
+ await session.commit()
+
+ try:
+ async with async_session() as session:
+ resolved = await resolve_bearer_token(session, plaintext)
+ assert resolved is None
+ finally:
+ async with async_session() as session:
+ await session.delete(await session.get(ApiToken, api_token.id))
+ await session.delete(await session.get(User, user.id))
+ await session.commit()
+
+
+async def test_resolve_token_no_expiry():
+ """Token without expires_at (never expires) resolves successfully."""
+ from wiregui.auth.api_token import resolve_bearer_token
+
+ plaintext, token_hash = generate_api_token()
+
+ async with async_session() as session:
+ user = User(email="api-noexp@test.com", password_hash=hash_password("x"), role="admin")
+ session.add(user)
+ await session.commit()
+ await session.refresh(user)
+
+ api_token = ApiToken(
+ token_hash=token_hash,
+ user_id=user.id,
+ expires_at=None,
+ )
+ session.add(api_token)
+ await session.commit()
+
+ try:
+ async with async_session() as session:
+ resolved = await resolve_bearer_token(session, plaintext)
+ assert resolved is not None
+ assert resolved.id == user.id
+ finally:
+ async with async_session() as session:
+ await session.delete(await session.get(ApiToken, api_token.id))
+ await session.delete(await session.get(User, user.id))
+ await session.commit()
+
+
+# ========== get_current_api_user (via FastAPI deps) ==========
+
+
+async def test_get_current_api_user_missing_header():
+ """Missing Authorization header raises 401."""
+ from fastapi import HTTPException
+ from wiregui.api.deps import get_current_api_user
+
+ request = MagicMock()
+ request.headers = {}
+
+ with pytest.raises(HTTPException) as exc_info:
+ await get_current_api_user(request, session=AsyncMock())
+ assert exc_info.value.status_code == 401
+ assert "Missing" in exc_info.value.detail
+
+
+async def test_get_current_api_user_bad_scheme():
+ """Non-Bearer auth scheme raises 401."""
+ from fastapi import HTTPException
+ from wiregui.api.deps import get_current_api_user
+
+ request = MagicMock()
+ request.headers = {"Authorization": "Basic dXNlcjpwYXNz"}
+
+ with pytest.raises(HTTPException) as exc_info:
+ await get_current_api_user(request, session=AsyncMock())
+ assert exc_info.value.status_code == 401
+
+
+async def test_get_current_api_user_invalid_token():
+ """Valid Bearer scheme but bogus token raises 401."""
+ from fastapi import HTTPException
+ from wiregui.api.deps import get_current_api_user
+
+ request = MagicMock()
+ request.headers = {"Authorization": "Bearer bogus-token-value"}
+
+ async with async_session() as session:
+ with pytest.raises(HTTPException) as exc_info:
+ await get_current_api_user(request, session=session)
+ assert exc_info.value.status_code == 401
+ assert "Invalid" in exc_info.value.detail
+
+
+async def test_get_current_api_user_valid_token():
+ """Valid Bearer token resolves to user."""
+ from wiregui.api.deps import get_current_api_user
+
+ plaintext, token_hash = generate_api_token()
+
+ async with async_session() as session:
+ user = User(email="api-dep-test@test.com", password_hash=hash_password("x"), role="admin")
+ session.add(user)
+ await session.commit()
+ await session.refresh(user)
+
+ api_token = ApiToken(
+ token_hash=token_hash,
+ user_id=user.id,
+ expires_at=utcnow() + timedelta(hours=1),
+ )
+ session.add(api_token)
+ await session.commit()
+
+ try:
+ request = MagicMock()
+ request.headers = {"Authorization": f"Bearer {plaintext}"}
+
+ async with async_session() as session:
+ resolved = await get_current_api_user(request, session=session)
+ assert resolved.id == user.id
+ finally:
+ async with async_session() as session:
+ await session.delete(await session.get(ApiToken, api_token.id))
+ await session.delete(await session.get(User, user.id))
+ await session.commit()
+
+
+# ========== require_admin ==========
+
+
+async def test_require_admin_allows_admin():
+ """Admin user passes require_admin."""
+ from wiregui.api.deps import require_admin
+
+ admin_user = MagicMock(spec=User)
+ admin_user.role = "admin"
+ result = await require_admin(user=admin_user)
+ assert result == admin_user
+
+
+async def test_require_admin_rejects_unprivileged():
+ """Non-admin user gets 403."""
+ from fastapi import HTTPException
+ from wiregui.api.deps import require_admin
+
+ regular_user = MagicMock(spec=User)
+ regular_user.role = "unprivileged"
+
+ with pytest.raises(HTTPException) as exc_info:
+ await require_admin(user=regular_user)
+ assert exc_info.value.status_code == 403
+ assert "Admin" in exc_info.value.detail
\ No newline at end of file
diff --git a/tests/test_firewall_extended.py b/tests/test_firewall_extended.py
new file mode 100644
index 0000000..08a8df3
--- /dev/null
+++ b/tests/test_firewall_extended.py
@@ -0,0 +1,206 @@
+"""Extended firewall tests — _nft/_nft_batch error handling, add_device_jump_rule edge cases, policies."""
+
+from unittest.mock import AsyncMock, patch
+
+import pytest
+
+from wiregui.services.firewall import (
+ _nft,
+ _nft_batch,
+ add_device_jump_rule,
+ setup_base_tables,
+ setup_masquerade,
+ apply_peer_to_peer_policy,
+ apply_lan_to_peers_policy,
+ get_ruleset,
+)
+
+
+# ========== _nft error handling ==========
+
+
+@patch("asyncio.create_subprocess_exec")
+async def test_nft_raises_on_failure(mock_exec):
+ """_nft raises RuntimeError on non-zero exit code."""
+ mock_proc = AsyncMock()
+ mock_proc.communicate.return_value = (b"", b"nft: error message")
+ mock_proc.returncode = 1
+ mock_exec.return_value = mock_proc
+
+ with pytest.raises(RuntimeError, match="nft.*failed"):
+ await _nft("list ruleset")
+
+
+@patch("asyncio.create_subprocess_exec")
+async def test_nft_returns_stdout_on_success(mock_exec):
+ """_nft returns stdout on success."""
+ mock_proc = AsyncMock()
+ mock_proc.communicate.return_value = (b"table inet wiregui {}", b"")
+ mock_proc.returncode = 0
+ mock_exec.return_value = mock_proc
+
+ result = await _nft("list ruleset")
+ assert "wiregui" in result
+
+
+# ========== _nft_batch error handling ==========
+
+
+@patch("asyncio.create_subprocess_exec")
+async def test_nft_batch_raises_on_failure(mock_exec):
+ """_nft_batch raises RuntimeError on non-zero exit code."""
+ mock_proc = AsyncMock()
+ mock_proc.communicate.return_value = (b"", b"Error: syntax error")
+ mock_proc.returncode = 1
+ mock_exec.return_value = mock_proc
+
+ with pytest.raises(RuntimeError, match="nft batch failed"):
+ await _nft_batch(["add table inet wiregui"])
+
+
+@patch("asyncio.create_subprocess_exec")
+async def test_nft_batch_sends_commands_via_stdin(mock_exec):
+ """_nft_batch sends all commands via stdin to nft -f -."""
+ mock_proc = AsyncMock()
+ mock_proc.communicate.return_value = (b"", b"")
+ mock_proc.returncode = 0
+ mock_exec.return_value = mock_proc
+
+ cmds = ["add table inet wiregui", "add chain inet wiregui test"]
+ await _nft_batch(cmds)
+
+ mock_exec.assert_awaited_once()
+ # Verify nft -f - was called
+ call_args = mock_exec.call_args[0]
+ assert call_args == ("nft", "-f", "-")
+ # Verify stdin data
+ stdin_data = mock_proc.communicate.call_args[0][0]
+ assert b"add table inet wiregui" in stdin_data
+ assert b"add chain inet wiregui test" in stdin_data
+
+
+# ========== add_device_jump_rule edge cases ==========
+
+
+@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
+async def test_add_device_jump_rule_ipv4_only(mock_batch):
+ """Only IPv4 — generates single IPv4 jump rule."""
+ await add_device_jump_rule("user-id-1", "10.0.0.5", None)
+ mock_batch.assert_awaited_once()
+ cmds = mock_batch.call_args[0][0]
+ assert len(cmds) == 1
+ assert "ip saddr 10.0.0.5" in cmds[0]
+ assert "jump" in cmds[0]
+
+
+@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
+async def test_add_device_jump_rule_ipv6_only(mock_batch):
+ """Only IPv6 — generates single IPv6 jump rule."""
+ await add_device_jump_rule("user-id-2", None, "fd00::5")
+ mock_batch.assert_awaited_once()
+ cmds = mock_batch.call_args[0][0]
+ assert len(cmds) == 1
+ assert "ip6 saddr fd00::5" in cmds[0]
+ assert "jump" in cmds[0]
+
+
+@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
+async def test_add_device_jump_rule_no_ips(mock_batch):
+ """Neither IPv4 nor IPv6 — no nft commands issued."""
+ await add_device_jump_rule("user-id-3", None, None)
+ mock_batch.assert_not_awaited()
+
+
+@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
+async def test_add_device_jump_rule_both_ips(mock_batch):
+ """Both IPv4 and IPv6 — generates two jump rules."""
+ await add_device_jump_rule("user-id-4", "10.0.0.7", "fd00::7")
+ mock_batch.assert_awaited_once()
+ cmds = mock_batch.call_args[0][0]
+ assert len(cmds) == 2
+ assert any("ip saddr 10.0.0.7" in c for c in cmds)
+ assert any("ip6 saddr fd00::7" in c for c in cmds)
+
+
+# ========== setup_base_tables — already exists ==========
+
+
+@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
+async def test_setup_base_tables_already_exists(mock_batch):
+ """If table already exists (File exists error), don't raise."""
+ mock_batch.side_effect = RuntimeError("File exists")
+ await setup_base_tables() # should not raise
+
+
+@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
+async def test_setup_base_tables_other_error_raises(mock_batch):
+ """Other nft errors should propagate."""
+ mock_batch.side_effect = RuntimeError("Permission denied")
+ with pytest.raises(RuntimeError, match="Permission denied"):
+ await setup_base_tables()
+
+
+# ========== setup_masquerade — error handling ==========
+
+
+@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
+async def test_setup_masquerade_error_swallowed(mock_batch):
+ """Masquerade errors are logged but not raised."""
+ mock_batch.side_effect = RuntimeError("nft error")
+ await setup_masquerade(iface="wg0") # should not raise
+
+
+# ========== policy functions — command verification ==========
+
+
+@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
+async def test_peer_to_peer_enabled(mock_batch):
+ """Enabling peer-to-peer generates accept rules."""
+ await apply_peer_to_peer_policy(True)
+ cmds = mock_batch.call_args[0][0]
+ assert any("accept" in c for c in cmds)
+ assert any("peer_to_peer" in c for c in cmds)
+
+
+@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
+async def test_peer_to_peer_disabled(mock_batch):
+ """Disabling peer-to-peer generates drop rules."""
+ await apply_peer_to_peer_policy(False)
+ cmds = mock_batch.call_args[0][0]
+ assert any("drop" in c for c in cmds)
+
+
+@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
+async def test_lan_to_peers_enabled(mock_batch):
+ """Enabling LAN-to-peers generates accept rules."""
+ await apply_lan_to_peers_policy(True)
+ cmds = mock_batch.call_args[0][0]
+ assert any("accept" in c for c in cmds)
+ assert any("lan_to_peers" in c for c in cmds)
+
+
+@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
+async def test_lan_to_peers_disabled(mock_batch):
+ """Disabling LAN-to-peers generates drop rules."""
+ await apply_lan_to_peers_policy(False)
+ cmds = mock_batch.call_args[0][0]
+ assert any("drop" in c for c in cmds)
+
+
+# ========== get_ruleset — error handling ==========
+
+
+@patch("wiregui.services.firewall._nft", new_callable=AsyncMock)
+async def test_get_ruleset_returns_output(mock_nft):
+ """get_ruleset returns nft list ruleset output."""
+ mock_nft.return_value = "table inet wiregui { ... }"
+ result = await get_ruleset()
+ assert "wiregui" in result
+
+
+@patch("wiregui.services.firewall._nft", new_callable=AsyncMock)
+async def test_get_ruleset_returns_fallback_on_error(mock_nft):
+ """get_ruleset returns friendly message when nft not available."""
+ mock_nft.side_effect = RuntimeError("nft not found")
+ result = await get_ruleset()
+ assert "not available" in result
\ No newline at end of file
diff --git a/tests/test_wireguard_extended.py b/tests/test_wireguard_extended.py
new file mode 100644
index 0000000..ab848df
--- /dev/null
+++ b/tests/test_wireguard_extended.py
@@ -0,0 +1,114 @@
+"""Tests for WireGuard service — ensure_interface, set_private_key, set_listen_port, configure_interface."""
+
+from unittest.mock import AsyncMock, patch, call
+
+from wiregui.services.wireguard import (
+ ensure_interface,
+ set_private_key,
+ set_listen_port,
+ configure_interface,
+)
+
+
+# ========== ensure_interface ==========
+
+
+@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
+async def test_ensure_interface_already_exists(mock_run):
+ """If interface exists (ip link show succeeds), do nothing."""
+ mock_run.return_value = ""
+ await ensure_interface(iface="wg-test")
+ # Only called once for ip link show
+ mock_run.assert_awaited_once_with(["ip", "link", "show", "wg-test"])
+
+
+@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
+async def test_ensure_interface_creates_new(mock_run):
+ """If interface doesn't exist, create it, assign IPs, bring up."""
+ call_count = 0
+
+ async def side_effect(args, input_data=None):
+ nonlocal call_count
+ call_count += 1
+ if call_count == 1 and args == ["ip", "link", "show", "wg-test"]:
+ raise RuntimeError("Device not found")
+ return ""
+
+ mock_run.side_effect = side_effect
+ await ensure_interface(iface="wg-test")
+
+ # Should have called: ip link show (fails), ip link add, ip addr add x2, ip link set up
+ assert mock_run.await_count == 5
+ calls = [c[0][0] for c in mock_run.call_args_list]
+ assert calls[1] == ["ip", "link", "add", "wg-test", "type", "wireguard"]
+ assert calls[2][0:3] == ["ip", "address", "add"]
+ assert calls[3][0:3] == ["ip", "address", "add"]
+ assert calls[4] == ["ip", "link", "set", "wg-test", "up"]
+
+
+# ========== set_private_key ==========
+
+
+@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
+async def test_set_private_key(mock_run):
+ """set_private_key calls wg set with private-key path."""
+ mock_run.return_value = ""
+ await set_private_key("/tmp/test.key", iface="wg-test")
+ mock_run.assert_awaited_once_with(["wg", "set", "wg-test", "private-key", "/tmp/test.key"])
+
+
+# ========== set_listen_port ==========
+
+
+@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
+async def test_set_listen_port(mock_run):
+ """set_listen_port calls wg set with listen-port."""
+ mock_run.return_value = ""
+ await set_listen_port(51820, iface="wg-test")
+ mock_run.assert_awaited_once_with(["wg", "set", "wg-test", "listen-port", "51820"])
+
+
+# ========== configure_interface ==========
+
+
+@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
+@patch("wiregui.db.async_session")
+async def test_configure_interface_no_config(mock_session_cls, mock_run):
+ """If no Configuration row exists, do not call wg set."""
+ from unittest.mock import MagicMock
+
+ mock_session = AsyncMock()
+ mock_result = MagicMock()
+ mock_result.scalar_one_or_none.return_value = None
+ mock_session.execute.return_value = mock_result
+ mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=False)
+
+ await configure_interface(iface="wg-test")
+ mock_run.assert_not_awaited()
+
+
+@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
+@patch("wiregui.db.async_session")
+async def test_configure_interface_sets_key_and_port(mock_session_cls, mock_run):
+ """With valid config, writes key to temp file and calls wg set."""
+ from unittest.mock import MagicMock
+
+ mock_config = MagicMock()
+ mock_config.server_private_key = "test-private-key-value"
+
+ mock_session = AsyncMock()
+ mock_result = MagicMock()
+ mock_result.scalar_one_or_none.return_value = mock_config
+ mock_session.execute.return_value = mock_result
+ mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=False)
+
+ mock_run.return_value = ""
+ await configure_interface(iface="wg-test")
+
+ mock_run.assert_awaited_once()
+ args = mock_run.call_args[0][0]
+ assert args[0:3] == ["wg", "set", "wg-test"]
+ assert "private-key" in args
+ assert "listen-port" in args
\ No newline at end of file
diff --git a/wiregui/auth/saml.py b/wiregui/auth/saml.py
index c624a35..c71054f 100644
--- a/wiregui/auth/saml.py
+++ b/wiregui/auth/saml.py
@@ -17,7 +17,7 @@ def _build_saml_settings(provider_config: dict) -> dict:
idp_settings = idp_data.get("idp", {})
return {
- "strict": True,
+ "strict": provider_config.get("strict", True),
"debug": False,
"sp": {
"entityId": f"{base_url}/auth/saml/{provider_config['id']}/metadata",
diff --git a/wiregui/pages/admin/settings.py b/wiregui/pages/admin/settings.py
index 72ec470..7a868c1 100644
--- a/wiregui/pages/admin/settings.py
+++ b/wiregui/pages/admin/settings.py
@@ -6,6 +6,7 @@ from loguru import logger
from nicegui import app, ui
from sqlmodel import select
+from wiregui.config import get_settings
from wiregui.db import async_session
from wiregui.models.configuration import Configuration
from wiregui.pages.layout import layout
diff --git a/wiregui/pages/auth_saml.py b/wiregui/pages/auth_saml.py
index c9dccc2..9183f2b 100644
--- a/wiregui/pages/auth_saml.py
+++ b/wiregui/pages/auth_saml.py
@@ -101,14 +101,13 @@ async def saml_callback(provider_id: str, request: Request):
session.add(user)
await session.commit()
- request.session["authenticated"] = True
- request.session["user_id"] = str(user.id)
- request.session["email"] = user.email
- request.session["role"] = user.role
- request.session["theme_preference"] = user.theme_preference
+ # Store auth data in Starlette session — picked up by /auth/complete
+ request.session["oidc_user_id"] = str(user.id)
+ request.session["oidc_email"] = user.email
+ request.session["oidc_role"] = user.role
logger.info("SAML login: {} via {}", email, provider_id)
- return RedirectResponse(url="/", status_code=303)
+ return RedirectResponse(url="/auth/complete", status_code=303)
except Exception as e:
logger.error("SAML callback failed for {}: {}", provider_id, e)
diff --git a/wiregui/pages/login.py b/wiregui/pages/login.py
index f1b2110..35781d9 100644
--- a/wiregui/pages/login.py
+++ b/wiregui/pages/login.py
@@ -1,4 +1,4 @@
-"""Login page — email/password, MFA redirect, OIDC provider buttons."""
+"""Login page — email/password, MFA redirect, OIDC/SAML provider buttons."""
from nicegui import app, ui
from sqlmodel import select
@@ -6,6 +6,7 @@ from sqlmodel import select
from wiregui.auth.oidc import load_providers
from wiregui.auth.session import authenticate_user
from wiregui.db import async_session
+from wiregui.models.configuration import Configuration
from wiregui.models.mfa_method import MFAMethod
from wiregui.pages.style import apply_style
from wiregui.utils.time import utcnow
@@ -18,9 +19,13 @@ async def login_page():
apply_style()
- # Load OIDC providers for SSO buttons
+ # Load SSO providers for login buttons
oidc_providers = await load_providers()
+ async with async_session() as session:
+ config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
+ saml_providers = config.saml_identity_providers if config else []
+
async def try_login():
user = await authenticate_user(email.value, password.value)
if user is None:
@@ -76,8 +81,8 @@ async def login_page():
password.on("keydown.enter", try_login)
- # OIDC provider buttons
- if oidc_providers:
+ # SSO provider buttons
+ if oidc_providers or saml_providers:
ui.separator().classes("q-my-md")
ui.label("Or sign in with").classes("text-caption text-center w-full")
for provider in oidc_providers:
@@ -87,3 +92,10 @@ async def login_page():
label,
on_click=lambda p=pid: ui.run_javascript(f"window.location.href='/auth/oidc/{p}'"),
).props("color=primary unelevated").classes("w-full q-mt-xs")
+ for provider in saml_providers:
+ pid = provider.get("id", "")
+ label = provider.get("label", pid)
+ ui.button(
+ label,
+ on_click=lambda p=pid: ui.run_javascript(f"window.location.href='/auth/saml/{p}'"),
+ ).props("color=primary unelevated").classes("w-full q-mt-xs")