feat: comprehensive test suite + SAML auth fixes + mock SAML IdP
Some checks failed
Dev / test (push) Failing after 3m14s
Dev / docker (push) Has been skipped

Tests (198 unit + 70 e2e = 268 total):
- Add test_api_deps.py: Bearer token auth, get_current_api_user, require_admin
- Add test_wireguard_extended.py: ensure_interface, set_private_key, set_listen_port
- Add test_firewall_extended.py: _nft/_nft_batch errors, jump rules, policies
- Add test_mfa_login.py: MFA redirect, TOTP verify, invalid code, cancel
- Add test_magic_link_page.py: page render, submit, empty email, back to login
- Add test_admin_devices.py: list, filter, create, edit, delete, config dialog
- Add test_admin_rules.py: list, create, edit, delete (all DB-verified)
- Add test_admin_settings.py: defaults, security, OIDC/SAML providers
- Add test_saml_login.py: button visible, redirect, metadata, full login flow

Bug fixes:
- Fix SAML callback to use /auth/complete bridge (same fix as OIDC)
- Fix missing get_settings import in admin settings page
- Add SAML provider buttons to login page
- Make SAML strict mode configurable per-provider

Infrastructure:
- Add mock SimpleSAMLphp IdP to compose.yml with SP config
- Add mock-saml service to CI workflows (release + dev)
This commit is contained in:
Stefano Bertelli 2026-03-31 16:52:29 -05:00
parent 25cff5e4d9
commit 06b5a3dc12
18 changed files with 1768 additions and 47 deletions

View file

@ -34,11 +34,18 @@ jobs:
env: env:
SERVER_PORT: "9000" SERVER_PORT: "9000"
JSON_CONFIG: '{"interactiveLogin":true,"httpServer":"NettyWrapper","tokenCallbacks":[{"issuerId":"test-idp","tokenExpiry":3600,"requestMappings":[{"requestParam":"scope","match":"*","claims":{"sub":"$${claim:sub}","email":"$${claim:sub}@test.local","name":"Test User"}}]}]}' JSON_CONFIG: '{"interactiveLogin":true,"httpServer":"NettyWrapper","tokenCallbacks":[{"issuerId":"test-idp","tokenExpiry":3600,"requestMappings":[{"requestParam":"scope","match":"*","claims":{"sub":"$${claim:sub}","email":"$${claim:sub}@test.local","name":"Test User"}}]}]}'
mock-saml:
image: kenchan0130/simplesamlphp
env:
SIMPLESAMLPHP_SP_ENTITY_ID: http://localhost:13003/auth/saml/test-saml/metadata
SIMPLESAMLPHP_SP_ASSERTION_CONSUMER_SERVICE: http://localhost:13003/auth/saml/test-saml/callback
SIMPLESAMLPHP_IDP_BASE_URL: http://mock-saml:8080/simplesaml/
env: env:
CI: "true" CI: "true"
WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui
WG_REDIS_URL: redis://valkey:6379/0 WG_REDIS_URL: redis://valkey:6379/0
MOCK_OIDC_HOST: mock-oidc MOCK_OIDC_HOST: mock-oidc
MOCK_SAML_HOST: mock-saml
steps: steps:
- name: Install system dependencies and checkout - name: Install system dependencies and checkout
run: | run: |

View file

@ -35,11 +35,18 @@ jobs:
env: env:
SERVER_PORT: "9000" SERVER_PORT: "9000"
JSON_CONFIG: '{"interactiveLogin":true,"httpServer":"NettyWrapper","tokenCallbacks":[{"issuerId":"test-idp","tokenExpiry":3600,"requestMappings":[{"requestParam":"scope","match":"*","claims":{"sub":"$${claim:sub}","email":"$${claim:sub}@test.local","name":"Test User"}}]}]}' JSON_CONFIG: '{"interactiveLogin":true,"httpServer":"NettyWrapper","tokenCallbacks":[{"issuerId":"test-idp","tokenExpiry":3600,"requestMappings":[{"requestParam":"scope","match":"*","claims":{"sub":"$${claim:sub}","email":"$${claim:sub}@test.local","name":"Test User"}}]}]}'
mock-saml:
image: kenchan0130/simplesamlphp
env:
SIMPLESAMLPHP_SP_ENTITY_ID: http://localhost:13003/auth/saml/test-saml/metadata
SIMPLESAMLPHP_SP_ASSERTION_CONSUMER_SERVICE: http://localhost:13003/auth/saml/test-saml/callback
SIMPLESAMLPHP_IDP_BASE_URL: http://mock-saml:8080/simplesaml/
env: env:
CI: "true" CI: "true"
WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui
WG_REDIS_URL: redis://valkey:6379/0 WG_REDIS_URL: redis://valkey:6379/0
MOCK_OIDC_HOST: mock-oidc MOCK_OIDC_HOST: mock-oidc
MOCK_SAML_HOST: mock-saml
steps: steps:
- name: Install system dependencies and checkout - name: Install system dependencies and checkout
run: | run: |

78
TODO.md
View file

