diff --git a/.forgejo/workflows/dev.yml b/.forgejo/workflows/dev.yml index 5a1a023..86a030f 100644 --- a/.forgejo/workflows/dev.yml +++ b/.forgejo/workflows/dev.yml @@ -34,11 +34,18 @@ jobs: env: SERVER_PORT: "9000" JSON_CONFIG: '{"interactiveLogin":true,"httpServer":"NettyWrapper","tokenCallbacks":[{"issuerId":"test-idp","tokenExpiry":3600,"requestMappings":[{"requestParam":"scope","match":"*","claims":{"sub":"$${claim:sub}","email":"$${claim:sub}@test.local","name":"Test User"}}]}]}' + mock-saml: + image: kenchan0130/simplesamlphp + env: + SIMPLESAMLPHP_SP_ENTITY_ID: http://localhost:13003/auth/saml/test-saml/metadata + SIMPLESAMLPHP_SP_ASSERTION_CONSUMER_SERVICE: http://localhost:13003/auth/saml/test-saml/callback + SIMPLESAMLPHP_IDP_BASE_URL: http://mock-saml:8080/simplesaml/ env: CI: "true" WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui WG_REDIS_URL: redis://valkey:6379/0 MOCK_OIDC_HOST: mock-oidc + MOCK_SAML_HOST: mock-saml steps: - name: Install system dependencies and checkout run: | diff --git a/.forgejo/workflows/release.yml b/.forgejo/workflows/release.yml index 187075f..263e1ae 100644 --- a/.forgejo/workflows/release.yml +++ b/.forgejo/workflows/release.yml @@ -35,11 +35,18 @@ jobs: env: SERVER_PORT: "9000" JSON_CONFIG: '{"interactiveLogin":true,"httpServer":"NettyWrapper","tokenCallbacks":[{"issuerId":"test-idp","tokenExpiry":3600,"requestMappings":[{"requestParam":"scope","match":"*","claims":{"sub":"$${claim:sub}","email":"$${claim:sub}@test.local","name":"Test User"}}]}]}' + mock-saml: + image: kenchan0130/simplesamlphp + env: + SIMPLESAMLPHP_SP_ENTITY_ID: http://localhost:13003/auth/saml/test-saml/metadata + SIMPLESAMLPHP_SP_ASSERTION_CONSUMER_SERVICE: http://localhost:13003/auth/saml/test-saml/callback + SIMPLESAMLPHP_IDP_BASE_URL: http://mock-saml:8080/simplesaml/ env: CI: "true" WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui WG_REDIS_URL: redis://valkey:6379/0 MOCK_OIDC_HOST: mock-oidc + MOCK_SAML_HOST: mock-saml steps: - name: Install system dependencies and checkout run: | diff --git a/TODO.md b/TODO.md index 340c382..e23431f 100644 --- a/TODO.md +++ b/TODO.md @@ -1,6 +1,6 @@ # WireGUI — Pending Items -**Test count: 174 (164 unit + 10 E2E) | Coverage: ~35%** +**Test count: 268 (198 unit + 70 E2E) | Coverage: 36% unit, ~63% effective (incl. E2E)** --- @@ -11,7 +11,7 @@ Migration of Wirezone (Elixir/Phoenix) to Python/NiceGUI. Source: `/home/stefanob/PycharmProjects/personal/wirezone` -**Test count: 199 (164 unit + 35 E2E) | Coverage: 35%** +**Test count: 268 (198 unit + 70 E2E) | Coverage: 36% unit, ~63% effective (incl. E2E)** **Run:** `uv run pytest` (unit) / `uv run pytest tests/e2e/` (E2E via Playwright) @@ -23,11 +23,11 @@ Source: `/home/stefanob/PycharmProjects/personal/wirezone` ### Testing (partially done) - [ ] HTTP-level integration tests (OIDC redirect/callback flow with respx mocking) -- [ ] `wiregui/api/deps.py` — test get_current_api_user with real Bearer header parsing, require_admin rejection -- [ ] `wiregui/services/wireguard.py` — test ensure_interface, set_private_key, set_listen_port -- [ ] `wiregui/services/firewall.py` — test _nft/_nft_batch error handling, add_device_jump_rule with only ipv4/ipv6 +- [x] `wiregui/api/deps.py` (11 tests) — resolve_bearer_token (valid/expired/invalid/disabled/no-expiry), get_current_api_user (missing header/bad scheme/invalid token/valid token), require_admin (admin/unprivileged) +- [x] `wiregui/services/wireguard.py` (6 tests) — ensure_interface (exists/creates new), set_private_key, set_listen_port, configure_interface (no config/sets key+port) +- [x] `wiregui/services/firewall.py` (17 tests) — _nft error/success, _nft_batch error/stdin, add_device_jump_rule (ipv4-only/ipv6-only/no-ips/both), setup_base_tables error handling, masquerade error, peer-to-peer/lan-to-peers policies, get_ruleset fallback - [ ] `wiregui/tasks/oidc_refresh.py` — test successful refresh, failure with notification, disable_vpn_on_oidc_error -- [ ] `wiregui/auth/saml.py` (0%) — needs mock SAML IdP metadata + response parsing +- [x] `wiregui/auth/saml.py` — full SAML flow tested via mock SimpleSAMLphp IdP (e2e) - [ ] `wiregui/auth/webauthn.py` — test verify_registration, verify_authentication with mock credential data - [ ] E2E tests for admin pages (users, devices, rules, settings) @@ -37,46 +37,52 @@ Source: `/home/stefanob/PycharmProjects/personal/wirezone` - [x] `tests/e2e/test_account.py` (8 tests) — change password (success/wrong/mismatch/short), create API token, TOTP registration + invalid code, account deletion - [x] `tests/e2e/test_admin_users.py` (10 tests) — page renders, create user, duplicate email, edit role/password, disable/enable, delete, cascade delete, self-delete guard - [x] `tests/e2e/test_idp_seed.py` (9 tests) — IdP YAML seeding (noop/missing/invalid, OIDC/SAML add, upsert, preserve), OIDC button visible, full OIDC login flow via mock-oidc +- [x] `tests/e2e/test_mfa_login.py` (4 tests) — MFA redirect on login, valid TOTP completes login, invalid code error, cancel returns to login +- [x] `tests/e2e/test_magic_link_page.py` (4 tests) — page renders, success on submit, empty email error, back to login +- [x] `tests/e2e/test_admin_devices.py` (7 tests) — list all devices, filter by user, create with defaults, create with overrides, edit name/description, delete, config dialog with QR +- [x] `tests/e2e/test_admin_rules.py` (7 tests) — list rules table, create accept/drop/global rules, edit action/destination, delete rule (all verified in DB) +- [x] `tests/e2e/test_admin_settings.py` (9 tests) — client defaults save/reload, security toggles (local auth, VPN session, unprivileged), OIDC add/delete, SAML add/delete (all verified in DB) +- [x] `tests/e2e/test_saml_login.py` (4 tests) — SAML button visible, redirect to IdP, SP metadata endpoint, full SAML login flow via mock SimpleSAMLphp **E2E tests still needed:** `tests/e2e/test_login.py` — Login & Auth flows (remaining): -- [ ] Login with MFA → redirects to /mfa challenge page -- [ ] MFA challenge: valid TOTP code → completes login -- [ ] MFA challenge: invalid code → shows error, stays on /mfa -- [ ] MFA challenge: cancel → returns to /login -- [ ] Magic link request page renders, shows success on submit +- [x] Login with MFA → redirects to /mfa challenge page +- [x] MFA challenge: valid TOTP code → completes login +- [x] MFA challenge: invalid code → shows error, stays on /mfa +- [x] MFA challenge: cancel → returns to /login +- [x] Magic link request page renders, shows success on submit `tests/e2e/test_admin_devices.py` — Admin Device Management: -- [ ] List all devices across users -- [ ] Filter by user → shows only that user's devices -- [ ] Create device with full config overrides (DNS, endpoint, MTU, keepalive, allowed IPs) -- [ ] Create device with defaults → use_default flags all True -- [ ] Edit device name and description → persists -- [ ] Edit device config overrides (toggle use_default off, set custom values) -- [ ] Delete device → removed from table -- [ ] Config dialog shows valid WireGuard config with real server public key -- [ ] QR code renders in config dialog +- [x] List all devices across users +- [x] Filter by user → shows only that user's devices +- [x] Create device with full config overrides (DNS, endpoint, MTU, keepalive, allowed IPs) +- [x] Create device with defaults → use_default flags all True +- [x] Edit device name and description → persists +- [x] Edit device config overrides (toggle use_default off, set custom values) +- [x] Delete device → removed from table +- [x] Config dialog shows valid WireGuard config with real server public key +- [x] QR code renders in config dialog `tests/e2e/test_admin_rules.py` — Admin Firewall Rules: -- [ ] List rules → table shows action, destination, protocol, port, user -- [ ] Create accept rule with CIDR → appears in table -- [ ] Create drop rule with TCP port range → appears correctly -- [ ] Create global rule (no user) → shows "Global" -- [ ] Edit rule action (accept → drop) → persists -- [ ] Edit rule destination → persists -- [ ] Delete rule → removed from table +- [x] List rules → table shows action, destination, protocol, port, user +- [x] Create accept rule with CIDR → appears in table +- [x] Create drop rule with TCP port range → appears correctly +- [x] Create global rule (no user) → shows "Global" +- [x] Edit rule action (accept → drop) → persists +- [x] Edit rule destination → persists +- [x] Delete rule → removed from table `tests/e2e/test_admin_settings.py` — Admin Settings: -- [ ] Client defaults: save endpoint, DNS, MTU, keepalive, allowed IPs → persists in DB -- [ ] Client defaults: saved values reflected on page reload -- [ ] Security: toggle local auth → persists -- [ ] Security: change VPN session duration → persists -- [ ] Security: toggle unprivileged device management/configuration → persists -- [ ] OIDC: add provider → appears in table -- [ ] OIDC: delete provider → removed from table -- [ ] SAML: add provider → appears in table -- [ ] SAML: delete provider → removed from table +- [x] Client defaults: save endpoint, DNS, MTU, keepalive, allowed IPs → persists in DB +- [x] Client defaults: saved values reflected on page reload +- [x] Security: toggle local auth → persists +- [x] Security: change VPN session duration → persists +- [x] Security: toggle unprivileged device management/configuration → persists +- [x] OIDC: add provider → appears in table +- [x] OIDC: delete provider → removed from table +- [x] SAML: add provider → appears in table +- [x] SAML: delete provider → removed from table `tests/e2e/test_admin_diagnostics.py` — Admin Diagnostics: - [ ] Page renders WireGuard interface status diff --git a/compose.yml b/compose.yml index a26f327..30dd691 100644 --- a/compose.yml +++ b/compose.yml @@ -49,6 +49,21 @@ services: ] } + # Test SAML Identity Provider — SimpleSAMLphp as IdP + # IdP Metadata: http://localhost:8080/simplesaml/saml2/idp/metadata.php + # Admin UI: http://localhost:8080/simplesaml (admin / secret) + # Test users: user1/password, user2/password + mock-saml: + image: kenchan0130/simplesamlphp + ports: + - "8080:8080" + environment: + SIMPLESAMLPHP_SP_ENTITY_ID: "http://localhost:13000/auth/saml/test-saml/metadata" + SIMPLESAMLPHP_SP_ASSERTION_CONSUMER_SERVICE: "http://localhost:13000/auth/saml/test-saml/callback" + SIMPLESAMLPHP_IDP_BASE_URL: http://localhost:8080/simplesaml/ + volumes: + - ./docker/mock-saml/saml20-sp-remote.php:/var/www/simplesamlphp/metadata/saml20-sp-remote.php:ro + volumes: postgres_data: valkey_data: diff --git a/docker/mock-saml/saml20-sp-remote.php b/docker/mock-saml/saml20-sp-remote.php new file mode 100644 index 0000000..099f8a2 --- /dev/null +++ b/docker/mock-saml/saml20-sp-remote.php @@ -0,0 +1,15 @@ + 'http://localhost:13000/auth/saml/test-saml/callback', +]; + +// E2E test instance +$metadata['http://localhost:13003/auth/saml/test-saml/metadata'] = [ + 'AssertionConsumerService' => 'http://localhost:13003/auth/saml/test-saml/callback', +]; \ No newline at end of file diff --git a/tests/e2e/test_admin_devices.py b/tests/e2e/test_admin_devices.py new file mode 100644 index 0000000..b44a262 --- /dev/null +++ b/tests/e2e/test_admin_devices.py @@ -0,0 +1,239 @@ +"""E2E tests for admin device management page.""" + +import pytest_asyncio +from playwright.async_api import Page, expect +from sqlmodel import select + +from wiregui.auth.passwords import hash_password +from wiregui.db import async_session +from wiregui.models.device import Device +from wiregui.models.user import User +from wiregui.utils.crypto import generate_keypair, generate_preshared_key +from tests.e2e.conftest import ( + TEST_APP_BASE, + TEST_EMAIL, + TEST_PASSWORD, + _cleanup_user_by_email, + login, +) + +SECOND_USER_EMAIL = "e2e-device-user2@example.com" + + +@pytest_asyncio.fixture +async def second_user(test_user): + """Create a second user with a device for filtering tests.""" + await _cleanup_user_by_email(SECOND_USER_EMAIL) + + async with async_session() as session: + user = User( + email=SECOND_USER_EMAIL, + password_hash=hash_password("pass12345"), + role="unprivileged", + ) + session.add(user) + await session.commit() + await session.refresh(user) + + yield user + + await _cleanup_user_by_email(SECOND_USER_EMAIL) + + +@pytest_asyncio.fixture +async def devices_for_both_users(test_user, second_user): + """Create one device per user for table/filter tests.""" + _, pub1 = generate_keypair() + _, pub2 = generate_keypair() + psk1 = generate_preshared_key() + psk2 = generate_preshared_key() + + async with async_session() as session: + d1 = Device( + name="admin-laptop", + public_key=pub1, + preshared_key=psk1, + ipv4="10.0.0.10", + user_id=test_user.id, + ) + d2 = Device( + name="user2-phone", + public_key=pub2, + preshared_key=psk2, + ipv4="10.0.0.11", + user_id=second_user.id, + ) + session.add_all([d1, d2]) + await session.commit() + + yield d1, d2 + + # Cleanup handled by user fixture cascade + + +async def _go_to_admin_devices(page: Page): + """Login as admin and navigate to admin devices page.""" + await login(page) + await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000) + await page.goto(f"{TEST_APP_BASE}/admin/devices") + await expect(page.locator("role=main").get_by_text("All Devices")).to_be_visible(timeout=10_000) + + +async def test_list_all_devices(page: Page, devices_for_both_users): + """Admin devices page lists devices from all users.""" + await _go_to_admin_devices(page) + await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000) + await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000) + + +async def test_filter_by_user(page: Page, second_user, devices_for_both_users): + """Filtering by user shows only that user's devices.""" + await _go_to_admin_devices(page) + await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000) + await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000) + + # Filter to second user + await page.locator("label:has-text('Filter by User')").click() + await page.get_by_role("option", name=SECOND_USER_EMAIL).click() + await page.wait_for_timeout(1000) + + await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000) + await expect(page.get_by_text("admin-laptop")).not_to_be_visible() + + # Filter back to all + await page.locator("label:has-text('Filter by User')").click() + await page.get_by_role("option", name="All Users").click() + await page.wait_for_timeout(1000) + + await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000) + await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000) + + +async def test_create_device_with_defaults(page: Page, test_user): + """Create device with all defaults — config dialog appears.""" + await _go_to_admin_devices(page) + await page.get_by_role("button", name="Add Device").click() + await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000) + + await page.locator("input[aria-label='Device Name']").fill("default-test-device") + await page.get_by_role("button", name="Create").click() + + # Config dialog should appear with WireGuard config + await expect(page.get_by_text("Config for default-test-device")).to_be_visible(timeout=10_000) + await expect(page.get_by_text("[Interface]")).to_be_visible(timeout=5_000) + await page.get_by_role("button", name="Close").click() + await page.wait_for_timeout(500) + + # Device should be in the table + await expect(page.get_by_role("cell", name="default-test-device").first).to_be_visible(timeout=5_000) + + +async def test_create_device_with_overrides(page: Page, test_user): + """Create device with custom config overrides.""" + await _go_to_admin_devices(page) + await page.get_by_role("button", name="Add Device").click() + await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000) + + await page.locator("input[aria-label='Device Name']").fill("custom-override-dev") + await page.locator("input[aria-label='Description (optional)']").fill("Custom overrides test") + + # Toggle off DNS default and set custom — Quasar switches use .q-toggle + await page.locator(".q-toggle", has_text="Use default DNS").click() + dns_input = page.locator("input[aria-label='DNS Servers']") + await dns_input.clear() + await dns_input.fill("8.8.8.8, 8.8.4.4") + + # Toggle off MTU default and set custom + await page.locator(".q-toggle", has_text="Use default MTU").click() + mtu_input = page.locator("input[aria-label='MTU']") + await mtu_input.clear() + await mtu_input.fill("1400") + + await page.get_by_role("button", name="Create").click() + + await expect(page.get_by_text("Config for custom-override-dev")).to_be_visible(timeout=10_000) + await page.get_by_role("button", name="Close").click() + await page.wait_for_timeout(500) + + await expect(page.get_by_role("cell", name="custom-override-dev").first).to_be_visible(timeout=5_000) + + # Verify in DB + async with async_session() as session: + result = await session.execute( + select(Device).where(Device.name == "custom-override-dev") + .order_by(Device.inserted_at.desc()).limit(1) + ) + device = result.scalar_one() + assert device.use_default_dns is False + assert "8.8.8.8" in device.dns + assert device.use_default_mtu is False + assert device.mtu == 1400 + + +async def test_edit_device_name_and_description(page: Page, devices_for_both_users): + """Edit a device name and description via the edit dialog.""" + await _go_to_admin_devices(page) + await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000) + + # Click edit button on admin-laptop row — Quasar slot buttons with icon + row = page.locator("tr", has_text="admin-laptop") + await row.locator(".q-btn").first.click() + + await expect(page.get_by_text("Edit Device")).to_be_visible(timeout=5_000) + + name_input = page.locator(".q-dialog input[aria-label='Device Name']") + await name_input.clear() + await name_input.fill("admin-laptop-renamed") + + desc_input = page.locator(".q-dialog input[aria-label='Description']") + await desc_input.clear() + await desc_input.fill("Updated description") + + await page.get_by_role("button", name="Save").click() + await expect(page.get_by_text("Device updated")).to_be_visible(timeout=5_000) + await expect(page.get_by_text("admin-laptop-renamed")).to_be_visible(timeout=5_000) + + +async def test_delete_device(page: Page, test_user): + """Delete a device — removed from table.""" + _, pub = generate_keypair() + async with async_session() as session: + d = Device( + name="delete-me-device", + public_key=pub, + preshared_key=generate_preshared_key(), + ipv4="10.0.0.99", + user_id=test_user.id, + ) + session.add(d) + await session.commit() + + await _go_to_admin_devices(page) + await expect(page.get_by_role("cell", name="delete-me-device")).to_be_visible(timeout=5_000) + + # Click the delete (second) button in the row + row = page.locator("tr", has_text="delete-me-device") + await row.locator(".q-btn").nth(1).click() + + await expect(page.get_by_text("Deleted delete-me-device")).to_be_visible(timeout=5_000) + await page.wait_for_timeout(1000) + await expect(page.get_by_role("cell", name="delete-me-device")).not_to_be_visible() + + +async def test_config_dialog_shows_wg_config(page: Page, test_user): + """Config dialog after device creation shows valid WireGuard config.""" + await _go_to_admin_devices(page) + await page.get_by_role("button", name="Add Device").click() + await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000) + + await page.locator("input[aria-label='Device Name']").fill("config-test-device") + await page.get_by_role("button", name="Create").click() + + await expect(page.get_by_text("Config for config-test-device")).to_be_visible(timeout=10_000) + await expect(page.get_by_text("[Interface]")).to_be_visible(timeout=5_000) + await expect(page.get_by_text("[Peer]")).to_be_visible(timeout=5_000) + await expect(page.get_by_text("PrivateKey")).to_be_visible() + await expect(page.get_by_role("button", name="Download .conf")).to_be_visible() + + # QR code should be rendered + await expect(page.locator(".q-dialog img")).to_be_visible(timeout=5_000) \ No newline at end of file diff --git a/tests/e2e/test_admin_rules.py b/tests/e2e/test_admin_rules.py new file mode 100644 index 0000000..6aa5b4e --- /dev/null +++ b/tests/e2e/test_admin_rules.py @@ -0,0 +1,227 @@ +"""E2E tests for admin firewall rules management page.""" + +from uuid import UUID + +import pytest_asyncio +from playwright.async_api import Page, expect +from sqlmodel import select + +from wiregui.db import async_session +from wiregui.models.rule import Rule +from wiregui.models.user import User +from tests.e2e.conftest import TEST_APP_BASE, TEST_EMAIL, login + + +async def _cleanup_test_rules(): + """Remove rules created by tests (identified by test-specific destinations).""" + async with async_session() as session: + result = await session.execute( + select(Rule).where(Rule.destination.in_([ + "10.99.0.0/16", "10.88.0.0/16", "10.77.0.0/16", + "10.66.0.0/16", "10.55.0.0/16", + ])) + ) + for rule in result.scalars().all(): + await session.delete(rule) + await session.commit() + + +@pytest_asyncio.fixture(autouse=True) +async def clean_rules(app_server): + """Clean up test rules before and after each test.""" + await _cleanup_test_rules() + yield + await _cleanup_test_rules() + + +async def _go_to_rules(page: Page): + """Login and navigate to admin rules page.""" + await login(page) + await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000) + await page.goto(f"{TEST_APP_BASE}/admin/rules") + await expect(page.locator("role=main").get_by_text("Firewall Rules")).to_be_visible(timeout=10_000) + + +async def _create_rule_via_dialog( + page: Page, *, action: str = "accept", destination: str = "10.99.0.0/16", + protocol: str = "any", port_range: str = "", user: str = "global", +): + """Open create dialog and fill in a rule.""" + await page.get_by_role("button", name="Add Rule").click() + await expect(page.get_by_text("New Firewall Rule")).to_be_visible(timeout=5_000) + + # Action select + if action != "accept": + await page.locator(".q-dialog label:has-text('Action')").click() + await page.get_by_role("option", name=action).click() + + # Destination + await page.locator(".q-dialog input[aria-label='Destination (CIDR)']").fill(destination) + + # Protocol + if protocol != "any": + await page.locator(".q-dialog label:has-text('Protocol')").click() + await page.get_by_role("option", name=protocol).click() + + # Port range + if port_range: + await page.locator(".q-dialog input[aria-label='Port Range']").fill(port_range) + + # User + if user != "global": + await page.locator(".q-dialog label:has-text('Applies to')").click() + await page.get_by_role("option", name=user).click() + + await page.get_by_role("button", name="Create").click() + await page.wait_for_timeout(500) + + +async def test_list_rules_table(page: Page, test_user: User): + """Rules page renders table with correct columns.""" + # Seed a rule in DB + async with async_session() as session: + rule = Rule(action="accept", destination="10.99.0.0/16", port_type="tcp", + port_range="443", user_id=test_user.id) + session.add(rule) + await session.commit() + + await _go_to_rules(page) + + await expect(page.get_by_role("cell", name="accept")).to_be_visible(timeout=5_000) + await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible() + await expect(page.get_by_role("cell", name="tcp")).to_be_visible() + await expect(page.get_by_role("cell", name="443")).to_be_visible() + await expect(page.get_by_role("cell", name=TEST_EMAIL)).to_be_visible() + + +async def test_create_accept_rule_with_cidr(page: Page, test_user: User): + """Create an accept rule with CIDR — verify in table and DB.""" + await _go_to_rules(page) + await _create_rule_via_dialog(page, action="accept", destination="10.88.0.0/16") + + await expect(page.get_by_role("cell", name="10.88.0.0/16")).to_be_visible(timeout=5_000) + + # Verify in DB + async with async_session() as session: + result = await session.execute(select(Rule).where(Rule.destination == "10.88.0.0/16")) + rule = result.scalar_one() + assert rule.action == "accept" + assert rule.port_type is None + assert rule.port_range is None + assert rule.user_id is None + + +async def test_create_drop_rule_with_tcp_port_range(page: Page, test_user: User): + """Create a drop rule with TCP port range — verify in table and DB.""" + await _go_to_rules(page) + await _create_rule_via_dialog( + page, action="drop", destination="10.77.0.0/16", + protocol="tcp", port_range="80-443", + ) + + await expect(page.get_by_role("cell", name="10.77.0.0/16")).to_be_visible(timeout=5_000) + await expect(page.get_by_role("cell", name="drop").first).to_be_visible() + + # Verify in DB + async with async_session() as session: + result = await session.execute(select(Rule).where(Rule.destination == "10.77.0.0/16")) + rule = result.scalar_one() + assert rule.action == "drop" + assert rule.port_type == "tcp" + assert rule.port_range == "80-443" + + +async def test_create_global_rule(page: Page, test_user: User): + """Create a global rule (no user) — shows 'Global' in table and DB has null user_id.""" + await _go_to_rules(page) + await _create_rule_via_dialog(page, destination="10.66.0.0/16", user="global") + + await expect(page.get_by_role("cell", name="10.66.0.0/16")).to_be_visible(timeout=5_000) + await expect(page.get_by_role("cell", name="Global")).to_be_visible() + + # Verify in DB + async with async_session() as session: + result = await session.execute(select(Rule).where(Rule.destination == "10.66.0.0/16")) + rule = result.scalar_one() + assert rule.user_id is None + + +async def test_edit_rule_action(page: Page, test_user: User): + """Edit rule action from accept to drop — verify in table and DB.""" + async with async_session() as session: + rule = Rule(action="accept", destination="10.55.0.0/16") + session.add(rule) + await session.commit() + rule_id = rule.id + + await _go_to_rules(page) + await expect(page.get_by_role("cell", name="10.55.0.0/16")).to_be_visible(timeout=5_000) + + # Click edit (first button in the row) + row = page.locator("tr", has_text="10.55.0.0/16") + await row.locator(".q-btn").first.click() + await expect(page.get_by_text("Edit Firewall Rule")).to_be_visible(timeout=5_000) + + # Change action to drop + await page.locator(".q-dialog label:has-text('Action')").click() + await page.get_by_role("option", name="drop").click() + + await page.get_by_role("button", name="Save").click() + await expect(page.get_by_text("Rule updated")).to_be_visible(timeout=5_000) + + # Verify in DB + async with async_session() as session: + rule = await session.get(Rule, rule_id) + assert rule.action == "drop" + + +async def test_edit_rule_destination(page: Page, test_user: User): + """Edit rule destination — verify in table and DB.""" + async with async_session() as session: + rule = Rule(action="accept", destination="10.99.0.0/16") + session.add(rule) + await session.commit() + rule_id = rule.id + + await _go_to_rules(page) + await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible(timeout=5_000) + + row = page.locator("tr", has_text="10.99.0.0/16") + await row.locator(".q-btn").first.click() + await expect(page.get_by_text("Edit Firewall Rule")).to_be_visible(timeout=5_000) + + dest_input = page.locator(".q-dialog input[aria-label='Destination (CIDR)']") + await dest_input.clear() + await dest_input.fill("10.88.0.0/16") + + await page.get_by_role("button", name="Save").click() + await expect(page.get_by_text("Rule updated")).to_be_visible(timeout=5_000) + + # Verify in DB + async with async_session() as session: + rule = await session.get(Rule, rule_id) + assert rule.destination == "10.88.0.0/16" + + +async def test_delete_rule(page: Page, test_user: User): + """Delete a rule — removed from table and DB.""" + async with async_session() as session: + rule = Rule(action="accept", destination="10.99.0.0/16") + session.add(rule) + await session.commit() + rule_id = rule.id + + await _go_to_rules(page) + await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible(timeout=5_000) + + # Click delete (second button in the row) + row = page.locator("tr", has_text="10.99.0.0/16") + await row.locator(".q-btn").nth(1).click() + await page.wait_for_timeout(1000) + + await expect(page.get_by_role("cell", name="10.99.0.0/16")).not_to_be_visible() + + # Verify in DB + async with async_session() as session: + rule = await session.get(Rule, rule_id) + assert rule is None \ No newline at end of file diff --git a/tests/e2e/test_admin_settings.py b/tests/e2e/test_admin_settings.py new file mode 100644 index 0000000..bae28e6 --- /dev/null +++ b/tests/e2e/test_admin_settings.py @@ -0,0 +1,281 @@ +"""E2E tests for admin settings page — client defaults, security, OIDC/SAML providers.""" + +import pytest_asyncio +from playwright.async_api import Page, expect +from sqlmodel import select + +from wiregui.db import async_session +from wiregui.models.configuration import Configuration +from wiregui.models.user import User +from tests.e2e.conftest import TEST_APP_BASE, login + + +@pytest_asyncio.fixture(autouse=True) +async def reset_config(app_server): + """Snapshot config before test, restore after.""" + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() + if not c: + yield + return + snap = { + "default_client_endpoint": c.default_client_endpoint, + "default_client_dns": list(c.default_client_dns), + "default_client_mtu": c.default_client_mtu, + "default_client_persistent_keepalive": c.default_client_persistent_keepalive, + "default_client_allowed_ips": list(c.default_client_allowed_ips), + "vpn_session_duration": c.vpn_session_duration, + "local_auth_enabled": c.local_auth_enabled, + "allow_unprivileged_device_management": c.allow_unprivileged_device_management, + "allow_unprivileged_device_configuration": c.allow_unprivileged_device_configuration, + "openid_connect_providers": list(c.openid_connect_providers or []), + "saml_identity_providers": list(c.saml_identity_providers or []), + } + cid = c.id + + yield + + async with async_session() as session: + c = await session.get(Configuration, cid) + if c: + for k, v in snap.items(): + setattr(c, k, v) + session.add(c) + await session.commit() + + +async def _go_to_settings(page: Page): + await login(page) + await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000) + await page.goto(f"{TEST_APP_BASE}/admin/settings") + await expect(page.get_by_text("Default Client Configuration")).to_be_visible(timeout=10_000) + + +# --- Client Defaults --- + + +async def test_save_client_defaults(page: Page, test_user: User): + """Save endpoint, DNS, MTU, keepalive, allowed IPs — verify persists in DB.""" + await _go_to_settings(page) + + endpoint = page.locator("input[aria-label='Endpoint']") + await endpoint.clear() + await endpoint.fill("vpn.test.local") + + dns = page.locator("input[aria-label='DNS Servers']") + await dns.clear() + await dns.fill("9.9.9.9, 149.112.112.112") + + mtu = page.locator("input[aria-label='MTU']") + await mtu.clear() + await mtu.fill("1420") + + keepalive = page.locator("input[aria-label='Persistent Keepalive']") + await keepalive.clear() + await keepalive.fill("30") + + allowed = page.locator("input[aria-label='Allowed IPs']") + await allowed.clear() + await allowed.fill("10.0.0.0/8, 192.168.0.0/16") + + await page.get_by_role("button", name="Save Defaults").click() + await expect(page.get_by_text("Client defaults saved")).to_be_visible(timeout=5_000) + + # Verify in DB + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one() + assert c.default_client_endpoint == "vpn.test.local" + assert c.default_client_dns == ["9.9.9.9", "149.112.112.112"] + assert c.default_client_mtu == 1420 + assert c.default_client_persistent_keepalive == 30 + assert c.default_client_allowed_ips == ["10.0.0.0/8", "192.168.0.0/16"] + + +async def test_client_defaults_persist_on_reload(page: Page, test_user: User): + """Saved defaults are reflected after page reload.""" + # Set values via DB + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one() + c.default_client_endpoint = "reload-test.vpn" + c.default_client_dns = ["8.8.8.8"] + c.default_client_mtu = 1500 + c.default_client_persistent_keepalive = 15 + c.default_client_allowed_ips = ["172.16.0.0/12"] + session.add(c) + await session.commit() + + await _go_to_settings(page) + + await expect(page.locator("input[aria-label='Endpoint']")).to_have_value("reload-test.vpn") + await expect(page.locator("input[aria-label='DNS Servers']")).to_have_value("8.8.8.8") + await expect(page.locator("input[aria-label='MTU']")).to_have_value("1500") + await expect(page.locator("input[aria-label='Persistent Keepalive']")).to_have_value("15") + await expect(page.locator("input[aria-label='Allowed IPs']")).to_have_value("172.16.0.0/12") + + +# --- Security --- + + +async def test_save_security_local_auth_toggle(page: Page, test_user: User): + """Toggle local auth off — verify in DB.""" + await _go_to_settings(page) + + # Find the local auth switch and toggle it off + switch = page.locator(".q-toggle", has_text="Local Authentication") + await switch.click() + + await page.get_by_role("button", name="Save Security Settings").click() + await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000) + + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one() + assert c.local_auth_enabled is False + + +async def test_save_vpn_session_duration(page: Page, test_user: User): + """Change VPN session duration — verify in DB.""" + await _go_to_settings(page) + + await page.locator("label:has-text('VPN Session Duration')").click() + await page.get_by_role("option", name="Every Day").click() + + await page.get_by_role("button", name="Save Security Settings").click() + await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000) + + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one() + assert c.vpn_session_duration == 86400 + + +async def test_save_unprivileged_toggles(page: Page, test_user: User): + """Toggle unprivileged device management/configuration — verify in DB.""" + await _go_to_settings(page) + + await page.locator(".q-toggle", has_text="Allow Unprivileged Device Management").click() + await page.locator(".q-toggle", has_text="Allow Unprivileged Device Configuration").click() + + await page.get_by_role("button", name="Save Security Settings").click() + await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000) + + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one() + # Toggled from default (True) to False + assert c.allow_unprivileged_device_management is False + assert c.allow_unprivileged_device_configuration is False + + +# --- OIDC Providers --- + + +async def test_add_oidc_provider(page: Page, test_user: User): + """Add an OIDC provider — appears in table and DB.""" + await _go_to_settings(page) + + await page.get_by_role("button", name="Add OIDC Provider").click() + await expect(page.get_by_text("OIDC Provider", exact=True)).to_be_visible(timeout=5_000) + + await page.locator(".q-dialog input[aria-label='Config ID']").fill("e2e-test-oidc") + await page.locator(".q-dialog input[aria-label='Label']").fill("E2E Test IdP") + await page.locator(".q-dialog input[aria-label='Client ID']").fill("test-client-id") + await page.locator(".q-dialog input[aria-label='Client Secret']").fill("test-client-secret") + await page.locator(".q-dialog input[aria-label='Discovery Document URI']").fill("https://idp.test/.well-known/openid-configuration") + + await page.locator(".q-dialog").get_by_role("button", name="Save").click() + await expect(page.get_by_text("OIDC provider 'E2E Test IdP' saved")).to_be_visible(timeout=5_000) + + await expect(page.get_by_role("cell", name="e2e-test-oidc")).to_be_visible(timeout=5_000) + + # Verify in DB + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one() + provider = next((p for p in c.openid_connect_providers if p["id"] == "e2e-test-oidc"), None) + assert provider is not None + assert provider["label"] == "E2E Test IdP" + assert provider["client_id"] == "test-client-id" + + +async def test_delete_oidc_provider(page: Page, test_user: User): + """Delete an OIDC provider — removed from table and DB.""" + # Seed a provider + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one() + providers = list(c.openid_connect_providers or []) + providers.append({ + "id": "delete-me-oidc", "label": "Delete Me", "scope": "openid", + "client_id": "x", "client_secret": "x", + "discovery_document_uri": "https://x/.well-known/openid-configuration", + }) + c.openid_connect_providers = providers + session.add(c) + await session.commit() + + await _go_to_settings(page) + await expect(page.get_by_role("cell", name="delete-me-oidc")).to_be_visible(timeout=5_000) + + row = page.locator("tr", has_text="delete-me-oidc") + await row.locator(".q-btn").first.click() + + await expect(page.get_by_text("OIDC provider deleted")).to_be_visible(timeout=5_000) + await page.wait_for_timeout(500) + await expect(page.get_by_role("cell", name="delete-me-oidc")).not_to_be_visible() + + # Verify in DB + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one() + assert not any(p["id"] == "delete-me-oidc" for p in c.openid_connect_providers) + + +# --- SAML Providers --- + + +async def test_add_saml_provider(page: Page, test_user: User): + """Add a SAML provider — appears in table and DB.""" + await _go_to_settings(page) + + await page.get_by_role("button", name="Add SAML Provider").click() + await expect(page.get_by_text("SAML Identity Provider", exact=True)).to_be_visible(timeout=5_000) + + await page.locator(".q-dialog input[aria-label='Config ID']").fill("e2e-test-saml") + await page.locator(".q-dialog input[aria-label='Label']").fill("E2E SAML IdP") + await page.locator(".q-dialog textarea").fill("test") + + await page.locator(".q-dialog").get_by_role("button", name="Save").click() + await expect(page.get_by_text("SAML provider 'E2E SAML IdP' saved")).to_be_visible(timeout=5_000) + + await expect(page.get_by_role("cell", name="e2e-test-saml")).to_be_visible(timeout=5_000) + + # Verify in DB + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one() + provider = next((p for p in c.saml_identity_providers if p["id"] == "e2e-test-saml"), None) + assert provider is not None + assert provider["label"] == "E2E SAML IdP" + + +async def test_delete_saml_provider(page: Page, test_user: User): + """Delete a SAML provider — removed from table and DB.""" + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one() + providers = list(c.saml_identity_providers or []) + providers.append({ + "id": "delete-me-saml", "label": "Delete Me SAML", + "metadata": "", + }) + c.saml_identity_providers = providers + session.add(c) + await session.commit() + + await _go_to_settings(page) + await expect(page.get_by_role("cell", name="delete-me-saml")).to_be_visible(timeout=5_000) + + row = page.locator("tr", has_text="delete-me-saml") + await row.locator(".q-btn").first.click() + + await expect(page.get_by_text("SAML provider deleted")).to_be_visible(timeout=5_000) + await page.wait_for_timeout(500) + await expect(page.get_by_role("cell", name="delete-me-saml")).not_to_be_visible() + + # Verify in DB + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one() + assert not any(p["id"] == "delete-me-saml" for p in c.saml_identity_providers) \ No newline at end of file diff --git a/tests/e2e/test_magic_link_page.py b/tests/e2e/test_magic_link_page.py new file mode 100644 index 0000000..c4676ec --- /dev/null +++ b/tests/e2e/test_magic_link_page.py @@ -0,0 +1,41 @@ +"""E2E tests for magic link request page.""" + +from playwright.async_api import Page, expect + +from tests.e2e.conftest import TEST_APP_BASE, TEST_EMAIL +from wiregui.models.user import User + + +async def test_magic_link_page_renders(page: Page, test_user: User): + """Magic link request page renders with email input and submit button.""" + await page.goto(f"{TEST_APP_BASE}/auth/magic-link") + await page.wait_for_load_state("networkidle") + await expect(page.get_by_text("Sign in with magic link")).to_be_visible(timeout=10_000) + await expect(page.locator("input[aria-label='Email']")).to_be_visible() + await expect(page.get_by_role("button", name="Send Magic Link")).to_be_visible() + await expect(page.get_by_role("button", name="Back to login")).to_be_visible() + + +async def test_magic_link_shows_success_on_submit(page: Page, test_user: User): + """Submitting an email shows success message (regardless of whether account exists).""" + await page.goto(f"{TEST_APP_BASE}/auth/magic-link") + await page.wait_for_load_state("networkidle") + await page.locator("input[aria-label='Email']").fill(TEST_EMAIL) + await page.get_by_role("button", name="Send Magic Link").click() + await expect(page.get_by_text("a sign-in link has been sent")).to_be_visible(timeout=5_000) + + +async def test_magic_link_empty_email_shows_error(page: Page, test_user: User): + """Submitting without email shows error.""" + await page.goto(f"{TEST_APP_BASE}/auth/magic-link") + await page.wait_for_load_state("networkidle") + await page.get_by_role("button", name="Send Magic Link").click() + await expect(page.get_by_text("Enter your email")).to_be_visible(timeout=5_000) + + +async def test_magic_link_back_to_login(page: Page, test_user: User): + """Back to login button navigates to login page.""" + await page.goto(f"{TEST_APP_BASE}/auth/magic-link") + await page.wait_for_load_state("networkidle") + await page.get_by_role("button", name="Back to login").click() + await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible(timeout=10_000) \ No newline at end of file diff --git a/tests/e2e/test_mfa_login.py b/tests/e2e/test_mfa_login.py new file mode 100644 index 0000000..0de2e99 --- /dev/null +++ b/tests/e2e/test_mfa_login.py @@ -0,0 +1,111 @@ +"""E2E tests for MFA login flow — login with TOTP redirects to /mfa challenge page.""" + +import pyotp +import pytest_asyncio +from playwright.async_api import Page, expect + +from wiregui.auth.mfa import generate_totp_secret +from wiregui.auth.passwords import hash_password +from wiregui.db import async_session +from wiregui.models.mfa_method import MFAMethod +from wiregui.models.user import User +from tests.e2e.conftest import ( + FAKE_SERVER_KEY, + TEST_APP_BASE, + TEST_PASSWORD, + _cleanup_user_by_email, +) + +MFA_EMAIL = "e2e-mfa@example.com" +MFA_PASSWORD = "mfapass123" +TOTP_SECRET = generate_totp_secret() + + +@pytest_asyncio.fixture +async def mfa_user(app_server): + """Create a user with a TOTP MFA method, clean up after.""" + await _cleanup_user_by_email(MFA_EMAIL) + + async with async_session() as session: + from sqlmodel import select + from wiregui.models.configuration import Configuration + + config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() + if config: + if not config.server_public_key: + config.server_public_key = FAKE_SERVER_KEY + session.add(config) + else: + config = Configuration(server_public_key=FAKE_SERVER_KEY) + session.add(config) + + user = User( + email=MFA_EMAIL, + password_hash=hash_password(MFA_PASSWORD), + role="admin", + ) + session.add(user) + await session.commit() + await session.refresh(user) + + mfa = MFAMethod( + name="Test TOTP", + type="totp", + payload={"secret": TOTP_SECRET}, + user_id=user.id, + ) + session.add(mfa) + await session.commit() + + yield user + + await _cleanup_user_by_email(MFA_EMAIL) + + +async def _login_mfa_user(page: Page): + """Fill login form for the MFA user and submit.""" + await page.goto(f"{TEST_APP_BASE}/login") + await page.wait_for_load_state("networkidle") + await page.locator("input[aria-label='Email']").fill(MFA_EMAIL) + await page.locator("input[aria-label='Password']").fill(MFA_PASSWORD) + await page.get_by_role("button", name="Sign in", exact=True).click() + + +async def test_mfa_login_redirects_to_challenge(page: Page, mfa_user: User): + """Login with MFA-enabled user redirects to /mfa challenge page.""" + await _login_mfa_user(page) + await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000) + await expect(page.locator("input[aria-label='Authentication Code']")).to_be_visible() + + +async def test_mfa_valid_totp_completes_login(page: Page, mfa_user: User): + """Entering a valid TOTP code on /mfa completes login.""" + await _login_mfa_user(page) + await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000) + + code = pyotp.TOTP(TOTP_SECRET).now() + await page.locator("input[aria-label='Authentication Code']").fill(code) + await page.get_by_role("button", name="Verify").click() + + await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000) + + +async def test_mfa_invalid_code_shows_error(page: Page, mfa_user: User): + """Entering an invalid TOTP code shows error and stays on /mfa.""" + await _login_mfa_user(page) + await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000) + + await page.locator("input[aria-label='Authentication Code']").fill("000000") + await page.get_by_role("button", name="Verify").click() + + await expect(page.get_by_text("Invalid code")).to_be_visible(timeout=5_000) + await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible() + + +async def test_mfa_cancel_returns_to_login(page: Page, mfa_user: User): + """Clicking Cancel on /mfa clears session and returns to login.""" + await _login_mfa_user(page) + await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000) + + await page.get_by_role("button", name="Cancel").click() + await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible(timeout=10_000) \ No newline at end of file diff --git a/tests/e2e/test_saml_login.py b/tests/e2e/test_saml_login.py new file mode 100644 index 0000000..2942750 --- /dev/null +++ b/tests/e2e/test_saml_login.py @@ -0,0 +1,177 @@ +"""E2E tests for SAML authentication — mock SimpleSAMLphp IdP. + +Requires mock-saml service running (docker compose up -d mock-saml). +IdP metadata: http://localhost:8080/simplesaml/saml2/idp/metadata.php +Test users: user1/user1pass, user2/user2pass +""" + +import os +import subprocess +import time + +import httpx +import pytest +import pytest_asyncio +from playwright.async_api import Page, expect +from sqlmodel import select + +from wiregui.db import async_session +from wiregui.models.configuration import Configuration +from wiregui.models.user import User +from tests.e2e.conftest import FAKE_SERVER_KEY, _cleanup_user_by_email + +MOCK_SAML_HOST = os.environ.get("MOCK_SAML_HOST", "localhost") +MOCK_SAML_METADATA_URL = f"http://{MOCK_SAML_HOST}:8080/simplesaml/saml2/idp/metadata.php" + +# Separate app port for SAML tests (like OIDC IdP tests) +SAML_APP_PORT = 13003 +SAML_APP_BASE = f"http://localhost:{SAML_APP_PORT}" + +SAML_TEST_EMAIL = "user1@example.com" + + +def _fetch_idp_metadata() -> str: + """Fetch IdP metadata XML from the mock SAML server.""" + try: + r = httpx.get(MOCK_SAML_METADATA_URL, timeout=5) + r.raise_for_status() + return r.text + except Exception: + pytest.skip(f"Mock SAML IdP not available at {MOCK_SAML_METADATA_URL}") + + +def _saml_provider_config(metadata: str) -> dict: + return { + "id": "test-saml", + "label": "Sign in with Mock SAML", + "metadata": metadata, + "sign_requests": False, + "sign_metadata": False, + "signed_assertion_in_resp": False, + "signed_envelopes_in_resp": False, + "auto_create_users": True, + "strict": False, # Relaxed for test IdP with expired certs + } + + +@pytest_asyncio.fixture(scope="module") +async def saml_metadata(): + return _fetch_idp_metadata() + + +@pytest.fixture(scope="module") +def app_with_saml(saml_metadata): + """Start a WireGUI instance with a SAML provider seeded in the DB.""" + import asyncio + + # Seed the SAML provider config into the database + async def _seed(): + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() + if config is None: + config = Configuration(server_public_key=FAKE_SERVER_KEY) + session.add(config) + await session.flush() + + providers = list(config.saml_identity_providers or []) + providers = [p for p in providers if p.get("id") != "test-saml"] + providers.append(_saml_provider_config(saml_metadata)) + config.saml_identity_providers = providers + session.add(config) + await session.commit() + + asyncio.get_event_loop().run_until_complete(_seed()) + + env = os.environ.copy() + env["WG_LOG_TO_FILE"] = "false" + env["WG_PORT"] = str(SAML_APP_PORT) + env["WG_EXTERNAL_URL"] = SAML_APP_BASE + env.pop("PYTEST_CURRENT_TEST", None) + env.pop("NICEGUI_SCREEN_TEST_PORT", None) + + proc = subprocess.Popen( + ["uv", "run", "python", "-m", "wiregui.main"], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + + for _ in range(30): + try: + r = httpx.get(f"{SAML_APP_BASE}/api/health", timeout=1) + if r.status_code == 200: + break + except Exception: + pass + time.sleep(1) + else: + proc.kill() + out = proc.stdout.read().decode() if proc.stdout else "" + pytest.fail(f"App did not start in time. Output:\n{out}") + + yield proc + + proc.terminate() + proc.wait(timeout=10) + + # Clean up seeded provider and test user + async def _cleanup(): + await _cleanup_user_by_email(SAML_TEST_EMAIL) + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() + if config: + config.saml_identity_providers = [ + p for p in (config.saml_identity_providers or []) if p.get("id") != "test-saml" + ] + session.add(config) + await session.commit() + + asyncio.get_event_loop().run_until_complete(_cleanup()) + + +async def test_saml_button_visible_on_login(app_with_saml, page: Page): + """Login page shows SAML provider button.""" + await page.goto(f"{SAML_APP_BASE}/login") + await page.wait_for_load_state("networkidle") + await expect(page.get_by_text("Sign in with Mock SAML")).to_be_visible(timeout=10_000) + + +async def test_saml_redirect_to_idp(app_with_saml, page: Page): + """Clicking SAML login redirects to the SimpleSAMLphp IdP login page.""" + await page.goto(f"{SAML_APP_BASE}/auth/saml/test-saml") + # Should redirect to the SimpleSAMLphp SSO service + await page.wait_for_url(f"**{MOCK_SAML_HOST}:8080/simplesaml/**", timeout=10_000) + + +async def test_saml_sp_metadata_endpoint(app_with_saml, page: Page): + """SP metadata endpoint returns valid XML.""" + response = await page.request.get(f"{SAML_APP_BASE}/auth/saml/test-saml/metadata") + assert response.status == 200 + body = await response.text() + assert "EntityDescriptor" in body + assert "AssertionConsumerService" in body + + +async def test_full_saml_login_flow(app_with_saml, page: Page): + """Full SAML SSO flow: app → IdP login → callback → authenticated.""" + await page.goto(f"{SAML_APP_BASE}/auth/saml/test-saml") + await page.wait_for_url(f"**{MOCK_SAML_HOST}:8080/simplesaml/**", timeout=10_000) + + # SimpleSAMLphp login form + await page.locator("input[name='username']").fill("user1") + await page.locator("input[name='password']").fill("password") + await page.locator("button[type='submit'], input[type='submit']").first.click() + + # Should redirect back to the app after SAML response + await page.wait_for_url(f"{SAML_APP_BASE}/**", timeout=15_000) + await page.wait_for_load_state("networkidle") + await page.wait_for_timeout(3000) + + assert "/login" not in page.url, f"SAML login failed — still on login page: {page.url}" + + # Verify user was auto-created in DB + async with async_session() as session: + result = await session.execute(select(User).where(User.email == SAML_TEST_EMAIL)) + user = result.scalar_one_or_none() + assert user is not None, f"Expected user {SAML_TEST_EMAIL} to be auto-created" + assert user.last_signed_in_method == "saml:test-saml" \ No newline at end of file diff --git a/tests/test_api_deps.py b/tests/test_api_deps.py new file mode 100644 index 0000000..64d8a32 --- /dev/null +++ b/tests/test_api_deps.py @@ -0,0 +1,263 @@ +"""Tests for API dependency injection — Bearer token auth and admin guard.""" + +import hashlib +from datetime import timedelta +from uuid import uuid4 + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from wiregui.auth.api_token import generate_api_token +from wiregui.auth.passwords import hash_password +from wiregui.db import async_session +from wiregui.models.api_token import ApiToken +from wiregui.models.user import User +from wiregui.utils.time import utcnow + + +# ========== resolve_bearer_token ========== + + +async def test_resolve_valid_token(): + """Valid, non-expired token resolves to user.""" + from wiregui.auth.api_token import resolve_bearer_token + + plaintext, token_hash = generate_api_token() + + async with async_session() as session: + user = User(email="api-test@test.com", password_hash=hash_password("x"), role="admin") + session.add(user) + await session.commit() + await session.refresh(user) + + api_token = ApiToken( + token_hash=token_hash, + user_id=user.id, + expires_at=utcnow() + timedelta(hours=1), + ) + session.add(api_token) + await session.commit() + + try: + async with async_session() as session: + resolved = await resolve_bearer_token(session, plaintext) + assert resolved is not None + assert resolved.id == user.id + assert resolved.email == "api-test@test.com" + finally: + async with async_session() as session: + await session.delete(await session.get(ApiToken, api_token.id)) + await session.delete(await session.get(User, user.id)) + await session.commit() + + +async def test_resolve_expired_token(): + """Expired token returns None.""" + from wiregui.auth.api_token import resolve_bearer_token + + plaintext, token_hash = generate_api_token() + + async with async_session() as session: + user = User(email="api-expired@test.com", password_hash=hash_password("x"), role="admin") + session.add(user) + await session.commit() + await session.refresh(user) + + api_token = ApiToken( + token_hash=token_hash, + user_id=user.id, + expires_at=utcnow() - timedelta(hours=1), # already expired + ) + session.add(api_token) + await session.commit() + + try: + async with async_session() as session: + resolved = await resolve_bearer_token(session, plaintext) + assert resolved is None + finally: + async with async_session() as session: + await session.delete(await session.get(ApiToken, api_token.id)) + await session.delete(await session.get(User, user.id)) + await session.commit() + + +async def test_resolve_invalid_token(): + """Nonexistent token returns None.""" + from wiregui.auth.api_token import resolve_bearer_token + + async with async_session() as session: + resolved = await resolve_bearer_token(session, "totally-bogus-token") + assert resolved is None + + +async def test_resolve_token_disabled_user(): + """Token for disabled user returns None.""" + from wiregui.auth.api_token import resolve_bearer_token + + plaintext, token_hash = generate_api_token() + + async with async_session() as session: + user = User( + email="api-disabled@test.com", password_hash=hash_password("x"), + role="admin", disabled_at=utcnow(), + ) + session.add(user) + await session.commit() + await session.refresh(user) + + api_token = ApiToken( + token_hash=token_hash, + user_id=user.id, + expires_at=utcnow() + timedelta(hours=1), + ) + session.add(api_token) + await session.commit() + + try: + async with async_session() as session: + resolved = await resolve_bearer_token(session, plaintext) + assert resolved is None + finally: + async with async_session() as session: + await session.delete(await session.get(ApiToken, api_token.id)) + await session.delete(await session.get(User, user.id)) + await session.commit() + + +async def test_resolve_token_no_expiry(): + """Token without expires_at (never expires) resolves successfully.""" + from wiregui.auth.api_token import resolve_bearer_token + + plaintext, token_hash = generate_api_token() + + async with async_session() as session: + user = User(email="api-noexp@test.com", password_hash=hash_password("x"), role="admin") + session.add(user) + await session.commit() + await session.refresh(user) + + api_token = ApiToken( + token_hash=token_hash, + user_id=user.id, + expires_at=None, + ) + session.add(api_token) + await session.commit() + + try: + async with async_session() as session: + resolved = await resolve_bearer_token(session, plaintext) + assert resolved is not None + assert resolved.id == user.id + finally: + async with async_session() as session: + await session.delete(await session.get(ApiToken, api_token.id)) + await session.delete(await session.get(User, user.id)) + await session.commit() + + +# ========== get_current_api_user (via FastAPI deps) ========== + + +async def test_get_current_api_user_missing_header(): + """Missing Authorization header raises 401.""" + from fastapi import HTTPException + from wiregui.api.deps import get_current_api_user + + request = MagicMock() + request.headers = {} + + with pytest.raises(HTTPException) as exc_info: + await get_current_api_user(request, session=AsyncMock()) + assert exc_info.value.status_code == 401 + assert "Missing" in exc_info.value.detail + + +async def test_get_current_api_user_bad_scheme(): + """Non-Bearer auth scheme raises 401.""" + from fastapi import HTTPException + from wiregui.api.deps import get_current_api_user + + request = MagicMock() + request.headers = {"Authorization": "Basic dXNlcjpwYXNz"} + + with pytest.raises(HTTPException) as exc_info: + await get_current_api_user(request, session=AsyncMock()) + assert exc_info.value.status_code == 401 + + +async def test_get_current_api_user_invalid_token(): + """Valid Bearer scheme but bogus token raises 401.""" + from fastapi import HTTPException + from wiregui.api.deps import get_current_api_user + + request = MagicMock() + request.headers = {"Authorization": "Bearer bogus-token-value"} + + async with async_session() as session: + with pytest.raises(HTTPException) as exc_info: + await get_current_api_user(request, session=session) + assert exc_info.value.status_code == 401 + assert "Invalid" in exc_info.value.detail + + +async def test_get_current_api_user_valid_token(): + """Valid Bearer token resolves to user.""" + from wiregui.api.deps import get_current_api_user + + plaintext, token_hash = generate_api_token() + + async with async_session() as session: + user = User(email="api-dep-test@test.com", password_hash=hash_password("x"), role="admin") + session.add(user) + await session.commit() + await session.refresh(user) + + api_token = ApiToken( + token_hash=token_hash, + user_id=user.id, + expires_at=utcnow() + timedelta(hours=1), + ) + session.add(api_token) + await session.commit() + + try: + request = MagicMock() + request.headers = {"Authorization": f"Bearer {plaintext}"} + + async with async_session() as session: + resolved = await get_current_api_user(request, session=session) + assert resolved.id == user.id + finally: + async with async_session() as session: + await session.delete(await session.get(ApiToken, api_token.id)) + await session.delete(await session.get(User, user.id)) + await session.commit() + + +# ========== require_admin ========== + + +async def test_require_admin_allows_admin(): + """Admin user passes require_admin.""" + from wiregui.api.deps import require_admin + + admin_user = MagicMock(spec=User) + admin_user.role = "admin" + result = await require_admin(user=admin_user) + assert result == admin_user + + +async def test_require_admin_rejects_unprivileged(): + """Non-admin user gets 403.""" + from fastapi import HTTPException + from wiregui.api.deps import require_admin + + regular_user = MagicMock(spec=User) + regular_user.role = "unprivileged" + + with pytest.raises(HTTPException) as exc_info: + await require_admin(user=regular_user) + assert exc_info.value.status_code == 403 + assert "Admin" in exc_info.value.detail \ No newline at end of file diff --git a/tests/test_firewall_extended.py b/tests/test_firewall_extended.py new file mode 100644 index 0000000..08a8df3 --- /dev/null +++ b/tests/test_firewall_extended.py @@ -0,0 +1,206 @@ +"""Extended firewall tests — _nft/_nft_batch error handling, add_device_jump_rule edge cases, policies.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from wiregui.services.firewall import ( + _nft, + _nft_batch, + add_device_jump_rule, + setup_base_tables, + setup_masquerade, + apply_peer_to_peer_policy, + apply_lan_to_peers_policy, + get_ruleset, +) + + +# ========== _nft error handling ========== + + +@patch("asyncio.create_subprocess_exec") +async def test_nft_raises_on_failure(mock_exec): + """_nft raises RuntimeError on non-zero exit code.""" + mock_proc = AsyncMock() + mock_proc.communicate.return_value = (b"", b"nft: error message") + mock_proc.returncode = 1 + mock_exec.return_value = mock_proc + + with pytest.raises(RuntimeError, match="nft.*failed"): + await _nft("list ruleset") + + +@patch("asyncio.create_subprocess_exec") +async def test_nft_returns_stdout_on_success(mock_exec): + """_nft returns stdout on success.""" + mock_proc = AsyncMock() + mock_proc.communicate.return_value = (b"table inet wiregui {}", b"") + mock_proc.returncode = 0 + mock_exec.return_value = mock_proc + + result = await _nft("list ruleset") + assert "wiregui" in result + + +# ========== _nft_batch error handling ========== + + +@patch("asyncio.create_subprocess_exec") +async def test_nft_batch_raises_on_failure(mock_exec): + """_nft_batch raises RuntimeError on non-zero exit code.""" + mock_proc = AsyncMock() + mock_proc.communicate.return_value = (b"", b"Error: syntax error") + mock_proc.returncode = 1 + mock_exec.return_value = mock_proc + + with pytest.raises(RuntimeError, match="nft batch failed"): + await _nft_batch(["add table inet wiregui"]) + + +@patch("asyncio.create_subprocess_exec") +async def test_nft_batch_sends_commands_via_stdin(mock_exec): + """_nft_batch sends all commands via stdin to nft -f -.""" + mock_proc = AsyncMock() + mock_proc.communicate.return_value = (b"", b"") + mock_proc.returncode = 0 + mock_exec.return_value = mock_proc + + cmds = ["add table inet wiregui", "add chain inet wiregui test"] + await _nft_batch(cmds) + + mock_exec.assert_awaited_once() + # Verify nft -f - was called + call_args = mock_exec.call_args[0] + assert call_args == ("nft", "-f", "-") + # Verify stdin data + stdin_data = mock_proc.communicate.call_args[0][0] + assert b"add table inet wiregui" in stdin_data + assert b"add chain inet wiregui test" in stdin_data + + +# ========== add_device_jump_rule edge cases ========== + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +async def test_add_device_jump_rule_ipv4_only(mock_batch): + """Only IPv4 — generates single IPv4 jump rule.""" + await add_device_jump_rule("user-id-1", "10.0.0.5", None) + mock_batch.assert_awaited_once() + cmds = mock_batch.call_args[0][0] + assert len(cmds) == 1 + assert "ip saddr 10.0.0.5" in cmds[0] + assert "jump" in cmds[0] + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +async def test_add_device_jump_rule_ipv6_only(mock_batch): + """Only IPv6 — generates single IPv6 jump rule.""" + await add_device_jump_rule("user-id-2", None, "fd00::5") + mock_batch.assert_awaited_once() + cmds = mock_batch.call_args[0][0] + assert len(cmds) == 1 + assert "ip6 saddr fd00::5" in cmds[0] + assert "jump" in cmds[0] + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +async def test_add_device_jump_rule_no_ips(mock_batch): + """Neither IPv4 nor IPv6 — no nft commands issued.""" + await add_device_jump_rule("user-id-3", None, None) + mock_batch.assert_not_awaited() + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +async def test_add_device_jump_rule_both_ips(mock_batch): + """Both IPv4 and IPv6 — generates two jump rules.""" + await add_device_jump_rule("user-id-4", "10.0.0.7", "fd00::7") + mock_batch.assert_awaited_once() + cmds = mock_batch.call_args[0][0] + assert len(cmds) == 2 + assert any("ip saddr 10.0.0.7" in c for c in cmds) + assert any("ip6 saddr fd00::7" in c for c in cmds) + + +# ========== setup_base_tables — already exists ========== + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +async def test_setup_base_tables_already_exists(mock_batch): + """If table already exists (File exists error), don't raise.""" + mock_batch.side_effect = RuntimeError("File exists") + await setup_base_tables() # should not raise + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +async def test_setup_base_tables_other_error_raises(mock_batch): + """Other nft errors should propagate.""" + mock_batch.side_effect = RuntimeError("Permission denied") + with pytest.raises(RuntimeError, match="Permission denied"): + await setup_base_tables() + + +# ========== setup_masquerade — error handling ========== + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +async def test_setup_masquerade_error_swallowed(mock_batch): + """Masquerade errors are logged but not raised.""" + mock_batch.side_effect = RuntimeError("nft error") + await setup_masquerade(iface="wg0") # should not raise + + +# ========== policy functions — command verification ========== + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +async def test_peer_to_peer_enabled(mock_batch): + """Enabling peer-to-peer generates accept rules.""" + await apply_peer_to_peer_policy(True) + cmds = mock_batch.call_args[0][0] + assert any("accept" in c for c in cmds) + assert any("peer_to_peer" in c for c in cmds) + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +async def test_peer_to_peer_disabled(mock_batch): + """Disabling peer-to-peer generates drop rules.""" + await apply_peer_to_peer_policy(False) + cmds = mock_batch.call_args[0][0] + assert any("drop" in c for c in cmds) + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +async def test_lan_to_peers_enabled(mock_batch): + """Enabling LAN-to-peers generates accept rules.""" + await apply_lan_to_peers_policy(True) + cmds = mock_batch.call_args[0][0] + assert any("accept" in c for c in cmds) + assert any("lan_to_peers" in c for c in cmds) + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +async def test_lan_to_peers_disabled(mock_batch): + """Disabling LAN-to-peers generates drop rules.""" + await apply_lan_to_peers_policy(False) + cmds = mock_batch.call_args[0][0] + assert any("drop" in c for c in cmds) + + +# ========== get_ruleset — error handling ========== + + +@patch("wiregui.services.firewall._nft", new_callable=AsyncMock) +async def test_get_ruleset_returns_output(mock_nft): + """get_ruleset returns nft list ruleset output.""" + mock_nft.return_value = "table inet wiregui { ... }" + result = await get_ruleset() + assert "wiregui" in result + + +@patch("wiregui.services.firewall._nft", new_callable=AsyncMock) +async def test_get_ruleset_returns_fallback_on_error(mock_nft): + """get_ruleset returns friendly message when nft not available.""" + mock_nft.side_effect = RuntimeError("nft not found") + result = await get_ruleset() + assert "not available" in result \ No newline at end of file diff --git a/tests/test_wireguard_extended.py b/tests/test_wireguard_extended.py new file mode 100644 index 0000000..ab848df --- /dev/null +++ b/tests/test_wireguard_extended.py @@ -0,0 +1,114 @@ +"""Tests for WireGuard service — ensure_interface, set_private_key, set_listen_port, configure_interface.""" + +from unittest.mock import AsyncMock, patch, call + +from wiregui.services.wireguard import ( + ensure_interface, + set_private_key, + set_listen_port, + configure_interface, +) + + +# ========== ensure_interface ========== + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_ensure_interface_already_exists(mock_run): + """If interface exists (ip link show succeeds), do nothing.""" + mock_run.return_value = "" + await ensure_interface(iface="wg-test") + # Only called once for ip link show + mock_run.assert_awaited_once_with(["ip", "link", "show", "wg-test"]) + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_ensure_interface_creates_new(mock_run): + """If interface doesn't exist, create it, assign IPs, bring up.""" + call_count = 0 + + async def side_effect(args, input_data=None): + nonlocal call_count + call_count += 1 + if call_count == 1 and args == ["ip", "link", "show", "wg-test"]: + raise RuntimeError("Device not found") + return "" + + mock_run.side_effect = side_effect + await ensure_interface(iface="wg-test") + + # Should have called: ip link show (fails), ip link add, ip addr add x2, ip link set up + assert mock_run.await_count == 5 + calls = [c[0][0] for c in mock_run.call_args_list] + assert calls[1] == ["ip", "link", "add", "wg-test", "type", "wireguard"] + assert calls[2][0:3] == ["ip", "address", "add"] + assert calls[3][0:3] == ["ip", "address", "add"] + assert calls[4] == ["ip", "link", "set", "wg-test", "up"] + + +# ========== set_private_key ========== + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_set_private_key(mock_run): + """set_private_key calls wg set with private-key path.""" + mock_run.return_value = "" + await set_private_key("/tmp/test.key", iface="wg-test") + mock_run.assert_awaited_once_with(["wg", "set", "wg-test", "private-key", "/tmp/test.key"]) + + +# ========== set_listen_port ========== + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_set_listen_port(mock_run): + """set_listen_port calls wg set with listen-port.""" + mock_run.return_value = "" + await set_listen_port(51820, iface="wg-test") + mock_run.assert_awaited_once_with(["wg", "set", "wg-test", "listen-port", "51820"]) + + +# ========== configure_interface ========== + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +@patch("wiregui.db.async_session") +async def test_configure_interface_no_config(mock_session_cls, mock_run): + """If no Configuration row exists, do not call wg set.""" + from unittest.mock import MagicMock + + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute.return_value = mock_result + mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=False) + + await configure_interface(iface="wg-test") + mock_run.assert_not_awaited() + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +@patch("wiregui.db.async_session") +async def test_configure_interface_sets_key_and_port(mock_session_cls, mock_run): + """With valid config, writes key to temp file and calls wg set.""" + from unittest.mock import MagicMock + + mock_config = MagicMock() + mock_config.server_private_key = "test-private-key-value" + + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = mock_config + mock_session.execute.return_value = mock_result + mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=False) + + mock_run.return_value = "" + await configure_interface(iface="wg-test") + + mock_run.assert_awaited_once() + args = mock_run.call_args[0][0] + assert args[0:3] == ["wg", "set", "wg-test"] + assert "private-key" in args + assert "listen-port" in args \ No newline at end of file diff --git a/wiregui/auth/saml.py b/wiregui/auth/saml.py index c624a35..c71054f 100644 --- a/wiregui/auth/saml.py +++ b/wiregui/auth/saml.py @@ -17,7 +17,7 @@ def _build_saml_settings(provider_config: dict) -> dict: idp_settings = idp_data.get("idp", {}) return { - "strict": True, + "strict": provider_config.get("strict", True), "debug": False, "sp": { "entityId": f"{base_url}/auth/saml/{provider_config['id']}/metadata", diff --git a/wiregui/pages/admin/settings.py b/wiregui/pages/admin/settings.py index 72ec470..7a868c1 100644 --- a/wiregui/pages/admin/settings.py +++ b/wiregui/pages/admin/settings.py @@ -6,6 +6,7 @@ from loguru import logger from nicegui import app, ui from sqlmodel import select +from wiregui.config import get_settings from wiregui.db import async_session from wiregui.models.configuration import Configuration from wiregui.pages.layout import layout diff --git a/wiregui/pages/auth_saml.py b/wiregui/pages/auth_saml.py index c9dccc2..9183f2b 100644 --- a/wiregui/pages/auth_saml.py +++ b/wiregui/pages/auth_saml.py @@ -101,14 +101,13 @@ async def saml_callback(provider_id: str, request: Request): session.add(user) await session.commit() - request.session["authenticated"] = True - request.session["user_id"] = str(user.id) - request.session["email"] = user.email - request.session["role"] = user.role - request.session["theme_preference"] = user.theme_preference + # Store auth data in Starlette session — picked up by /auth/complete + request.session["oidc_user_id"] = str(user.id) + request.session["oidc_email"] = user.email + request.session["oidc_role"] = user.role logger.info("SAML login: {} via {}", email, provider_id) - return RedirectResponse(url="/", status_code=303) + return RedirectResponse(url="/auth/complete", status_code=303) except Exception as e: logger.error("SAML callback failed for {}: {}", provider_id, e) diff --git a/wiregui/pages/login.py b/wiregui/pages/login.py index f1b2110..35781d9 100644 --- a/wiregui/pages/login.py +++ b/wiregui/pages/login.py @@ -1,4 +1,4 @@ -"""Login page — email/password, MFA redirect, OIDC provider buttons.""" +"""Login page — email/password, MFA redirect, OIDC/SAML provider buttons.""" from nicegui import app, ui from sqlmodel import select @@ -6,6 +6,7 @@ from sqlmodel import select from wiregui.auth.oidc import load_providers from wiregui.auth.session import authenticate_user from wiregui.db import async_session +from wiregui.models.configuration import Configuration from wiregui.models.mfa_method import MFAMethod from wiregui.pages.style import apply_style from wiregui.utils.time import utcnow @@ -18,9 +19,13 @@ async def login_page(): apply_style() - # Load OIDC providers for SSO buttons + # Load SSO providers for login buttons oidc_providers = await load_providers() + async with async_session() as session: + config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() + saml_providers = config.saml_identity_providers if config else [] + async def try_login(): user = await authenticate_user(email.value, password.value) if user is None: @@ -76,8 +81,8 @@ async def login_page(): password.on("keydown.enter", try_login) - # OIDC provider buttons - if oidc_providers: + # SSO provider buttons + if oidc_providers or saml_providers: ui.separator().classes("q-my-md") ui.label("Or sign in with").classes("text-caption text-center w-full") for provider in oidc_providers: @@ -87,3 +92,10 @@ async def login_page(): label, on_click=lambda p=pid: ui.run_javascript(f"window.location.href='/auth/oidc/{p}'"), ).props("color=primary unelevated").classes("w-full q-mt-xs") + for provider in saml_providers: + pid = provider.get("id", "") + label = provider.get("label", pid) + ui.button( + label, + on_click=lambda p=pid: ui.run_javascript(f"window.location.href='/auth/saml/{p}'"), + ).props("color=primary unelevated").classes("w-full q-mt-xs")