@ -1,6 +1,6 @@
# WireGUI — Pending Items # WireGUI — Pending Items
**Test count: 174 (164 unit + 10 E2E) | Coverage: ~35%** **Test count: 268 (198 unit + 70 E2E) | Coverage: 36% unit, ~63% effective (incl. E2E)**
--- ---
@ -11,7 +11,7 @@
Migration of Wirezone (Elixir/Phoenix) to Python/NiceGUI. Migration of Wirezone (Elixir/Phoenix) to Python/NiceGUI.
Source: `/home/stefanob/PycharmProjects/personal/wirezone` Source: `/home/stefanob/PycharmProjects/personal/wirezone`
**Test count: 199 (164 unit + 35 E2E) | Coverage: 35%** **Test count: 268 (198 unit + 70 E2E) | Coverage: 36% unit, ~63% effective (incl. E2E)**
**Run:** `uv run pytest` (unit) / `uv run pytest tests/e2e/` (E2E via Playwright) **Run:** `uv run pytest` (unit) / `uv run pytest tests/e2e/` (E2E via Playwright)
@ -23,11 +23,11 @@ Source: `/home/stefanob/PycharmProjects/personal/wirezone`
### Testing (partially done) ### Testing (partially done)
- [ ] HTTP-level integration tests (OIDC redirect/callback flow with respx mocking) - [ ] HTTP-level integration tests (OIDC redirect/callback flow with respx mocking)
- [ ] `wiregui/api/deps.py` — test get_current_api_user with real Bearer header parsing, require_admin rejection - [x] `wiregui/api/deps.py` (11 tests) — resolve_bearer_token (valid/expired/invalid/disabled/no-expiry), get_current_api_user (missing header/bad scheme/invalid token/valid token), require_admin (admin/unprivileged)
- [ ] `wiregui/services/wireguard.py` — test ensure_interface, set_private_key, set_listen_port - [x] `wiregui/services/wireguard.py` (6 tests) — ensure_interface (exists/creates new), set_private_key, set_listen_port, configure_interface (no config/sets key+port)
- [ ] `wiregui/services/firewall.py` — test _nft/_nft_batch error handling, add_device_jump_rule with only ipv4/ipv6 - [x] `wiregui/services/firewall.py` (17 tests) — _nft error/success, _nft_batch error/stdin, add_device_jump_rule (ipv4-only/ipv6-only/no-ips/both), setup_base_tables error handling, masquerade error, peer-to-peer/lan-to-peers policies, get_ruleset fallback
- [ ] `wiregui/tasks/oidc_refresh.py` — test successful refresh, failure with notification, disable_vpn_on_oidc_error - [ ] `wiregui/tasks/oidc_refresh.py` — test successful refresh, failure with notification, disable_vpn_on_oidc_error
- [ ] `wiregui/auth/saml.py` (0%) — needs mock SAML IdP metadata + response parsing - [x] `wiregui/auth/saml.py` — full SAML flow tested via mock SimpleSAMLphp IdP (e2e)
- [ ] `wiregui/auth/webauthn.py` — test verify_registration, verify_authentication with mock credential data - [ ] `wiregui/auth/webauthn.py` — test verify_registration, verify_authentication with mock credential data
- [ ] E2E tests for admin pages (users, devices, rules, settings) - [ ] E2E tests for admin pages (users, devices, rules, settings)
@ -37,46 +37,52 @@ Source: `/home/stefanob/PycharmProjects/personal/wirezone`
- [x] `tests/e2e/test_account.py` (8 tests) — change password (success/wrong/mismatch/short), create API token, TOTP registration + invalid code, account deletion - [x] `tests/e2e/test_account.py` (8 tests) — change password (success/wrong/mismatch/short), create API token, TOTP registration + invalid code, account deletion
- [x] `tests/e2e/test_admin_users.py` (10 tests) — page renders, create user, duplicate email, edit role/password, disable/enable, delete, cascade delete, self-delete guard - [x] `tests/e2e/test_admin_users.py` (10 tests) — page renders, create user, duplicate email, edit role/password, disable/enable, delete, cascade delete, self-delete guard
- [x] `tests/e2e/test_idp_seed.py` (9 tests) — IdP YAML seeding (noop/missing/invalid, OIDC/SAML add, upsert, preserve), OIDC button visible, full OIDC login flow via mock-oidc - [x] `tests/e2e/test_idp_seed.py` (9 tests) — IdP YAML seeding (noop/missing/invalid, OIDC/SAML add, upsert, preserve), OIDC button visible, full OIDC login flow via mock-oidc
- [x] `tests/e2e/test_mfa_login.py` (4 tests) — MFA redirect on login, valid TOTP completes login, invalid code error, cancel returns to login
- [x] `tests/e2e/test_magic_link_page.py` (4 tests) — page renders, success on submit, empty email error, back to login
- [x] `tests/e2e/test_admin_devices.py` (7 tests) — list all devices, filter by user, create with defaults, create with overrides, edit name/description, delete, config dialog with QR
- [x] `tests/e2e/test_admin_rules.py` (7 tests) — list rules table, create accept/drop/global rules, edit action/destination, delete rule (all verified in DB)
- [x] `tests/e2e/test_admin_settings.py` (9 tests) — client defaults save/reload, security toggles (local auth, VPN session, unprivileged), OIDC add/delete, SAML add/delete (all verified in DB)
- [x] `tests/e2e/test_saml_login.py` (4 tests) — SAML button visible, redirect to IdP, SP metadata endpoint, full SAML login flow via mock SimpleSAMLphp
**E2E tests still needed:** **E2E tests still needed:**
`tests/e2e/test_login.py` — Login & Auth flows (remaining): `tests/e2e/test_login.py` — Login & Auth flows (remaining):
- [ ] Login with MFA → redirects to /mfa challenge page - [x] Login with MFA → redirects to /mfa challenge page
- [ ] MFA challenge: valid TOTP code → completes login - [x] MFA challenge: valid TOTP code → completes login
- [ ] MFA challenge: invalid code → shows error, stays on /mfa - [x] MFA challenge: invalid code → shows error, stays on /mfa
- [ ] MFA challenge: cancel → returns to /login - [x] MFA challenge: cancel → returns to /login
- [ ] Magic link request page renders, shows success on submit - [x] Magic link request page renders, shows success on submit
`tests/e2e/test_admin_devices.py` — Admin Device Management: `tests/e2e/test_admin_devices.py` — Admin Device Management:
- [ ] List all devices across users - [x] List all devices across users
- [ ] Filter by user → shows only that user's devices - [x] Filter by user → shows only that user's devices
- [ ] Create device with full config overrides (DNS, endpoint, MTU, keepalive, allowed IPs) - [x] Create device with full config overrides (DNS, endpoint, MTU, keepalive, allowed IPs)
- [ ] Create device with defaults → use_default flags all True - [x] Create device with defaults → use_default flags all True
- [ ] Edit device name and description → persists - [x] Edit device name and description → persists
- [ ] Edit device config overrides (toggle use_default off, set custom values) - [x] Edit device config overrides (toggle use_default off, set custom values)
- [ ] Delete device → removed from table - [x] Delete device → removed from table
- [ ] Config dialog shows valid WireGuard config with real server public key - [x] Config dialog shows valid WireGuard config with real server public key
- [ ] QR code renders in config dialog - [x] QR code renders in config dialog
`tests/e2e/test_admin_rules.py` — Admin Firewall Rules: `tests/e2e/test_admin_rules.py` — Admin Firewall Rules:
- [ ] List rules → table shows action, destination, protocol, port, user - [x] List rules → table shows action, destination, protocol, port, user
- [ ] Create accept rule with CIDR → appears in table - [x] Create accept rule with CIDR → appears in table
- [ ] Create drop rule with TCP port range → appears correctly - [x] Create drop rule with TCP port range → appears correctly
- [ ] Create global rule (no user) → shows "Global" - [x] Create global rule (no user) → shows "Global"
- [ ] Edit rule action (accept → drop) → persists - [x] Edit rule action (accept → drop) → persists
- [ ] Edit rule destination → persists - [x] Edit rule destination → persists
- [ ] Delete rule → removed from table - [x] Delete rule → removed from table
`tests/e2e/test_admin_settings.py` — Admin Settings: `tests/e2e/test_admin_settings.py` — Admin Settings:
- [ ] Client defaults: save endpoint, DNS, MTU, keepalive, allowed IPs → persists in DB - [x] Client defaults: save endpoint, DNS, MTU, keepalive, allowed IPs → persists in DB
- [ ] Client defaults: saved values reflected on page reload - [x] Client defaults: saved values reflected on page reload
- [ ] Security: toggle local auth → persists - [x] Security: toggle local auth → persists
- [ ] Security: change VPN session duration → persists - [x] Security: change VPN session duration → persists
- [ ] Security: toggle unprivileged device management/configuration → persists - [x] Security: toggle unprivileged device management/configuration → persists
- [ ] OIDC: add provider → appears in table - [x] OIDC: add provider → appears in table
- [ ] OIDC: delete provider → removed from table - [x] OIDC: delete provider → removed from table
- [ ] SAML: add provider → appears in table - [x] SAML: add provider → appears in table
- [ ] SAML: delete provider → removed from table - [x] SAML: delete provider → removed from table
`tests/e2e/test_admin_diagnostics.py` — Admin Diagnostics: `tests/e2e/test_admin_diagnostics.py` — Admin Diagnostics:
- [ ] Page renders WireGuard interface status - [ ] Page renders WireGuard interface status

View file

@ -49,6 +49,21 @@ services:
] ]
} }
# Test SAML Identity Provider — SimpleSAMLphp as IdP
# IdP Metadata: http://localhost:8080/simplesaml/saml2/idp/metadata.php
# Admin UI: http://localhost:8080/simplesaml (admin / secret)
# Test users: user1/password, user2/password
mock-saml:
image: kenchan0130/simplesamlphp
ports:
- "8080:8080"
environment:
SIMPLESAMLPHP_SP_ENTITY_ID: "http://localhost:13000/auth/saml/test-saml/metadata"
SIMPLESAMLPHP_SP_ASSERTION_CONSUMER_SERVICE: "http://localhost:13000/auth/saml/test-saml/callback"
SIMPLESAMLPHP_IDP_BASE_URL: http://localhost:8080/simplesaml/
volumes:
- ./docker/mock-saml/saml20-sp-remote.php:/var/www/simplesamlphp/metadata/saml20-sp-remote.php:ro
volumes: volumes:
postgres_data: postgres_data:
valkey_data: valkey_data:

View file

@ -0,0 +1,15 @@
<?php
/**
* SAML 2.0 remote SP metadata for WireGUI testing.
* Registers SPs for dev (port 13000) and e2e test (port 13003).
*/
// Dev instance
$metadata['http://localhost:13000/auth/saml/test-saml/metadata'] = [
'AssertionConsumerService' => 'http://localhost:13000/auth/saml/test-saml/callback',
];
// E2E test instance
$metadata['http://localhost:13003/auth/saml/test-saml/metadata'] = [
'AssertionConsumerService' => 'http://localhost:13003/auth/saml/test-saml/callback',
];

View file

@ -0,0 +1,239 @@
"""E2E tests for admin device management page."""
import pytest_asyncio
from playwright.async_api import Page, expect
from sqlmodel import select
from wiregui.auth.passwords import hash_password
from wiregui.db import async_session
from wiregui.models.device import Device
from wiregui.models.user import User
from wiregui.utils.crypto import generate_keypair, generate_preshared_key
from tests.e2e.conftest import (
TEST_APP_BASE,
TEST_EMAIL,
TEST_PASSWORD,
_cleanup_user_by_email,
login,
)
SECOND_USER_EMAIL = "e2e-device-user2@example.com"
@pytest_asyncio.fixture
async def second_user(test_user):
"""Create a second user with a device for filtering tests."""
await _cleanup_user_by_email(SECOND_USER_EMAIL)
async with async_session() as session:
user = User(
email=SECOND_USER_EMAIL,
password_hash=hash_password("pass12345"),
role="unprivileged",
)
session.add(user)
await session.commit()
await session.refresh(user)
yield user
await _cleanup_user_by_email(SECOND_USER_EMAIL)
@pytest_asyncio.fixture
async def devices_for_both_users(test_user, second_user):
"""Create one device per user for table/filter tests."""
_, pub1 = generate_keypair()
_, pub2 = generate_keypair()
psk1 = generate_preshared_key()
psk2 = generate_preshared_key()
async with async_session() as session:
d1 = Device(
name="admin-laptop",
public_key=pub1,
preshared_key=psk1,
ipv4="10.0.0.10",
user_id=test_user.id,
)
d2 = Device(
name="user2-phone",
public_key=pub2,
preshared_key=psk2,
ipv4="10.0.0.11",
user_id=second_user.id,
)
session.add_all([d1, d2])
await session.commit()
yield d1, d2
# Cleanup handled by user fixture cascade
async def _go_to_admin_devices(page: Page):
"""Login as admin and navigate to admin devices page."""
await login(page)
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
await page.goto(f"{TEST_APP_BASE}/admin/devices")
await expect(page.locator("role=main").get_by_text("All Devices")).to_be_visible(timeout=10_000)
async def test_list_all_devices(page: Page, devices_for_both_users):
"""Admin devices page lists devices from all users."""
await _go_to_admin_devices(page)
await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
async def test_filter_by_user(page: Page, second_user, devices_for_both_users):
"""Filtering by user shows only that user's devices."""
await _go_to_admin_devices(page)
await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
# Filter to second user
await page.locator("label:has-text('Filter by User')").click()
await page.get_by_role("option", name=SECOND_USER_EMAIL).click()
await page.wait_for_timeout(1000)
await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
await expect(page.get_by_text("admin-laptop")).not_to_be_visible()
# Filter back to all
await page.locator("label:has-text('Filter by User')").click()
await page.get_by_role("option", name="All Users").click()
await page.wait_for_timeout(1000)
await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
async def test_create_device_with_defaults(page: Page, test_user):
"""Create device with all defaults — config dialog appears."""
await _go_to_admin_devices(page)
await page.get_by_role("button", name="Add Device").click()
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
await page.locator("input[aria-label='Device Name']").fill("default-test-device")
await page.get_by_role("button", name="Create").click()
# Config dialog should appear with WireGuard config
await expect(page.get_by_text("Config for default-test-device")).to_be_visible(timeout=10_000)
await expect(page.get_by_text("[Interface]")).to_be_visible(timeout=5_000)
await page.get_by_role("button", name="Close").click()
await page.wait_for_timeout(500)
# Device should be in the table
await expect(page.get_by_role("cell", name="default-test-device").first).to_be_visible(timeout=5_000)
async def test_create_device_with_overrides(page: Page, test_user):
"""Create device with custom config overrides."""
await _go_to_admin_devices(page)
await page.get_by_role("button", name="Add Device").click()
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
await page.locator("input[aria-label='Device Name']").fill("custom-override-dev")
await page.locator("input[aria-label='Description (optional)']").fill("Custom overrides test")
# Toggle off DNS default and set custom — Quasar switches use .q-toggle
await page.locator(".q-toggle", has_text="Use default DNS").click()
dns_input = page.locator("input[aria-label='DNS Servers']")
await dns_input.clear()
await dns_input.fill("8.8.8.8, 8.8.4.4")
# Toggle off MTU default and set custom
await page.locator(".q-toggle", has_text="Use default MTU").click()
mtu_input = page.locator("input[aria-label='MTU']")
await mtu_input.clear()
await mtu_input.fill("1400")
await page.get_by_role("button", name="Create").click()
await expect(page.get_by_text("Config for custom-override-dev")).to_be_visible(timeout=10_000)
await page.get_by_role("button", name="Close").click()
await page.wait_for_timeout(500)
await expect(page.get_by_role("cell", name="custom-override-dev").first).to_be_visible(timeout=5_000)
# Verify in DB
async with async_session() as session:
result = await session.execute(
select(Device).where(Device.name == "custom-override-dev")
.order_by(Device.inserted_at.desc()).limit(1)
)
device = result.scalar_one()
assert device.use_default_dns is False
assert "8.8.8.8" in device.dns
assert device.use_default_mtu is False
assert device.mtu == 1400
async def test_edit_device_name_and_description(page: Page, devices_for_both_users):
"""Edit a device name and description via the edit dialog."""
await _go_to_admin_devices(page)
await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
# Click edit button on admin-laptop row — Quasar slot buttons with icon
row = page.locator("tr", has_text="admin-laptop")
await row.locator(".q-btn").first.click()
await expect(page.get_by_text("Edit Device")).to_be_visible(timeout=5_000)
name_input = page.locator(".q-dialog input[aria-label='Device Name']")
await name_input.clear()
await name_input.fill("admin-laptop-renamed")
desc_input = page.locator(".q-dialog input[aria-label='Description']")
await desc_input.clear()
await desc_input.fill("Updated description")
await page.get_by_role("button", name="Save").click()
await expect(page.get_by_text("Device updated")).to_be_visible(timeout=5_000)
await expect(page.get_by_text("admin-laptop-renamed")).to_be_visible(timeout=5_000)
async def test_delete_device(page: Page, test_user):
"""Delete a device — removed from table."""
_, pub = generate_keypair()
async with async_session() as session:
d = Device(
name="delete-me-device",
public_key=pub,
preshared_key=generate_preshared_key(),
ipv4="10.0.0.99",
user_id=test_user.id,
)
session.add(d)
await session.commit()
await _go_to_admin_devices(page)
await expect(page.get_by_role("cell", name="delete-me-device")).to_be_visible(timeout=5_000)
# Click the delete (second) button in the row
row = page.locator("tr", has_text="delete-me-device")
await row.locator(".q-btn").nth(1).click()
await expect(page.get_by_text("Deleted delete-me-device")).to_be_visible(timeout=5_000)
await page.wait_for_timeout(1000)
await expect(page.get_by_role("cell", name="delete-me-device")).not_to_be_visible()
async def test_config_dialog_shows_wg_config(page: Page, test_user):
"""Config dialog after device creation shows valid WireGuard config."""
await _go_to_admin_devices(page)
await page.get_by_role("button", name="Add Device").click()
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
await page.locator("input[aria-label='Device Name']").fill("config-test-device")
await page.get_by_role("button", name="Create").click()
await expect(page.get_by_text("Config for config-test-device")).to_be_visible(timeout=10_000)
await expect(page.get_by_text("[Interface]")).to_be_visible(timeout=5_000)
await expect(page.get_by_text("[Peer]")).to_be_visible(timeout=5_000)
await expect(page.get_by_text("PrivateKey")).to_be_visible()
await expect(page.get_by_role("button", name="Download .conf")).to_be_visible()
# QR code should be rendered
await expect(page.locator(".q-dialog img")).to_be_visible(timeout=5_000)

View file

@ -0,0 +1,227 @@
"""E2E tests for admin firewall rules management page."""
from uuid import UUID
import pytest_asyncio
from playwright.async_api import Page, expect
from sqlmodel import select
from wiregui.db import async_session
from wiregui.models.rule import Rule
from wiregui.models.user import User
from tests.e2e.conftest import TEST_APP_BASE, TEST_EMAIL, login
async def _cleanup_test_rules():
"""Remove rules created by tests (identified by test-specific destinations)."""
async with async_session() as session:
result = await session.execute(
select(Rule).where(Rule.destination.in_([
"10.99.0.0/16", "10.88.0.0/16", "10.77.0.0/16",
"10.66.0.0/16", "10.55.0.0/16",
]))
)
for rule in result.scalars().all():
await session.delete(rule)
await session.commit()
@pytest_asyncio.fixture(autouse=True)
async def clean_rules(app_server):
"""Clean up test rules before and after each test."""
await _cleanup_test_rules()
yield
await _cleanup_test_rules()
async def _go_to_rules(page: Page):
"""Login and navigate to admin rules page."""
await login(page)
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
await page.goto(f"{TEST_APP_BASE}/admin/rules")
await expect(page.locator("role=main").get_by_text("Firewall Rules")).to_be_visible(timeout=10_000)
async def _create_rule_via_dialog(
page: Page, *, action: str = "accept", destination: str = "10.99.0.0/16",
protocol: str = "any", port_range: str = "", user: str = "global",
):
"""Open create dialog and fill in a rule."""
await page.get_by_role("button", name="Add Rule").click()
await expect(page.get_by_text("New Firewall Rule")).to_be_visible(timeout=5_000)
# Action select
if action != "accept":
await page.locator(".q-dialog label:has-text('Action')").click()
await page.get_by_role("option", name=action).click()
# Destination
await page.locator(".q-dialog input[aria-label='Destination (CIDR)']").fill(destination)
# Protocol
if protocol != "any":
await page.locator(".q-dialog label:has-text('Protocol')").click()
await page.get_by_role("option", name=protocol).click()
# Port range
if port_range:
await page.locator(".q-dialog input[aria-label='Port Range']").fill(port_range)
# User
if user != "global":
await page.locator(".q-dialog label:has-text('Applies to')").click()
await page.get_by_role("option", name=user).click()
await page.get_by_role("button", name="Create").click()
await page.wait_for_timeout(500)
async def test_list_rules_table(page: Page, test_user: User):
"""Rules page renders table with correct columns."""
# Seed a rule in DB
async with async_session() as session:
rule = Rule(action="accept", destination="10.99.0.0/16", port_type="tcp",
port_range="443", user_id=test_user.id)
session.add(rule)
await session.commit()
await _go_to_rules(page)
await expect(page.get_by_role("cell", name="accept")).to_be_visible(timeout=5_000)
await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible()
await expect(page.get_by_role("cell", name="tcp")).to_be_visible()
await expect(page.get_by_role("cell", name="443")).to_be_visible()
await expect(page.get_by_role("cell", name=TEST_EMAIL)).to_be_visible()
async def test_create_accept_rule_with_cidr(page: Page, test_user: User):
"""Create an accept rule with CIDR — verify in table and DB."""
await _go_to_rules(page)
await _create_rule_via_dialog(page, action="accept", destination="10.88.0.0/16")
await expect(page.get_by_role("cell", name="10.88.0.0/16")).to_be_visible(timeout=5_000)
# Verify in DB
async with async_session() as session:
result = await session.execute(select(Rule).where(Rule.destination == "10.88.0.0/16"))
rule = result.scalar_one()
assert rule.action == "accept"
assert rule.port_type is None
assert rule.port_range is None
assert rule.user_id is None
async def test_create_drop_rule_with_tcp_port_range(page: Page, test_user: User):
"""Create a drop rule with TCP port range — verify in table and DB."""
await _go_to_rules(page)
await _create_rule_via_dialog(
page, action="drop", destination="10.77.0.0/16",
protocol="tcp", port_range="80-443",
)
await expect(page.get_by_role("cell", name="10.77.0.0/16")).to_be_visible(timeout=5_000)
await expect(page.get_by_role("cell", name="drop").first).to_be_visible()
# Verify in DB
async with async_session() as session:
result = await session.execute(select(Rule).where(Rule.destination == "10.77.0.0/16"))
rule = result.scalar_one()
assert rule.action == "drop"
assert rule.port_type == "tcp"
assert rule.port_range == "80-443"
async def test_create_global_rule(page: Page, test_user: User):
"""Create a global rule (no user) — shows 'Global' in table and DB has null user_id."""
await _go_to_rules(page)
await _create_rule_via_dialog(page, destination="10.66.0.0/16", user="global")
await expect(page.get_by_role("cell", name="10.66.0.0/16")).to_be_visible(timeout=5_000)
await expect(page.get_by_role("cell", name="Global")).to_be_visible()
# Verify in DB
async with async_session() as session:
result = await session.execute(select(Rule).where(Rule.destination == "10.66.0.0/16"))
rule = result.scalar_one()
assert rule.user_id is None
async def test_edit_rule_action(page: Page, test_user: User):
"""Edit rule action from accept to drop — verify in table and DB."""
async with async_session() as session:
rule = Rule(action="accept", destination="10.55.0.0/16")
session.add(rule)
await session.commit()
rule_id = rule.id
await _go_to_rules(page)
await expect(page.get_by_role("cell", name="10.55.0.0/16")).to_be_visible(timeout=5_000)
# Click edit (first button in the row)
row = page.locator("tr", has_text="10.55.0.0/16")
await row.locator(".q-btn").first.click()
await expect(page.get_by_text("Edit Firewall Rule")).to_be_visible(timeout=5_000)
# Change action to drop
await page.locator(".q-dialog label:has-text('Action')").click()
await page.get_by_role("option", name="drop").click()
await page.get_by_role("button", name="Save").click()
await expect(page.get_by_text("Rule updated")).to_be_visible(timeout=5_000)
# Verify in DB
async with async_session() as session:
rule = await session.get(Rule, rule_id)
assert rule.action == "drop"
async def test_edit_rule_destination(page: Page, test_user: User):
"""Edit rule destination — verify in table and DB."""
async with async_session() as session:
rule = Rule(action="accept", destination="10.99.0.0/16")
session.add(rule)
await session.commit()
rule_id = rule.id
await _go_to_rules(page)
await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible(timeout=5_000)
row = page.locator("tr", has_text="10.99.0.0/16")
await row.locator(".q-btn").first.click()
await expect(page.get_by_text("Edit Firewall Rule")).to_be_visible(timeout=5_000)
dest_input = page.locator(".q-dialog input[aria-label='Destination (CIDR)']")
await dest_input.clear()
await dest_input.fill("10.88.0.0/16")
await page.get_by_role("button", name="Save").click()
await expect(page.get_by_text("Rule updated")).to_be_visible(timeout=5_000)
# Verify in DB
async with async_session() as session:
rule = await session.get(Rule, rule_id)
assert rule.destination == "10.88.0.0/16"
async def test_delete_rule(page: Page, test_user: User):
"""Delete a rule — removed from table and DB."""
async with async_session() as session:
rule = Rule(action="accept", destination="10.99.0.0/16")
session.add(rule)
await session.commit()
rule_id = rule.id
await _go_to_rules(page)
await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible(timeout=5_000)
# Click delete (second button in the row)
row = page.locator("tr", has_text="10.99.0.0/16")
await row.locator(".q-btn").nth(1).click()
await page.wait_for_timeout(1000)
await expect(page.get_by_role("cell", name="10.99.0.0/16")).not_to_be_visible()
# Verify in DB
async with async_session() as session:
rule = await session.get(Rule, rule_id)
assert rule is None

View file

@ -0,0 +1,281 @@
"""E2E tests for admin settings page — client defaults, security, OIDC/SAML providers."""
import pytest_asyncio
from playwright.async_api import Page, expect
from sqlmodel import select
from wiregui.db import async_session
from wiregui.models.configuration import Configuration
from wiregui.models.user import User
from tests.e2e.conftest import TEST_APP_BASE, login
@pytest_asyncio.fixture(autouse=True)
async def reset_config(app_server):
"""Snapshot config before test, restore after."""
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
if not c:
yield
return
snap = {
"default_client_endpoint": c.default_client_endpoint,
"default_client_dns": list(c.default_client_dns),
"default_client_mtu": c.default_client_mtu,
"default_client_persistent_keepalive": c.default_client_persistent_keepalive,
"default_client_allowed_ips": list(c.default_client_allowed_ips),
"vpn_session_duration": c.vpn_session_duration,
"local_auth_enabled": c.local_auth_enabled,
"allow_unprivileged_device_management": c.allow_unprivileged_device_management,
"allow_unprivileged_device_configuration": c.allow_unprivileged_device_configuration,
"openid_connect_providers": list(c.openid_connect_providers or []),
"saml_identity_providers": list(c.saml_identity_providers or []),
}
cid = c.id
yield
async with async_session() as session:
c = await session.get(Configuration, cid)
if c:
for k, v in snap.items():
setattr(c, k, v)
session.add(c)
await session.commit()
async def _go_to_settings(page: Page):
await login(page)
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
await page.goto(f"{TEST_APP_BASE}/admin/settings")
await expect(page.get_by_text("Default Client Configuration")).to_be_visible(timeout=10_000)
# --- Client Defaults ---
async def test_save_client_defaults(page: Page, test_user: User):
"""Save endpoint, DNS, MTU, keepalive, allowed IPs — verify persists in DB."""
await _go_to_settings(page)
endpoint = page.locator("input[aria-label='Endpoint']")
await endpoint.clear()
await endpoint.fill("vpn.test.local")
dns = page.locator("input[aria-label='DNS Servers']")
await dns.clear()
await dns.fill("9.9.9.9, 149.112.112.112")
mtu = page.locator("input[aria-label='MTU']")
await mtu.clear()
await mtu.fill("1420")
keepalive = page.locator("input[aria-label='Persistent Keepalive']")
await keepalive.clear()
await keepalive.fill("30")
allowed = page.locator("input[aria-label='Allowed IPs']")
await allowed.clear()
await allowed.fill("10.0.0.0/8, 192.168.0.0/16")
await page.get_by_role("button", name="Save Defaults").click()
await expect(page.get_by_text("Client defaults saved")).to_be_visible(timeout=5_000)
# Verify in DB
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
assert c.default_client_endpoint == "vpn.test.local"
assert c.default_client_dns == ["9.9.9.9", "149.112.112.112"]
assert c.default_client_mtu == 1420
assert c.default_client_persistent_keepalive == 30
assert c.default_client_allowed_ips == ["10.0.0.0/8", "192.168.0.0/16"]
async def test_client_defaults_persist_on_reload(page: Page, test_user: User):
"""Saved defaults are reflected after page reload."""
# Set values via DB
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
c.default_client_endpoint = "reload-test.vpn"
c.default_client_dns = ["8.8.8.8"]
c.default_client_mtu = 1500
c.default_client_persistent_keepalive = 15
c.default_client_allowed_ips = ["172.16.0.0/12"]
session.add(c)
await session.commit()
await _go_to_settings(page)
await expect(page.locator("input[aria-label='Endpoint']")).to_have_value("reload-test.vpn")
await expect(page.locator("input[aria-label='DNS Servers']")).to_have_value("8.8.8.8")
await expect(page.locator("input[aria-label='MTU']")).to_have_value("1500")
await expect(page.locator("input[aria-label='Persistent Keepalive']")).to_have_value("15")
await expect(page.locator("input[aria-label='Allowed IPs']")).to_have_value("172.16.0.0/12")
# --- Security ---
async def test_save_security_local_auth_toggle(page: Page, test_user: User):
"""Toggle local auth off — verify in DB."""
await _go_to_settings(page)
# Find the local auth switch and toggle it off
switch = page.locator(".q-toggle", has_text="Local Authentication")
await switch.click()
await page.get_by_role("button", name="Save Security Settings").click()
await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000)
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
assert c.local_auth_enabled is False
async def test_save_vpn_session_duration(page: Page, test_user: User):
"""Change VPN session duration — verify in DB."""
await _go_to_settings(page)
await page.locator("label:has-text('VPN Session Duration')").click()
await page.get_by_role("option", name="Every Day").click()
await page.get_by_role("button", name="Save Security Settings").click()
await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000)
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
assert c.vpn_session_duration == 86400
async def test_save_unprivileged_toggles(page: Page, test_user: User):
"""Toggle unprivileged device management/configuration — verify in DB."""
await _go_to_settings(page)
await page.locator(".q-toggle", has_text="Allow Unprivileged Device Management").click()
await page.locator(".q-toggle", has_text="Allow Unprivileged Device Configuration").click()
await page.get_by_role("button", name="Save Security Settings").click()
await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000)
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
# Toggled from default (True) to False
assert c.allow_unprivileged_device_management is False
assert c.allow_unprivileged_device_configuration is False
# --- OIDC Providers ---
async def test_add_oidc_provider(page: Page, test_user: User):
"""Add an OIDC provider — appears in table and DB."""
await _go_to_settings(page)
await page.get_by_role("button", name="Add OIDC Provider").click()
await expect(page.get_by_text("OIDC Provider", exact=True)).to_be_visible(timeout=5_000)
await page.locator(".q-dialog input[aria-label='Config ID']").fill("e2e-test-oidc")
await page.locator(".q-dialog input[aria-label='Label']").fill("E2E Test IdP")
await page.locator(".q-dialog input[aria-label='Client ID']").fill("test-client-id")
await page.locator(".q-dialog input[aria-label='Client Secret']").fill("test-client-secret")
await page.locator(".q-dialog input[aria-label='Discovery Document URI']").fill("https://idp.test/.well-known/openid-configuration")
await page.locator(".q-dialog").get_by_role("button", name="Save").click()
await expect(page.get_by_text("OIDC provider 'E2E Test IdP' saved")).to_be_visible(timeout=5_000)
await expect(page.get_by_role("cell", name="e2e-test-oidc")).to_be_visible(timeout=5_000)
# Verify in DB
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
provider = next((p for p in c.openid_connect_providers if p["id"] == "e2e-test-oidc"), None)
assert provider is not None
assert provider["label"] == "E2E Test IdP"
assert provider["client_id"] == "test-client-id"
async def test_delete_oidc_provider(page: Page, test_user: User):
"""Delete an OIDC provider — removed from table and DB."""
# Seed a provider
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
providers = list(c.openid_connect_providers or [])
providers.append({
"id": "delete-me-oidc", "label": "Delete Me", "scope": "openid",
"client_id": "x", "client_secret": "x",
"discovery_document_uri": "https://x/.well-known/openid-configuration",
})
c.openid_connect_providers = providers
session.add(c)
await session.commit()
await _go_to_settings(page)
await expect(page.get_by_role("cell", name="delete-me-oidc")).to_be_visible(timeout=5_000)
row = page.locator("tr", has_text="delete-me-oidc")
await row.locator(".q-btn").first.click()
await expect(page.get_by_text("OIDC provider deleted")).to_be_visible(timeout=5_000)
await page.wait_for_timeout(500)
await expect(page.get_by_role("cell", name="delete-me-oidc")).not_to_be_visible()
# Verify in DB
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
assert not any(p["id"] == "delete-me-oidc" for p in c.openid_connect_providers)
# --- SAML Providers ---
async def test_add_saml_provider(page: Page, test_user: User):
"""Add a SAML provider — appears in table and DB."""
await _go_to_settings(page)
await page.get_by_role("button", name="Add SAML Provider").click()
await expect(page.get_by_text("SAML Identity Provider", exact=True)).to_be_visible(timeout=5_000)
await page.locator(".q-dialog input[aria-label='Config ID']").fill("e2e-test-saml")
await page.locator(".q-dialog input[aria-label='Label']").fill("E2E SAML IdP")
await page.locator(".q-dialog textarea").fill("<EntityDescriptor>test</EntityDescriptor>")
await page.locator(".q-dialog").get_by_role("button", name="Save").click()
await expect(page.get_by_text("SAML provider 'E2E SAML IdP' saved")).to_be_visible(timeout=5_000)
await expect(page.get_by_role("cell", name="e2e-test-saml")).to_be_visible(timeout=5_000)
# Verify in DB
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
provider = next((p for p in c.saml_identity_providers if p["id"] == "e2e-test-saml"), None)
assert provider is not None
assert provider["label"] == "E2E SAML IdP"
async def test_delete_saml_provider(page: Page, test_user: User):
"""Delete a SAML provider — removed from table and DB."""
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
providers = list(c.saml_identity_providers or [])
providers.append({
"id": "delete-me-saml", "label": "Delete Me SAML",
"metadata": "<EntityDescriptor/>",
})
c.saml_identity_providers = providers
session.add(c)
await session.commit()
await _go_to_settings(page)
await expect(page.get_by_role("cell", name="delete-me-saml")).to_be_visible(timeout=5_000)
row = page.locator("tr", has_text="delete-me-saml")
await row.locator(".q-btn").first.click()
await expect(page.get_by_text("SAML provider deleted")).to_be_visible(timeout=5_000)
await page.wait_for_timeout(500)
await expect(page.get_by_role("cell", name="delete-me-saml")).not_to_be_visible()
# Verify in DB
async with async_session() as session:
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
assert not any(p["id"] == "delete-me-saml" for p in c.saml_identity_providers)

View file

@ -0,0 +1,41 @@
"""E2E tests for magic link request page."""
from playwright.async_api import Page, expect
from tests.e2e.conftest import TEST_APP_BASE, TEST_EMAIL
from wiregui.models.user import User
async def test_magic_link_page_renders(page: Page, test_user: User):
"""Magic link request page renders with email input and submit button."""
await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
await page.wait_for_load_state("networkidle")
await expect(page.get_by_text("Sign in with magic link")).to_be_visible(timeout=10_000)
await expect(page.locator("input[aria-label='Email']")).to_be_visible()
await expect(page.get_by_role("button", name="Send Magic Link")).to_be_visible()
await expect(page.get_by_role("button", name="Back to login")).to_be_visible()
async def test_magic_link_shows_success_on_submit(page: Page, test_user: User):
"""Submitting an email shows success message (regardless of whether account exists)."""
await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
await page.wait_for_load_state("networkidle")
await page.locator("input[aria-label='Email']").fill(TEST_EMAIL)
await page.get_by_role("button", name="Send Magic Link").click()
await expect(page.get_by_text("a sign-in link has been sent")).to_be_visible(timeout=5_000)
async def test_magic_link_empty_email_shows_error(page: Page, test_user: User):
"""Submitting without email shows error."""
await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
await page.wait_for_load_state("networkidle")
await page.get_by_role("button", name="Send Magic Link").click()
await expect(page.get_by_text("Enter your email")).to_be_visible(timeout=5_000)
async def test_magic_link_back_to_login(page: Page, test_user: User):
"""Back to login button navigates to login page."""
await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
await page.wait_for_load_state("networkidle")
await page.get_by_role("button", name="Back to login").click()
await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible(timeout=10_000)

111
tests/e2e/test_mfa_login.py Normal file
View file

@ -0,0 +1,111 @@
"""E2E tests for MFA login flow — login with TOTP redirects to /mfa challenge page."""
import pyotp
import pytest_asyncio
from playwright.async_api import Page, expect
from wiregui.auth.mfa import generate_totp_secret
from wiregui.auth.passwords import hash_password
from wiregui.db import async_session
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.user import User
from tests.e2e.conftest import (
FAKE_SERVER_KEY,
TEST_APP_BASE,
TEST_PASSWORD,
_cleanup_user_by_email,
)
MFA_EMAIL = "e2e-mfa@example.com"
MFA_PASSWORD = "mfapass123"
TOTP_SECRET = generate_totp_secret()
@pytest_asyncio.fixture
async def mfa_user(app_server):
"""Create a user with a TOTP MFA method, clean up after."""
await _cleanup_user_by_email(MFA_EMAIL)
async with async_session() as session:
from sqlmodel import select
from wiregui.models.configuration import Configuration
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
if config:
if not config.server_public_key:
config.server_public_key = FAKE_SERVER_KEY
session.add(config)
else:
config = Configuration(server_public_key=FAKE_SERVER_KEY)
session.add(config)
user = User(
email=MFA_EMAIL,
password_hash=hash_password(MFA_PASSWORD),
role="admin",
)
session.add(user)
await session.commit()
await session.refresh(user)
mfa = MFAMethod(
name="Test TOTP",
type="totp",
payload={"secret": TOTP_SECRET},
user_id=user.id,
)
session.add(mfa)
await session.commit()
yield user
await _cleanup_user_by_email(MFA_EMAIL)
async def _login_mfa_user(page: Page):
"""Fill login form for the MFA user and submit."""
await page.goto(f"{TEST_APP_BASE}/login")
await page.wait_for_load_state("networkidle")
await page.locator("input[aria-label='Email']").fill(MFA_EMAIL)
await page.locator("input[aria-label='Password']").fill(MFA_PASSWORD)
await page.get_by_role("button", name="Sign in", exact=True).click()
async def test_mfa_login_redirects_to_challenge(page: Page, mfa_user: User):
"""Login with MFA-enabled user redirects to /mfa challenge page."""
await _login_mfa_user(page)
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
await expect(page.locator("input[aria-label='Authentication Code']")).to_be_visible()
async def test_mfa_valid_totp_completes_login(page: Page, mfa_user: User):
"""Entering a valid TOTP code on /mfa completes login."""
await _login_mfa_user(page)
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
code = pyotp.TOTP(TOTP_SECRET).now()
await page.locator("input[aria-label='Authentication Code']").fill(code)
await page.get_by_role("button", name="Verify").click()
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
async def test_mfa_invalid_code_shows_error(page: Page, mfa_user: User):
"""Entering an invalid TOTP code shows error and stays on /mfa."""
await _login_mfa_user(page)
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
await page.locator("input[aria-label='Authentication Code']").fill("000000")
await page.get_by_role("button", name="Verify").click()
await expect(page.get_by_text("Invalid code")).to_be_visible(timeout=5_000)
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible()
async def test_mfa_cancel_returns_to_login(page: Page, mfa_user: User):
"""Clicking Cancel on /mfa clears session and returns to login."""
await _login_mfa_user(page)
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
await page.get_by_role("button", name="Cancel").click()
await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible(timeout=10_000)

View file

@ -0,0 +1,177 @@
"""E2E tests for SAML authentication — mock SimpleSAMLphp IdP.
Requires mock-saml service running (docker compose up -d mock-saml).
IdP metadata: http://localhost:8080/simplesaml/saml2/idp/metadata.php
Test users: user1/user1pass, user2/user2pass
"""
import os
import subprocess
import time
import httpx
import pytest
import pytest_asyncio
from playwright.async_api import Page, expect
from sqlmodel import select
from wiregui.db import async_session
from wiregui.models.configuration import Configuration
from wiregui.models.user import User
from tests.e2e.conftest import FAKE_SERVER_KEY, _cleanup_user_by_email
MOCK_SAML_HOST = os.environ.get("MOCK_SAML_HOST", "localhost")
MOCK_SAML_METADATA_URL = f"http://{MOCK_SAML_HOST}:8080/simplesaml/saml2/idp/metadata.php"
# Separate app port for SAML tests (like OIDC IdP tests)
SAML_APP_PORT = 13003
SAML_APP_BASE = f"http://localhost:{SAML_APP_PORT}"
SAML_TEST_EMAIL = "user1@example.com"
def _fetch_idp_metadata() -> str:
"""Fetch IdP metadata XML from the mock SAML server."""
try:
r = httpx.get(MOCK_SAML_METADATA_URL, timeout=5)
r.raise_for_status()
return r.text
except Exception:
pytest.skip(f"Mock SAML IdP not available at {MOCK_SAML_METADATA_URL}")
def _saml_provider_config(metadata: str) -> dict:
return {
"id": "test-saml",
"label": "Sign in with Mock SAML",
"metadata": metadata,
"sign_requests": False,
"sign_metadata": False,
"signed_assertion_in_resp": False,
"signed_envelopes_in_resp": False,
"auto_create_users": True,
"strict": False, # Relaxed for test IdP with expired certs
}
@pytest_asyncio.fixture(scope="module")
async def saml_metadata():
return _fetch_idp_metadata()
@pytest.fixture(scope="module")
def app_with_saml(saml_metadata):
"""Start a WireGUI instance with a SAML provider seeded in the DB."""
import asyncio
# Seed the SAML provider config into the database
async def _seed():
async with async_session() as session:
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
if config is None:
config = Configuration(server_public_key=FAKE_SERVER_KEY)
session.add(config)
await session.flush()
providers = list(config.saml_identity_providers or [])
providers = [p for p in providers if p.get("id") != "test-saml"]
providers.append(_saml_provider_config(saml_metadata))
config.saml_identity_providers = providers
session.add(config)
await session.commit()
asyncio.get_event_loop().run_until_complete(_seed())
env = os.environ.copy()
env["WG_LOG_TO_FILE"] = "false"
env["WG_PORT"] = str(SAML_APP_PORT)
env["WG_EXTERNAL_URL"] = SAML_APP_BASE
env.pop("PYTEST_CURRENT_TEST", None)
env.pop("NICEGUI_SCREEN_TEST_PORT", None)
proc = subprocess.Popen(
["uv", "run", "python", "-m", "wiregui.main"],
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
for _ in range(30):
try:
r = httpx.get(f"{SAML_APP_BASE}/api/health", timeout=1)
if r.status_code == 200:
break
except Exception:
pass
time.sleep(1)
else:
proc.kill()
out = proc.stdout.read().decode() if proc.stdout else ""
pytest.fail(f"App did not start in time. Output:\n{out}")
yield proc
proc.terminate()
proc.wait(timeout=10)
# Clean up seeded provider and test user
async def _cleanup():
await _cleanup_user_by_email(SAML_TEST_EMAIL)
async with async_session() as session:
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
if config:
config.saml_identity_providers = [
p for p in (config.saml_identity_providers or []) if p.get("id") != "test-saml"
]
session.add(config)
await session.commit()
asyncio.get_event_loop().run_until_complete(_cleanup())
async def test_saml_button_visible_on_login(app_with_saml, page: Page):
"""Login page shows SAML provider button."""
await page.goto(f"{SAML_APP_BASE}/login")
await page.wait_for_load_state("networkidle")
await expect(page.get_by_text("Sign in with Mock SAML")).to_be_visible(timeout=10_000)
async def test_saml_redirect_to_idp(app_with_saml, page: Page):
"""Clicking SAML login redirects to the SimpleSAMLphp IdP login page."""
await page.goto(f"{SAML_APP_BASE}/auth/saml/test-saml")
# Should redirect to the SimpleSAMLphp SSO service
await page.wait_for_url(f"**{MOCK_SAML_HOST}:8080/simplesaml/**", timeout=10_000)
async def test_saml_sp_metadata_endpoint(app_with_saml, page: Page):
"""SP metadata endpoint returns valid XML."""
response = await page.request.get(f"{SAML_APP_BASE}/auth/saml/test-saml/metadata")
assert response.status == 200
body = await response.text()
assert "EntityDescriptor" in body
assert "AssertionConsumerService" in body
async def test_full_saml_login_flow(app_with_saml, page: Page):
"""Full SAML SSO flow: app → IdP login → callback → authenticated."""
await page.goto(f"{SAML_APP_BASE}/auth/saml/test-saml")
await page.wait_for_url(f"**{MOCK_SAML_HOST}:8080/simplesaml/**", timeout=10_000)
# SimpleSAMLphp login form
await page.locator("input[name='username']").fill("user1")
await page.locator("input[name='password']").fill("password")
await page.locator("button[type='submit'], input[type='submit']").first.click()
# Should redirect back to the app after SAML response
await page.wait_for_url(f"{SAML_APP_BASE}/**", timeout=15_000)
await page.wait_for_load_state("networkidle")
await page.wait_for_timeout(3000)
assert "/login" not in page.url, f"SAML login failed — still on login page: {page.url}"
# Verify user was auto-created in DB
async with async_session() as session:
result = await session.execute(select(User).where(User.email == SAML_TEST_EMAIL))
user = result.scalar_one_or_none()
assert user is not None, f"Expected user {SAML_TEST_EMAIL} to be auto-created"
assert user.last_signed_in_method == "saml:test-saml"

263
tests/test_api_deps.py Normal file
View file

@ -0,0 +1,263 @@
"""Tests for API dependency injection — Bearer token auth and admin guard."""
import hashlib
from datetime import timedelta
from uuid import uuid4
import pytest
from unittest.mock import AsyncMock, MagicMock
from wiregui.auth.api_token import generate_api_token
from wiregui.auth.passwords import hash_password
from wiregui.db import async_session
from wiregui.models.api_token import ApiToken
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# ========== resolve_bearer_token ==========
async def test_resolve_valid_token():
"""Valid, non-expired token resolves to user."""
from wiregui.auth.api_token import resolve_bearer_token
plaintext, token_hash = generate_api_token()
async with async_session() as session:
user = User(email="api-test@test.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.commit()
await session.refresh(user)
api_token = ApiToken(
token_hash=token_hash,
user_id=user.id,
expires_at=utcnow() + timedelta(hours=1),
)
session.add(api_token)
await session.commit()
try:
async with async_session() as session:
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is not None
assert resolved.id == user.id
assert resolved.email == "api-test@test.com"
finally:
async with async_session() as session:
await session.delete(await session.get(ApiToken, api_token.id))
await session.delete(await session.get(User, user.id))
await session.commit()
async def test_resolve_expired_token():
"""Expired token returns None."""
from wiregui.auth.api_token import resolve_bearer_token
plaintext, token_hash = generate_api_token()
async with async_session() as session:
user = User(email="api-expired@test.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.commit()
await session.refresh(user)
api_token = ApiToken(
token_hash=token_hash,
user_id=user.id,
expires_at=utcnow() - timedelta(hours=1), # already expired
)
session.add(api_token)
await session.commit()
try:
async with async_session() as session:
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is None
finally:
async with async_session() as session:
await session.delete(await session.get(ApiToken, api_token.id))
await session.delete(await session.get(User, user.id))
await session.commit()
async def test_resolve_invalid_token():
"""Nonexistent token returns None."""
from wiregui.auth.api_token import resolve_bearer_token
async with async_session() as session:
resolved = await resolve_bearer_token(session, "totally-bogus-token")
assert resolved is None
async def test_resolve_token_disabled_user():
"""Token for disabled user returns None."""
from wiregui.auth.api_token import resolve_bearer_token
plaintext, token_hash = generate_api_token()
async with async_session() as session:
user = User(
email="api-disabled@test.com", password_hash=hash_password("x"),
role="admin", disabled_at=utcnow(),
)
session.add(user)
await session.commit()
await session.refresh(user)
api_token = ApiToken(
token_hash=token_hash,
user_id=user.id,
expires_at=utcnow() + timedelta(hours=1),
)
session.add(api_token)
await session.commit()
try:
async with async_session() as session:
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is None
finally:
async with async_session() as session:
await session.delete(await session.get(ApiToken, api_token.id))
await session.delete(await session.get(User, user.id))
await session.commit()
async def test_resolve_token_no_expiry():
"""Token without expires_at (never expires) resolves successfully."""
from wiregui.auth.api_token import resolve_bearer_token
plaintext, token_hash = generate_api_token()
async with async_session() as session:
user = User(email="api-noexp@test.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.commit()
await session.refresh(user)
api_token = ApiToken(
token_hash=token_hash,
user_id=user.id,
expires_at=None,
)
session.add(api_token)
await session.commit()
try:
async with async_session() as session:
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is not None
assert resolved.id == user.id
finally:
async with async_session() as session:
await session.delete(await session.get(ApiToken, api_token.id))
await session.delete(await session.get(User, user.id))
await session.commit()
# ========== get_current_api_user (via FastAPI deps) ==========
async def test_get_current_api_user_missing_header():
"""Missing Authorization header raises 401."""
from fastapi import HTTPException
from wiregui.api.deps import get_current_api_user
request = MagicMock()
request.headers = {}
with pytest.raises(HTTPException) as exc_info:
await get_current_api_user(request, session=AsyncMock())
assert exc_info.value.status_code == 401
assert "Missing" in exc_info.value.detail
async def test_get_current_api_user_bad_scheme():
"""Non-Bearer auth scheme raises 401."""
from fastapi import HTTPException
from wiregui.api.deps import get_current_api_user
request = MagicMock()
request.headers = {"Authorization": "Basic dXNlcjpwYXNz"}
with pytest.raises(HTTPException) as exc_info:
await get_current_api_user(request, session=AsyncMock())
assert exc_info.value.status_code == 401
async def test_get_current_api_user_invalid_token():
"""Valid Bearer scheme but bogus token raises 401."""
from fastapi import HTTPException
from wiregui.api.deps import get_current_api_user
request = MagicMock()
request.headers = {"Authorization": "Bearer bogus-token-value"}
async with async_session() as session:
with pytest.raises(HTTPException) as exc_info:
await get_current_api_user(request, session=session)
assert exc_info.value.status_code == 401
assert "Invalid" in exc_info.value.detail
async def test_get_current_api_user_valid_token():
"""Valid Bearer token resolves to user."""
from wiregui.api.deps import get_current_api_user
plaintext, token_hash = generate_api_token()
async with async_session() as session:
user = User(email="api-dep-test@test.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.commit()
await session.refresh(user)
api_token = ApiToken(
token_hash=token_hash,
user_id=user.id,
expires_at=utcnow() + timedelta(hours=1),
)
session.add(api_token)
await session.commit()
try:
request = MagicMock()
request.headers = {"Authorization": f"Bearer {plaintext}"}
async with async_session() as session:
resolved = await get_current_api_user(request, session=session)
assert resolved.id == user.id
finally:
async with async_session() as session:
await session.delete(await session.get(ApiToken, api_token.id))
await session.delete(await session.get(User, user.id))
await session.commit()
# ========== require_admin ==========
async def test_require_admin_allows_admin():
"""Admin user passes require_admin."""
from wiregui.api.deps import require_admin
admin_user = MagicMock(spec=User)
admin_user.role = "admin"
result = await require_admin(user=admin_user)
assert result == admin_user
async def test_require_admin_rejects_unprivileged():
"""Non-admin user gets 403."""
from fastapi import HTTPException
from wiregui.api.deps import require_admin
regular_user = MagicMock(spec=User)
regular_user.role = "unprivileged"
with pytest.raises(HTTPException) as exc_info:
await require_admin(user=regular_user)
assert exc_info.value.status_code == 403
assert "Admin" in exc_info.value.detail

View file

@ -0,0 +1,206 @@
"""Extended firewall tests — _nft/_nft_batch error handling, add_device_jump_rule edge cases, policies."""
from unittest.mock import AsyncMock, patch
import pytest
from wiregui.services.firewall import (
_nft,
_nft_batch,
add_device_jump_rule,
setup_base_tables,
setup_masquerade,
apply_peer_to_peer_policy,
apply_lan_to_peers_policy,
get_ruleset,
)
# ========== _nft error handling ==========
@patch("asyncio.create_subprocess_exec")
async def test_nft_raises_on_failure(mock_exec):
"""_nft raises RuntimeError on non-zero exit code."""
mock_proc = AsyncMock()
mock_proc.communicate.return_value = (b"", b"nft: error message")
mock_proc.returncode = 1
mock_exec.return_value = mock_proc
with pytest.raises(RuntimeError, match="nft.*failed"):
await _nft("list ruleset")
@patch("asyncio.create_subprocess_exec")
async def test_nft_returns_stdout_on_success(mock_exec):
"""_nft returns stdout on success."""
mock_proc = AsyncMock()
mock_proc.communicate.return_value = (b"table inet wiregui {}", b"")
mock_proc.returncode = 0
mock_exec.return_value = mock_proc
result = await _nft("list ruleset")
assert "wiregui" in result
# ========== _nft_batch error handling ==========
@patch("asyncio.create_subprocess_exec")
async def test_nft_batch_raises_on_failure(mock_exec):
"""_nft_batch raises RuntimeError on non-zero exit code."""
mock_proc = AsyncMock()
mock_proc.communicate.return_value = (b"", b"Error: syntax error")
mock_proc.returncode = 1
mock_exec.return_value = mock_proc
with pytest.raises(RuntimeError, match="nft batch failed"):
await _nft_batch(["add table inet wiregui"])
@patch("asyncio.create_subprocess_exec")
async def test_nft_batch_sends_commands_via_stdin(mock_exec):
"""_nft_batch sends all commands via stdin to nft -f -."""
mock_proc = AsyncMock()
mock_proc.communicate.return_value = (b"", b"")
mock_proc.returncode = 0
mock_exec.return_value = mock_proc
cmds = ["add table inet wiregui", "add chain inet wiregui test"]
await _nft_batch(cmds)
mock_exec.assert_awaited_once()
# Verify nft -f - was called
call_args = mock_exec.call_args[0]
assert call_args == ("nft", "-f", "-")
# Verify stdin data
stdin_data = mock_proc.communicate.call_args[0][0]
assert b"add table inet wiregui" in stdin_data
assert b"add chain inet wiregui test" in stdin_data
# ========== add_device_jump_rule edge cases ==========
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_add_device_jump_rule_ipv4_only(mock_batch):
"""Only IPv4 — generates single IPv4 jump rule."""
await add_device_jump_rule("user-id-1", "10.0.0.5", None)
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert len(cmds) == 1
assert "ip saddr 10.0.0.5" in cmds[0]
assert "jump" in cmds[0]
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_add_device_jump_rule_ipv6_only(mock_batch):
"""Only IPv6 — generates single IPv6 jump rule."""
await add_device_jump_rule("user-id-2", None, "fd00::5")
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert len(cmds) == 1
assert "ip6 saddr fd00::5" in cmds[0]
assert "jump" in cmds[0]
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_add_device_jump_rule_no_ips(mock_batch):
"""Neither IPv4 nor IPv6 — no nft commands issued."""
await add_device_jump_rule("user-id-3", None, None)
mock_batch.assert_not_awaited()
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_add_device_jump_rule_both_ips(mock_batch):
"""Both IPv4 and IPv6 — generates two jump rules."""
await add_device_jump_rule("user-id-4", "10.0.0.7", "fd00::7")
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert len(cmds) == 2
assert any("ip saddr 10.0.0.7" in c for c in cmds)
assert any("ip6 saddr fd00::7" in c for c in cmds)
# ========== setup_base_tables — already exists ==========
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_setup_base_tables_already_exists(mock_batch):
"""If table already exists (File exists error), don't raise."""
mock_batch.side_effect = RuntimeError("File exists")
await setup_base_tables() # should not raise
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_setup_base_tables_other_error_raises(mock_batch):
"""Other nft errors should propagate."""
mock_batch.side_effect = RuntimeError("Permission denied")
with pytest.raises(RuntimeError, match="Permission denied"):
await setup_base_tables()
# ========== setup_masquerade — error handling ==========
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_setup_masquerade_error_swallowed(mock_batch):
"""Masquerade errors are logged but not raised."""
mock_batch.side_effect = RuntimeError("nft error")
await setup_masquerade(iface="wg0") # should not raise
# ========== policy functions — command verification ==========
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_peer_to_peer_enabled(mock_batch):
"""Enabling peer-to-peer generates accept rules."""
await apply_peer_to_peer_policy(True)
cmds = mock_batch.call_args[0][0]
assert any("accept" in c for c in cmds)
assert any("peer_to_peer" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_peer_to_peer_disabled(mock_batch):
"""Disabling peer-to-peer generates drop rules."""
await apply_peer_to_peer_policy(False)
cmds = mock_batch.call_args[0][0]
assert any("drop" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_lan_to_peers_enabled(mock_batch):
"""Enabling LAN-to-peers generates accept rules."""
await apply_lan_to_peers_policy(True)
cmds = mock_batch.call_args[0][0]
assert any("accept" in c for c in cmds)
assert any("lan_to_peers" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_lan_to_peers_disabled(mock_batch):
"""Disabling LAN-to-peers generates drop rules."""
await apply_lan_to_peers_policy(False)
cmds = mock_batch.call_args[0][0]
assert any("drop" in c for c in cmds)
# ========== get_ruleset — error handling ==========
@patch("wiregui.services.firewall._nft", new_callable=AsyncMock)
async def test_get_ruleset_returns_output(mock_nft):
"""get_ruleset returns nft list ruleset output."""
mock_nft.return_value = "table inet wiregui { ... }"
result = await get_ruleset()
assert "wiregui" in result
@patch("wiregui.services.firewall._nft", new_callable=AsyncMock)
async def test_get_ruleset_returns_fallback_on_error(mock_nft):
"""get_ruleset returns friendly message when nft not available."""
mock_nft.side_effect = RuntimeError("nft not found")
result = await get_ruleset()
assert "not available" in result

View file

@ -0,0 +1,114 @@
"""Tests for WireGuard service — ensure_interface, set_private_key, set_listen_port, configure_interface."""
from unittest.mock import AsyncMock, patch, call
from wiregui.services.wireguard import (
ensure_interface,
set_private_key,
set_listen_port,
configure_interface,
)
# ========== ensure_interface ==========
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_ensure_interface_already_exists(mock_run):
"""If interface exists (ip link show succeeds), do nothing."""
mock_run.return_value = ""
await ensure_interface(iface="wg-test")
# Only called once for ip link show
mock_run.assert_awaited_once_with(["ip", "link", "show", "wg-test"])
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_ensure_interface_creates_new(mock_run):
"""If interface doesn't exist, create it, assign IPs, bring up."""
call_count = 0
async def side_effect(args, input_data=None):
nonlocal call_count
call_count += 1
if call_count == 1 and args == ["ip", "link", "show", "wg-test"]:
raise RuntimeError("Device not found")
return ""
mock_run.side_effect = side_effect
await ensure_interface(iface="wg-test")
# Should have called: ip link show (fails), ip link add, ip addr add x2, ip link set up
assert mock_run.await_count == 5
calls = [c[0][0] for c in mock_run.call_args_list]
assert calls[1] == ["ip", "link", "add", "wg-test", "type", "wireguard"]
assert calls[2][0:3] == ["ip", "address", "add"]
assert calls[3][0:3] == ["ip", "address", "add"]
assert calls[4] == ["ip", "link", "set", "wg-test", "up"]
# ========== set_private_key ==========
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_set_private_key(mock_run):
"""set_private_key calls wg set with private-key path."""
mock_run.return_value = ""
await set_private_key("/tmp/test.key", iface="wg-test")
mock_run.assert_awaited_once_with(["wg", "set", "wg-test", "private-key", "/tmp/test.key"])
# ========== set_listen_port ==========
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_set_listen_port(mock_run):
"""set_listen_port calls wg set with listen-port."""
mock_run.return_value = ""
await set_listen_port(51820, iface="wg-test")
mock_run.assert_awaited_once_with(["wg", "set", "wg-test", "listen-port", "51820"])
# ========== configure_interface ==========
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
@patch("wiregui.db.async_session")
async def test_configure_interface_no_config(mock_session_cls, mock_run):
"""If no Configuration row exists, do not call wg set."""
from unittest.mock import MagicMock
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=False)
await configure_interface(iface="wg-test")
mock_run.assert_not_awaited()
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
@patch("wiregui.db.async_session")
async def test_configure_interface_sets_key_and_port(mock_session_cls, mock_run):
"""With valid config, writes key to temp file and calls wg set."""
from unittest.mock import MagicMock
mock_config = MagicMock()
mock_config.server_private_key = "test-private-key-value"
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_config
mock_session.execute.return_value = mock_result
mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=False)
mock_run.return_value = ""
await configure_interface(iface="wg-test")
mock_run.assert_awaited_once()
args = mock_run.call_args[0][0]
assert args[0:3] == ["wg", "set", "wg-test"]
assert "private-key" in args
assert "listen-port" in args

View file

@ -17,7 +17,7 @@ def _build_saml_settings(provider_config: dict) -> dict:
idp_settings = idp_data.get("idp", {}) idp_settings = idp_data.get("idp", {})
return { return {
"strict": True, "strict": provider_config.get("strict", True),
"debug": False, "debug": False,
"sp": { "sp": {
"entityId": f"{base_url}/auth/saml/{provider_config['id']}/metadata", "entityId": f"{base_url}/auth/saml/{provider_config['id']}/metadata",

View file

@ -6,6 +6,7 @@ from loguru import logger
from nicegui import app, ui from nicegui import app, ui
from sqlmodel import select from sqlmodel import select
from wiregui.config import get_settings
from wiregui.db import async_session from wiregui.db import async_session
from wiregui.models.configuration import Configuration from wiregui.models.configuration import Configuration
from wiregui.pages.layout import layout from wiregui.pages.layout import layout

View file

@ -101,14 +101,13 @@ async def saml_callback(provider_id: str, request: Request):
session.add(user) session.add(user)
await session.commit() await session.commit()
request.session["authenticated"] = True # Store auth data in Starlette session — picked up by /auth/complete
request.session["user_id"] = str(user.id) request.session["oidc_user_id"] = str(user.id)
request.session["email"] = user.email request.session["oidc_email"] = user.email
request.session["role"] = user.role request.session["oidc_role"] = user.role
request.session["theme_preference"] = user.theme_preference
logger.info("SAML login: {} via {}", email, provider_id) logger.info("SAML login: {} via {}", email, provider_id)
return RedirectResponse(url="/", status_code=303) return RedirectResponse(url="/auth/complete", status_code=303)
except Exception as e: except Exception as e:
logger.error("SAML callback failed for {}: {}", provider_id, e) logger.error("SAML callback failed for {}: {}", provider_id, e)

View file

@ -1,4 +1,4 @@
"""Login page — email/password, MFA redirect, OIDC provider buttons.""" """Login page — email/password, MFA redirect, OIDC/SAML provider buttons."""
from nicegui import app, ui from nicegui import app, ui
from sqlmodel import select from sqlmodel import select
@ -6,6 +6,7 @@ from sqlmodel import select
from wiregui.auth.oidc import load_providers from wiregui.auth.oidc import load_providers
from wiregui.auth.session import authenticate_user from wiregui.auth.session import authenticate_user
from wiregui.db import async_session from wiregui.db import async_session
from wiregui.models.configuration import Configuration
from wiregui.models.mfa_method import MFAMethod from wiregui.models.mfa_method import MFAMethod
from wiregui.pages.style import apply_style from wiregui.pages.style import apply_style
from wiregui.utils.time import utcnow from wiregui.utils.time import utcnow
@ -18,9 +19,13 @@ async def login_page():
apply_style() apply_style()
# Load OIDC providers for SSO buttons # Load SSO providers for login buttons
oidc_providers = await load_providers() oidc_providers = await load_providers()
async with async_session() as session:
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
saml_providers = config.saml_identity_providers if config else []
async def try_login(): async def try_login():
user = await authenticate_user(email.value, password.value) user = await authenticate_user(email.value, password.value)
if user is None: if user is None:
@ -76,8 +81,8 @@ async def login_page():
password.on("keydown.enter", try_login) password.on("keydown.enter", try_login)
# OIDC provider buttons # SSO provider buttons
if oidc_providers: if oidc_providers or saml_providers:
ui.separator().classes("q-my-md") ui.separator().classes("q-my-md")
ui.label("Or sign in with").classes("text-caption text-center w-full") ui.label("Or sign in with").classes("text-caption text-center w-full")
for provider in oidc_providers: for provider in oidc_providers:
@ -87,3 +92,10 @@ async def login_page():
label, label,
on_click=lambda p=pid: ui.run_javascript(f"window.location.href='/auth/oidc/{p}'"), on_click=lambda p=pid: ui.run_javascript(f"window.location.href='/auth/oidc/{p}'"),
).props("color=primary unelevated").classes("w-full q-mt-xs") ).props("color=primary unelevated").classes("w-full q-mt-xs")
for provider in saml_providers:
pid = provider.get("id", "")
label = provider.get("label", pid)
ui.button(
label,
on_click=lambda p=pid: ui.run_javascript(f"window.location.href='/auth/saml/{p}'"),
).props("color=primary unelevated").classes("w-full q-mt-xs")