feat: comprehensive test suite + SAML auth fixes + mock SAML IdP
Tests (198 unit + 70 e2e = 268 total): - Add test_api_deps.py: Bearer token auth, get_current_api_user, require_admin - Add test_wireguard_extended.py: ensure_interface, set_private_key, set_listen_port - Add test_firewall_extended.py: _nft/_nft_batch errors, jump rules, policies - Add test_mfa_login.py: MFA redirect, TOTP verify, invalid code, cancel - Add test_magic_link_page.py: page render, submit, empty email, back to login - Add test_admin_devices.py: list, filter, create, edit, delete, config dialog - Add test_admin_rules.py: list, create, edit, delete (all DB-verified) - Add test_admin_settings.py: defaults, security, OIDC/SAML providers - Add test_saml_login.py: button visible, redirect, metadata, full login flow Bug fixes: - Fix SAML callback to use /auth/complete bridge (same fix as OIDC) - Fix missing get_settings import in admin settings page - Add SAML provider buttons to login page - Make SAML strict mode configurable per-provider Infrastructure: - Add mock SimpleSAMLphp IdP to compose.yml with SP config - Add mock-saml service to CI workflows (release + dev)
This commit is contained in:
parent
25cff5e4d9
commit
06b5a3dc12
18 changed files with 1768 additions and 47 deletions
|
|
@ -34,11 +34,18 @@ jobs:
|
||||||
env:
|
env:
|
||||||
SERVER_PORT: "9000"
|
SERVER_PORT: "9000"
|
||||||
JSON_CONFIG: '{"interactiveLogin":true,"httpServer":"NettyWrapper","tokenCallbacks":[{"issuerId":"test-idp","tokenExpiry":3600,"requestMappings":[{"requestParam":"scope","match":"*","claims":{"sub":"$${claim:sub}","email":"$${claim:sub}@test.local","name":"Test User"}}]}]}'
|
JSON_CONFIG: '{"interactiveLogin":true,"httpServer":"NettyWrapper","tokenCallbacks":[{"issuerId":"test-idp","tokenExpiry":3600,"requestMappings":[{"requestParam":"scope","match":"*","claims":{"sub":"$${claim:sub}","email":"$${claim:sub}@test.local","name":"Test User"}}]}]}'
|
||||||
|
mock-saml:
|
||||||
|
image: kenchan0130/simplesamlphp
|
||||||
|
env:
|
||||||
|
SIMPLESAMLPHP_SP_ENTITY_ID: http://localhost:13003/auth/saml/test-saml/metadata
|
||||||
|
SIMPLESAMLPHP_SP_ASSERTION_CONSUMER_SERVICE: http://localhost:13003/auth/saml/test-saml/callback
|
||||||
|
SIMPLESAMLPHP_IDP_BASE_URL: http://mock-saml:8080/simplesaml/
|
||||||
env:
|
env:
|
||||||
CI: "true"
|
CI: "true"
|
||||||
WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui
|
WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui
|
||||||
WG_REDIS_URL: redis://valkey:6379/0
|
WG_REDIS_URL: redis://valkey:6379/0
|
||||||
MOCK_OIDC_HOST: mock-oidc
|
MOCK_OIDC_HOST: mock-oidc
|
||||||
|
MOCK_SAML_HOST: mock-saml
|
||||||
steps:
|
steps:
|
||||||
- name: Install system dependencies and checkout
|
- name: Install system dependencies and checkout
|
||||||
run: |
|
run: |
|
||||||
|
|
|
||||||
|
|
@ -35,11 +35,18 @@ jobs:
|
||||||
env:
|
env:
|
||||||
SERVER_PORT: "9000"
|
SERVER_PORT: "9000"
|
||||||
JSON_CONFIG: '{"interactiveLogin":true,"httpServer":"NettyWrapper","tokenCallbacks":[{"issuerId":"test-idp","tokenExpiry":3600,"requestMappings":[{"requestParam":"scope","match":"*","claims":{"sub":"$${claim:sub}","email":"$${claim:sub}@test.local","name":"Test User"}}]}]}'
|
JSON_CONFIG: '{"interactiveLogin":true,"httpServer":"NettyWrapper","tokenCallbacks":[{"issuerId":"test-idp","tokenExpiry":3600,"requestMappings":[{"requestParam":"scope","match":"*","claims":{"sub":"$${claim:sub}","email":"$${claim:sub}@test.local","name":"Test User"}}]}]}'
|
||||||
|
mock-saml:
|
||||||
|
image: kenchan0130/simplesamlphp
|
||||||
|
env:
|
||||||
|
SIMPLESAMLPHP_SP_ENTITY_ID: http://localhost:13003/auth/saml/test-saml/metadata
|
||||||
|
SIMPLESAMLPHP_SP_ASSERTION_CONSUMER_SERVICE: http://localhost:13003/auth/saml/test-saml/callback
|
||||||
|
SIMPLESAMLPHP_IDP_BASE_URL: http://mock-saml:8080/simplesaml/
|
||||||
env:
|
env:
|
||||||
CI: "true"
|
CI: "true"
|
||||||
WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui
|
WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui
|
||||||
WG_REDIS_URL: redis://valkey:6379/0
|
WG_REDIS_URL: redis://valkey:6379/0
|
||||||
MOCK_OIDC_HOST: mock-oidc
|
MOCK_OIDC_HOST: mock-oidc
|
||||||
|
MOCK_SAML_HOST: mock-saml
|
||||||
steps:
|
steps:
|
||||||
- name: Install system dependencies and checkout
|
- name: Install system dependencies and checkout
|
||||||
run: |
|
run: |
|
||||||
|
|
|
||||||
78
TODO.md
78
TODO.md
|
|
@ -1,6 +1,6 @@
|
||||||
# WireGUI — Pending Items
|
# WireGUI — Pending Items
|
||||||
|
|
||||||
**Test count: 174 (164 unit + 10 E2E) | Coverage: ~35%**
|
**Test count: 268 (198 unit + 70 E2E) | Coverage: 36% unit, ~63% effective (incl. E2E)**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
@ -11,7 +11,7 @@
|
||||||
Migration of Wirezone (Elixir/Phoenix) to Python/NiceGUI.
|
Migration of Wirezone (Elixir/Phoenix) to Python/NiceGUI.
|
||||||
Source: `/home/stefanob/PycharmProjects/personal/wirezone`
|
Source: `/home/stefanob/PycharmProjects/personal/wirezone`
|
||||||
|
|
||||||
**Test count: 199 (164 unit + 35 E2E) | Coverage: 35%**
|
**Test count: 268 (198 unit + 70 E2E) | Coverage: 36% unit, ~63% effective (incl. E2E)**
|
||||||
**Run:** `uv run pytest` (unit) / `uv run pytest tests/e2e/` (E2E via Playwright)
|
**Run:** `uv run pytest` (unit) / `uv run pytest tests/e2e/` (E2E via Playwright)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -23,11 +23,11 @@ Source: `/home/stefanob/PycharmProjects/personal/wirezone`
|
||||||
|
|
||||||
### Testing (partially done)
|
### Testing (partially done)
|
||||||
- [ ] HTTP-level integration tests (OIDC redirect/callback flow with respx mocking)
|
- [ ] HTTP-level integration tests (OIDC redirect/callback flow with respx mocking)
|
||||||
- [ ] `wiregui/api/deps.py` — test get_current_api_user with real Bearer header parsing, require_admin rejection
|
- [x] `wiregui/api/deps.py` (11 tests) — resolve_bearer_token (valid/expired/invalid/disabled/no-expiry), get_current_api_user (missing header/bad scheme/invalid token/valid token), require_admin (admin/unprivileged)
|
||||||
- [ ] `wiregui/services/wireguard.py` — test ensure_interface, set_private_key, set_listen_port
|
- [x] `wiregui/services/wireguard.py` (6 tests) — ensure_interface (exists/creates new), set_private_key, set_listen_port, configure_interface (no config/sets key+port)
|
||||||
- [ ] `wiregui/services/firewall.py` — test _nft/_nft_batch error handling, add_device_jump_rule with only ipv4/ipv6
|
- [x] `wiregui/services/firewall.py` (17 tests) — _nft error/success, _nft_batch error/stdin, add_device_jump_rule (ipv4-only/ipv6-only/no-ips/both), setup_base_tables error handling, masquerade error, peer-to-peer/lan-to-peers policies, get_ruleset fallback
|
||||||
- [ ] `wiregui/tasks/oidc_refresh.py` — test successful refresh, failure with notification, disable_vpn_on_oidc_error
|
- [ ] `wiregui/tasks/oidc_refresh.py` — test successful refresh, failure with notification, disable_vpn_on_oidc_error
|
||||||
- [ ] `wiregui/auth/saml.py` (0%) — needs mock SAML IdP metadata + response parsing
|
- [x] `wiregui/auth/saml.py` — full SAML flow tested via mock SimpleSAMLphp IdP (e2e)
|
||||||
- [ ] `wiregui/auth/webauthn.py` — test verify_registration, verify_authentication with mock credential data
|
- [ ] `wiregui/auth/webauthn.py` — test verify_registration, verify_authentication with mock credential data
|
||||||
- [ ] E2E tests for admin pages (users, devices, rules, settings)
|
- [ ] E2E tests for admin pages (users, devices, rules, settings)
|
||||||
|
|
||||||
|
|
@ -37,46 +37,52 @@ Source: `/home/stefanob/PycharmProjects/personal/wirezone`
|
||||||
- [x] `tests/e2e/test_account.py` (8 tests) — change password (success/wrong/mismatch/short), create API token, TOTP registration + invalid code, account deletion
|
- [x] `tests/e2e/test_account.py` (8 tests) — change password (success/wrong/mismatch/short), create API token, TOTP registration + invalid code, account deletion
|
||||||
- [x] `tests/e2e/test_admin_users.py` (10 tests) — page renders, create user, duplicate email, edit role/password, disable/enable, delete, cascade delete, self-delete guard
|
- [x] `tests/e2e/test_admin_users.py` (10 tests) — page renders, create user, duplicate email, edit role/password, disable/enable, delete, cascade delete, self-delete guard
|
||||||
- [x] `tests/e2e/test_idp_seed.py` (9 tests) — IdP YAML seeding (noop/missing/invalid, OIDC/SAML add, upsert, preserve), OIDC button visible, full OIDC login flow via mock-oidc
|
- [x] `tests/e2e/test_idp_seed.py` (9 tests) — IdP YAML seeding (noop/missing/invalid, OIDC/SAML add, upsert, preserve), OIDC button visible, full OIDC login flow via mock-oidc
|
||||||
|
- [x] `tests/e2e/test_mfa_login.py` (4 tests) — MFA redirect on login, valid TOTP completes login, invalid code error, cancel returns to login
|
||||||
|
- [x] `tests/e2e/test_magic_link_page.py` (4 tests) — page renders, success on submit, empty email error, back to login
|
||||||
|
- [x] `tests/e2e/test_admin_devices.py` (7 tests) — list all devices, filter by user, create with defaults, create with overrides, edit name/description, delete, config dialog with QR
|
||||||
|
- [x] `tests/e2e/test_admin_rules.py` (7 tests) — list rules table, create accept/drop/global rules, edit action/destination, delete rule (all verified in DB)
|
||||||
|
- [x] `tests/e2e/test_admin_settings.py` (9 tests) — client defaults save/reload, security toggles (local auth, VPN session, unprivileged), OIDC add/delete, SAML add/delete (all verified in DB)
|
||||||
|
- [x] `tests/e2e/test_saml_login.py` (4 tests) — SAML button visible, redirect to IdP, SP metadata endpoint, full SAML login flow via mock SimpleSAMLphp
|
||||||
|
|
||||||
**E2E tests still needed:**
|
**E2E tests still needed:**
|
||||||
|
|
||||||
`tests/e2e/test_login.py` — Login & Auth flows (remaining):
|
`tests/e2e/test_login.py` — Login & Auth flows (remaining):
|
||||||
- [ ] Login with MFA → redirects to /mfa challenge page
|
- [x] Login with MFA → redirects to /mfa challenge page
|
||||||
- [ ] MFA challenge: valid TOTP code → completes login
|
- [x] MFA challenge: valid TOTP code → completes login
|
||||||
- [ ] MFA challenge: invalid code → shows error, stays on /mfa
|
- [x] MFA challenge: invalid code → shows error, stays on /mfa
|
||||||
- [ ] MFA challenge: cancel → returns to /login
|
- [x] MFA challenge: cancel → returns to /login
|
||||||
- [ ] Magic link request page renders, shows success on submit
|
- [x] Magic link request page renders, shows success on submit
|
||||||
|
|
||||||
`tests/e2e/test_admin_devices.py` — Admin Device Management:
|
`tests/e2e/test_admin_devices.py` — Admin Device Management:
|
||||||
- [ ] List all devices across users
|
- [x] List all devices across users
|
||||||
- [ ] Filter by user → shows only that user's devices
|
- [x] Filter by user → shows only that user's devices
|
||||||
- [ ] Create device with full config overrides (DNS, endpoint, MTU, keepalive, allowed IPs)
|
- [x] Create device with full config overrides (DNS, endpoint, MTU, keepalive, allowed IPs)
|
||||||
- [ ] Create device with defaults → use_default flags all True
|
- [x] Create device with defaults → use_default flags all True
|
||||||
- [ ] Edit device name and description → persists
|
- [x] Edit device name and description → persists
|
||||||
- [ ] Edit device config overrides (toggle use_default off, set custom values)
|
- [x] Edit device config overrides (toggle use_default off, set custom values)
|
||||||
- [ ] Delete device → removed from table
|
- [x] Delete device → removed from table
|
||||||
- [ ] Config dialog shows valid WireGuard config with real server public key
|
- [x] Config dialog shows valid WireGuard config with real server public key
|
||||||
- [ ] QR code renders in config dialog
|
- [x] QR code renders in config dialog
|
||||||
|
|
||||||
`tests/e2e/test_admin_rules.py` — Admin Firewall Rules:
|
`tests/e2e/test_admin_rules.py` — Admin Firewall Rules:
|
||||||
- [ ] List rules → table shows action, destination, protocol, port, user
|
- [x] List rules → table shows action, destination, protocol, port, user
|
||||||
- [ ] Create accept rule with CIDR → appears in table
|
- [x] Create accept rule with CIDR → appears in table
|
||||||
- [ ] Create drop rule with TCP port range → appears correctly
|
- [x] Create drop rule with TCP port range → appears correctly
|
||||||
- [ ] Create global rule (no user) → shows "Global"
|
- [x] Create global rule (no user) → shows "Global"
|
||||||
- [ ] Edit rule action (accept → drop) → persists
|
- [x] Edit rule action (accept → drop) → persists
|
||||||
- [ ] Edit rule destination → persists
|
- [x] Edit rule destination → persists
|
||||||
- [ ] Delete rule → removed from table
|
- [x] Delete rule → removed from table
|
||||||
|
|
||||||
`tests/e2e/test_admin_settings.py` — Admin Settings:
|
`tests/e2e/test_admin_settings.py` — Admin Settings:
|
||||||
- [ ] Client defaults: save endpoint, DNS, MTU, keepalive, allowed IPs → persists in DB
|
- [x] Client defaults: save endpoint, DNS, MTU, keepalive, allowed IPs → persists in DB
|
||||||
- [ ] Client defaults: saved values reflected on page reload
|
- [x] Client defaults: saved values reflected on page reload
|
||||||
- [ ] Security: toggle local auth → persists
|
- [x] Security: toggle local auth → persists
|
||||||
- [ ] Security: change VPN session duration → persists
|
- [x] Security: change VPN session duration → persists
|
||||||
- [ ] Security: toggle unprivileged device management/configuration → persists
|
- [x] Security: toggle unprivileged device management/configuration → persists
|
||||||
- [ ] OIDC: add provider → appears in table
|
- [x] OIDC: add provider → appears in table
|
||||||
- [ ] OIDC: delete provider → removed from table
|
- [x] OIDC: delete provider → removed from table
|
||||||
- [ ] SAML: add provider → appears in table
|
- [x] SAML: add provider → appears in table
|
||||||
- [ ] SAML: delete provider → removed from table
|
- [x] SAML: delete provider → removed from table
|
||||||
|
|
||||||
`tests/e2e/test_admin_diagnostics.py` — Admin Diagnostics:
|
`tests/e2e/test_admin_diagnostics.py` — Admin Diagnostics:
|
||||||
- [ ] Page renders WireGuard interface status
|
- [ ] Page renders WireGuard interface status
|
||||||
|
|
|
||||||
15
compose.yml
15
compose.yml
|
|
@ -49,6 +49,21 @@ services:
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Test SAML Identity Provider — SimpleSAMLphp as IdP
|
||||||
|
# IdP Metadata: http://localhost:8080/simplesaml/saml2/idp/metadata.php
|
||||||
|
# Admin UI: http://localhost:8080/simplesaml (admin / secret)
|
||||||
|
# Test users: user1/password, user2/password
|
||||||
|
mock-saml:
|
||||||
|
image: kenchan0130/simplesamlphp
|
||||||
|
ports:
|
||||||
|
- "8080:8080"
|
||||||
|
environment:
|
||||||
|
SIMPLESAMLPHP_SP_ENTITY_ID: "http://localhost:13000/auth/saml/test-saml/metadata"
|
||||||
|
SIMPLESAMLPHP_SP_ASSERTION_CONSUMER_SERVICE: "http://localhost:13000/auth/saml/test-saml/callback"
|
||||||
|
SIMPLESAMLPHP_IDP_BASE_URL: http://localhost:8080/simplesaml/
|
||||||
|
volumes:
|
||||||
|
- ./docker/mock-saml/saml20-sp-remote.php:/var/www/simplesamlphp/metadata/saml20-sp-remote.php:ro
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres_data:
|
postgres_data:
|
||||||
valkey_data:
|
valkey_data:
|
||||||
|
|
|
||||||
15
docker/mock-saml/saml20-sp-remote.php
Normal file
15
docker/mock-saml/saml20-sp-remote.php
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
<?php
|
||||||
|
/**
|
||||||
|
* SAML 2.0 remote SP metadata for WireGUI testing.
|
||||||
|
* Registers SPs for dev (port 13000) and e2e test (port 13003).
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Dev instance
|
||||||
|
$metadata['http://localhost:13000/auth/saml/test-saml/metadata'] = [
|
||||||
|
'AssertionConsumerService' => 'http://localhost:13000/auth/saml/test-saml/callback',
|
||||||
|
];
|
||||||
|
|
||||||
|
// E2E test instance
|
||||||
|
$metadata['http://localhost:13003/auth/saml/test-saml/metadata'] = [
|
||||||
|
'AssertionConsumerService' => 'http://localhost:13003/auth/saml/test-saml/callback',
|
||||||
|
];
|
||||||
239
tests/e2e/test_admin_devices.py
Normal file
239
tests/e2e/test_admin_devices.py
Normal file
|
|
@ -0,0 +1,239 @@
|
||||||
|
"""E2E tests for admin device management page."""
|
||||||
|
|
||||||
|
import pytest_asyncio
|
||||||
|
from playwright.async_api import Page, expect
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
from wiregui.auth.passwords import hash_password
|
||||||
|
from wiregui.db import async_session
|
||||||
|
from wiregui.models.device import Device
|
||||||
|
from wiregui.models.user import User
|
||||||
|
from wiregui.utils.crypto import generate_keypair, generate_preshared_key
|
||||||
|
from tests.e2e.conftest import (
|
||||||
|
TEST_APP_BASE,
|
||||||
|
TEST_EMAIL,
|
||||||
|
TEST_PASSWORD,
|
||||||
|
_cleanup_user_by_email,
|
||||||
|
login,
|
||||||
|
)
|
||||||
|
|
||||||
|
SECOND_USER_EMAIL = "e2e-device-user2@example.com"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def second_user(test_user):
|
||||||
|
"""Create a second user with a device for filtering tests."""
|
||||||
|
await _cleanup_user_by_email(SECOND_USER_EMAIL)
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
user = User(
|
||||||
|
email=SECOND_USER_EMAIL,
|
||||||
|
password_hash=hash_password("pass12345"),
|
||||||
|
role="unprivileged",
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
|
||||||
|
yield user
|
||||||
|
|
||||||
|
await _cleanup_user_by_email(SECOND_USER_EMAIL)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def devices_for_both_users(test_user, second_user):
|
||||||
|
"""Create one device per user for table/filter tests."""
|
||||||
|
_, pub1 = generate_keypair()
|
||||||
|
_, pub2 = generate_keypair()
|
||||||
|
psk1 = generate_preshared_key()
|
||||||
|
psk2 = generate_preshared_key()
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
d1 = Device(
|
||||||
|
name="admin-laptop",
|
||||||
|
public_key=pub1,
|
||||||
|
preshared_key=psk1,
|
||||||
|
ipv4="10.0.0.10",
|
||||||
|
user_id=test_user.id,
|
||||||
|
)
|
||||||
|
d2 = Device(
|
||||||
|
name="user2-phone",
|
||||||
|
public_key=pub2,
|
||||||
|
preshared_key=psk2,
|
||||||
|
ipv4="10.0.0.11",
|
||||||
|
user_id=second_user.id,
|
||||||
|
)
|
||||||
|
session.add_all([d1, d2])
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
yield d1, d2
|
||||||
|
|
||||||
|
# Cleanup handled by user fixture cascade
|
||||||
|
|
||||||
|
|
||||||
|
async def _go_to_admin_devices(page: Page):
|
||||||
|
"""Login as admin and navigate to admin devices page."""
|
||||||
|
await login(page)
|
||||||
|
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
|
||||||
|
await page.goto(f"{TEST_APP_BASE}/admin/devices")
|
||||||
|
await expect(page.locator("role=main").get_by_text("All Devices")).to_be_visible(timeout=10_000)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_all_devices(page: Page, devices_for_both_users):
|
||||||
|
"""Admin devices page lists devices from all users."""
|
||||||
|
await _go_to_admin_devices(page)
|
||||||
|
await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
|
||||||
|
await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_filter_by_user(page: Page, second_user, devices_for_both_users):
|
||||||
|
"""Filtering by user shows only that user's devices."""
|
||||||
|
await _go_to_admin_devices(page)
|
||||||
|
await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
|
||||||
|
await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Filter to second user
|
||||||
|
await page.locator("label:has-text('Filter by User')").click()
|
||||||
|
await page.get_by_role("option", name=SECOND_USER_EMAIL).click()
|
||||||
|
await page.wait_for_timeout(1000)
|
||||||
|
|
||||||
|
await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
|
||||||
|
await expect(page.get_by_text("admin-laptop")).not_to_be_visible()
|
||||||
|
|
||||||
|
# Filter back to all
|
||||||
|
await page.locator("label:has-text('Filter by User')").click()
|
||||||
|
await page.get_by_role("option", name="All Users").click()
|
||||||
|
await page.wait_for_timeout(1000)
|
||||||
|
|
||||||
|
await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
|
||||||
|
await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_device_with_defaults(page: Page, test_user):
|
||||||
|
"""Create device with all defaults — config dialog appears."""
|
||||||
|
await _go_to_admin_devices(page)
|
||||||
|
await page.get_by_role("button", name="Add Device").click()
|
||||||
|
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
await page.locator("input[aria-label='Device Name']").fill("default-test-device")
|
||||||
|
await page.get_by_role("button", name="Create").click()
|
||||||
|
|
||||||
|
# Config dialog should appear with WireGuard config
|
||||||
|
await expect(page.get_by_text("Config for default-test-device")).to_be_visible(timeout=10_000)
|
||||||
|
await expect(page.get_by_text("[Interface]")).to_be_visible(timeout=5_000)
|
||||||
|
await page.get_by_role("button", name="Close").click()
|
||||||
|
await page.wait_for_timeout(500)
|
||||||
|
|
||||||
|
# Device should be in the table
|
||||||
|
await expect(page.get_by_role("cell", name="default-test-device").first).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_device_with_overrides(page: Page, test_user):
|
||||||
|
"""Create device with custom config overrides."""
|
||||||
|
await _go_to_admin_devices(page)
|
||||||
|
await page.get_by_role("button", name="Add Device").click()
|
||||||
|
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
await page.locator("input[aria-label='Device Name']").fill("custom-override-dev")
|
||||||
|
await page.locator("input[aria-label='Description (optional)']").fill("Custom overrides test")
|
||||||
|
|
||||||
|
# Toggle off DNS default and set custom — Quasar switches use .q-toggle
|
||||||
|
await page.locator(".q-toggle", has_text="Use default DNS").click()
|
||||||
|
dns_input = page.locator("input[aria-label='DNS Servers']")
|
||||||
|
await dns_input.clear()
|
||||||
|
await dns_input.fill("8.8.8.8, 8.8.4.4")
|
||||||
|
|
||||||
|
# Toggle off MTU default and set custom
|
||||||
|
await page.locator(".q-toggle", has_text="Use default MTU").click()
|
||||||
|
mtu_input = page.locator("input[aria-label='MTU']")
|
||||||
|
await mtu_input.clear()
|
||||||
|
await mtu_input.fill("1400")
|
||||||
|
|
||||||
|
await page.get_by_role("button", name="Create").click()
|
||||||
|
|
||||||
|
await expect(page.get_by_text("Config for custom-override-dev")).to_be_visible(timeout=10_000)
|
||||||
|
await page.get_by_role("button", name="Close").click()
|
||||||
|
await page.wait_for_timeout(500)
|
||||||
|
|
||||||
|
await expect(page.get_by_role("cell", name="custom-override-dev").first).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Device).where(Device.name == "custom-override-dev")
|
||||||
|
.order_by(Device.inserted_at.desc()).limit(1)
|
||||||
|
)
|
||||||
|
device = result.scalar_one()
|
||||||
|
assert device.use_default_dns is False
|
||||||
|
assert "8.8.8.8" in device.dns
|
||||||
|
assert device.use_default_mtu is False
|
||||||
|
assert device.mtu == 1400
|
||||||
|
|
||||||
|
|
||||||
|
async def test_edit_device_name_and_description(page: Page, devices_for_both_users):
|
||||||
|
"""Edit a device name and description via the edit dialog."""
|
||||||
|
await _go_to_admin_devices(page)
|
||||||
|
await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Click edit button on admin-laptop row — Quasar slot buttons with icon
|
||||||
|
row = page.locator("tr", has_text="admin-laptop")
|
||||||
|
await row.locator(".q-btn").first.click()
|
||||||
|
|
||||||
|
await expect(page.get_by_text("Edit Device")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
name_input = page.locator(".q-dialog input[aria-label='Device Name']")
|
||||||
|
await name_input.clear()
|
||||||
|
await name_input.fill("admin-laptop-renamed")
|
||||||
|
|
||||||
|
desc_input = page.locator(".q-dialog input[aria-label='Description']")
|
||||||
|
await desc_input.clear()
|
||||||
|
await desc_input.fill("Updated description")
|
||||||
|
|
||||||
|
await page.get_by_role("button", name="Save").click()
|
||||||
|
await expect(page.get_by_text("Device updated")).to_be_visible(timeout=5_000)
|
||||||
|
await expect(page.get_by_text("admin-laptop-renamed")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_device(page: Page, test_user):
|
||||||
|
"""Delete a device — removed from table."""
|
||||||
|
_, pub = generate_keypair()
|
||||||
|
async with async_session() as session:
|
||||||
|
d = Device(
|
||||||
|
name="delete-me-device",
|
||||||
|
public_key=pub,
|
||||||
|
preshared_key=generate_preshared_key(),
|
||||||
|
ipv4="10.0.0.99",
|
||||||
|
user_id=test_user.id,
|
||||||
|
)
|
||||||
|
session.add(d)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
await _go_to_admin_devices(page)
|
||||||
|
await expect(page.get_by_role("cell", name="delete-me-device")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Click the delete (second) button in the row
|
||||||
|
row = page.locator("tr", has_text="delete-me-device")
|
||||||
|
await row.locator(".q-btn").nth(1).click()
|
||||||
|
|
||||||
|
await expect(page.get_by_text("Deleted delete-me-device")).to_be_visible(timeout=5_000)
|
||||||
|
await page.wait_for_timeout(1000)
|
||||||
|
await expect(page.get_by_role("cell", name="delete-me-device")).not_to_be_visible()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_config_dialog_shows_wg_config(page: Page, test_user):
|
||||||
|
"""Config dialog after device creation shows valid WireGuard config."""
|
||||||
|
await _go_to_admin_devices(page)
|
||||||
|
await page.get_by_role("button", name="Add Device").click()
|
||||||
|
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
await page.locator("input[aria-label='Device Name']").fill("config-test-device")
|
||||||
|
await page.get_by_role("button", name="Create").click()
|
||||||
|
|
||||||
|
await expect(page.get_by_text("Config for config-test-device")).to_be_visible(timeout=10_000)
|
||||||
|
await expect(page.get_by_text("[Interface]")).to_be_visible(timeout=5_000)
|
||||||
|
await expect(page.get_by_text("[Peer]")).to_be_visible(timeout=5_000)
|
||||||
|
await expect(page.get_by_text("PrivateKey")).to_be_visible()
|
||||||
|
await expect(page.get_by_role("button", name="Download .conf")).to_be_visible()
|
||||||
|
|
||||||
|
# QR code should be rendered
|
||||||
|
await expect(page.locator(".q-dialog img")).to_be_visible(timeout=5_000)
|
||||||
227
tests/e2e/test_admin_rules.py
Normal file
227
tests/e2e/test_admin_rules.py
Normal file
|
|
@ -0,0 +1,227 @@
|
||||||
|
"""E2E tests for admin firewall rules management page."""
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest_asyncio
|
||||||
|
from playwright.async_api import Page, expect
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
from wiregui.db import async_session
|
||||||
|
from wiregui.models.rule import Rule
|
||||||
|
from wiregui.models.user import User
|
||||||
|
from tests.e2e.conftest import TEST_APP_BASE, TEST_EMAIL, login
|
||||||
|
|
||||||
|
|
||||||
|
async def _cleanup_test_rules():
|
||||||
|
"""Remove rules created by tests (identified by test-specific destinations)."""
|
||||||
|
async with async_session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Rule).where(Rule.destination.in_([
|
||||||
|
"10.99.0.0/16", "10.88.0.0/16", "10.77.0.0/16",
|
||||||
|
"10.66.0.0/16", "10.55.0.0/16",
|
||||||
|
]))
|
||||||
|
)
|
||||||
|
for rule in result.scalars().all():
|
||||||
|
await session.delete(rule)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(autouse=True)
|
||||||
|
async def clean_rules(app_server):
|
||||||
|
"""Clean up test rules before and after each test."""
|
||||||
|
await _cleanup_test_rules()
|
||||||
|
yield
|
||||||
|
await _cleanup_test_rules()
|
||||||
|
|
||||||
|
|
||||||
|
async def _go_to_rules(page: Page):
|
||||||
|
"""Login and navigate to admin rules page."""
|
||||||
|
await login(page)
|
||||||
|
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
|
||||||
|
await page.goto(f"{TEST_APP_BASE}/admin/rules")
|
||||||
|
await expect(page.locator("role=main").get_by_text("Firewall Rules")).to_be_visible(timeout=10_000)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_rule_via_dialog(
|
||||||
|
page: Page, *, action: str = "accept", destination: str = "10.99.0.0/16",
|
||||||
|
protocol: str = "any", port_range: str = "", user: str = "global",
|
||||||
|
):
|
||||||
|
"""Open create dialog and fill in a rule."""
|
||||||
|
await page.get_by_role("button", name="Add Rule").click()
|
||||||
|
await expect(page.get_by_text("New Firewall Rule")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Action select
|
||||||
|
if action != "accept":
|
||||||
|
await page.locator(".q-dialog label:has-text('Action')").click()
|
||||||
|
await page.get_by_role("option", name=action).click()
|
||||||
|
|
||||||
|
# Destination
|
||||||
|
await page.locator(".q-dialog input[aria-label='Destination (CIDR)']").fill(destination)
|
||||||
|
|
||||||
|
# Protocol
|
||||||
|
if protocol != "any":
|
||||||
|
await page.locator(".q-dialog label:has-text('Protocol')").click()
|
||||||
|
await page.get_by_role("option", name=protocol).click()
|
||||||
|
|
||||||
|
# Port range
|
||||||
|
if port_range:
|
||||||
|
await page.locator(".q-dialog input[aria-label='Port Range']").fill(port_range)
|
||||||
|
|
||||||
|
# User
|
||||||
|
if user != "global":
|
||||||
|
await page.locator(".q-dialog label:has-text('Applies to')").click()
|
||||||
|
await page.get_by_role("option", name=user).click()
|
||||||
|
|
||||||
|
await page.get_by_role("button", name="Create").click()
|
||||||
|
await page.wait_for_timeout(500)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_rules_table(page: Page, test_user: User):
|
||||||
|
"""Rules page renders table with correct columns."""
|
||||||
|
# Seed a rule in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
rule = Rule(action="accept", destination="10.99.0.0/16", port_type="tcp",
|
||||||
|
port_range="443", user_id=test_user.id)
|
||||||
|
session.add(rule)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
await _go_to_rules(page)
|
||||||
|
|
||||||
|
await expect(page.get_by_role("cell", name="accept")).to_be_visible(timeout=5_000)
|
||||||
|
await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible()
|
||||||
|
await expect(page.get_by_role("cell", name="tcp")).to_be_visible()
|
||||||
|
await expect(page.get_by_role("cell", name="443")).to_be_visible()
|
||||||
|
await expect(page.get_by_role("cell", name=TEST_EMAIL)).to_be_visible()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_accept_rule_with_cidr(page: Page, test_user: User):
|
||||||
|
"""Create an accept rule with CIDR — verify in table and DB."""
|
||||||
|
await _go_to_rules(page)
|
||||||
|
await _create_rule_via_dialog(page, action="accept", destination="10.88.0.0/16")
|
||||||
|
|
||||||
|
await expect(page.get_by_role("cell", name="10.88.0.0/16")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
result = await session.execute(select(Rule).where(Rule.destination == "10.88.0.0/16"))
|
||||||
|
rule = result.scalar_one()
|
||||||
|
assert rule.action == "accept"
|
||||||
|
assert rule.port_type is None
|
||||||
|
assert rule.port_range is None
|
||||||
|
assert rule.user_id is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_drop_rule_with_tcp_port_range(page: Page, test_user: User):
|
||||||
|
"""Create a drop rule with TCP port range — verify in table and DB."""
|
||||||
|
await _go_to_rules(page)
|
||||||
|
await _create_rule_via_dialog(
|
||||||
|
page, action="drop", destination="10.77.0.0/16",
|
||||||
|
protocol="tcp", port_range="80-443",
|
||||||
|
)
|
||||||
|
|
||||||
|
await expect(page.get_by_role("cell", name="10.77.0.0/16")).to_be_visible(timeout=5_000)
|
||||||
|
await expect(page.get_by_role("cell", name="drop").first).to_be_visible()
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
result = await session.execute(select(Rule).where(Rule.destination == "10.77.0.0/16"))
|
||||||
|
rule = result.scalar_one()
|
||||||
|
assert rule.action == "drop"
|
||||||
|
assert rule.port_type == "tcp"
|
||||||
|
assert rule.port_range == "80-443"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_global_rule(page: Page, test_user: User):
|
||||||
|
"""Create a global rule (no user) — shows 'Global' in table and DB has null user_id."""
|
||||||
|
await _go_to_rules(page)
|
||||||
|
await _create_rule_via_dialog(page, destination="10.66.0.0/16", user="global")
|
||||||
|
|
||||||
|
await expect(page.get_by_role("cell", name="10.66.0.0/16")).to_be_visible(timeout=5_000)
|
||||||
|
await expect(page.get_by_role("cell", name="Global")).to_be_visible()
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
result = await session.execute(select(Rule).where(Rule.destination == "10.66.0.0/16"))
|
||||||
|
rule = result.scalar_one()
|
||||||
|
assert rule.user_id is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_edit_rule_action(page: Page, test_user: User):
|
||||||
|
"""Edit rule action from accept to drop — verify in table and DB."""
|
||||||
|
async with async_session() as session:
|
||||||
|
rule = Rule(action="accept", destination="10.55.0.0/16")
|
||||||
|
session.add(rule)
|
||||||
|
await session.commit()
|
||||||
|
rule_id = rule.id
|
||||||
|
|
||||||
|
await _go_to_rules(page)
|
||||||
|
await expect(page.get_by_role("cell", name="10.55.0.0/16")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Click edit (first button in the row)
|
||||||
|
row = page.locator("tr", has_text="10.55.0.0/16")
|
||||||
|
await row.locator(".q-btn").first.click()
|
||||||
|
await expect(page.get_by_text("Edit Firewall Rule")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Change action to drop
|
||||||
|
await page.locator(".q-dialog label:has-text('Action')").click()
|
||||||
|
await page.get_by_role("option", name="drop").click()
|
||||||
|
|
||||||
|
await page.get_by_role("button", name="Save").click()
|
||||||
|
await expect(page.get_by_text("Rule updated")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
rule = await session.get(Rule, rule_id)
|
||||||
|
assert rule.action == "drop"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_edit_rule_destination(page: Page, test_user: User):
|
||||||
|
"""Edit rule destination — verify in table and DB."""
|
||||||
|
async with async_session() as session:
|
||||||
|
rule = Rule(action="accept", destination="10.99.0.0/16")
|
||||||
|
session.add(rule)
|
||||||
|
await session.commit()
|
||||||
|
rule_id = rule.id
|
||||||
|
|
||||||
|
await _go_to_rules(page)
|
||||||
|
await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
row = page.locator("tr", has_text="10.99.0.0/16")
|
||||||
|
await row.locator(".q-btn").first.click()
|
||||||
|
await expect(page.get_by_text("Edit Firewall Rule")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
dest_input = page.locator(".q-dialog input[aria-label='Destination (CIDR)']")
|
||||||
|
await dest_input.clear()
|
||||||
|
await dest_input.fill("10.88.0.0/16")
|
||||||
|
|
||||||
|
await page.get_by_role("button", name="Save").click()
|
||||||
|
await expect(page.get_by_text("Rule updated")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
rule = await session.get(Rule, rule_id)
|
||||||
|
assert rule.destination == "10.88.0.0/16"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_rule(page: Page, test_user: User):
|
||||||
|
"""Delete a rule — removed from table and DB."""
|
||||||
|
async with async_session() as session:
|
||||||
|
rule = Rule(action="accept", destination="10.99.0.0/16")
|
||||||
|
session.add(rule)
|
||||||
|
await session.commit()
|
||||||
|
rule_id = rule.id
|
||||||
|
|
||||||
|
await _go_to_rules(page)
|
||||||
|
await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Click delete (second button in the row)
|
||||||
|
row = page.locator("tr", has_text="10.99.0.0/16")
|
||||||
|
await row.locator(".q-btn").nth(1).click()
|
||||||
|
await page.wait_for_timeout(1000)
|
||||||
|
|
||||||
|
await expect(page.get_by_role("cell", name="10.99.0.0/16")).not_to_be_visible()
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
rule = await session.get(Rule, rule_id)
|
||||||
|
assert rule is None
|
||||||
281
tests/e2e/test_admin_settings.py
Normal file
281
tests/e2e/test_admin_settings.py
Normal file
|
|
@ -0,0 +1,281 @@
|
||||||
|
"""E2E tests for admin settings page — client defaults, security, OIDC/SAML providers."""
|
||||||
|
|
||||||
|
import pytest_asyncio
|
||||||
|
from playwright.async_api import Page, expect
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
from wiregui.db import async_session
|
||||||
|
from wiregui.models.configuration import Configuration
|
||||||
|
from wiregui.models.user import User
|
||||||
|
from tests.e2e.conftest import TEST_APP_BASE, login
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(autouse=True)
|
||||||
|
async def reset_config(app_server):
|
||||||
|
"""Snapshot config before test, restore after."""
|
||||||
|
async with async_session() as session:
|
||||||
|
c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||||
|
if not c:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
snap = {
|
||||||
|
"default_client_endpoint": c.default_client_endpoint,
|
||||||
|
"default_client_dns": list(c.default_client_dns),
|
||||||
|
"default_client_mtu": c.default_client_mtu,
|
||||||
|
"default_client_persistent_keepalive": c.default_client_persistent_keepalive,
|
||||||
|
"default_client_allowed_ips": list(c.default_client_allowed_ips),
|
||||||
|
"vpn_session_duration": c.vpn_session_duration,
|
||||||
|
"local_auth_enabled": c.local_auth_enabled,
|
||||||
|
"allow_unprivileged_device_management": c.allow_unprivileged_device_management,
|
||||||
|
"allow_unprivileged_device_configuration": c.allow_unprivileged_device_configuration,
|
||||||
|
"openid_connect_providers": list(c.openid_connect_providers or []),
|
||||||
|
"saml_identity_providers": list(c.saml_identity_providers or []),
|
||||||
|
}
|
||||||
|
cid = c.id
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
c = await session.get(Configuration, cid)
|
||||||
|
if c:
|
||||||
|
for k, v in snap.items():
|
||||||
|
setattr(c, k, v)
|
||||||
|
session.add(c)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def _go_to_settings(page: Page):
|
||||||
|
await login(page)
|
||||||
|
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
|
||||||
|
await page.goto(f"{TEST_APP_BASE}/admin/settings")
|
||||||
|
await expect(page.get_by_text("Default Client Configuration")).to_be_visible(timeout=10_000)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Client Defaults ---
|
||||||
|
|
||||||
|
|
||||||
|
async def test_save_client_defaults(page: Page, test_user: User):
|
||||||
|
"""Save endpoint, DNS, MTU, keepalive, allowed IPs — verify persists in DB."""
|
||||||
|
await _go_to_settings(page)
|
||||||
|
|
||||||
|
endpoint = page.locator("input[aria-label='Endpoint']")
|
||||||
|
await endpoint.clear()
|
||||||
|
await endpoint.fill("vpn.test.local")
|
||||||
|
|
||||||
|
dns = page.locator("input[aria-label='DNS Servers']")
|
||||||
|
await dns.clear()
|
||||||
|
await dns.fill("9.9.9.9, 149.112.112.112")
|
||||||
|
|
||||||
|
mtu = page.locator("input[aria-label='MTU']")
|
||||||
|
await mtu.clear()
|
||||||
|
await mtu.fill("1420")
|
||||||
|
|
||||||
|
keepalive = page.locator("input[aria-label='Persistent Keepalive']")
|
||||||
|
await keepalive.clear()
|
||||||
|
await keepalive.fill("30")
|
||||||
|
|
||||||
|
allowed = page.locator("input[aria-label='Allowed IPs']")
|
||||||
|
await allowed.clear()
|
||||||
|
await allowed.fill("10.0.0.0/8, 192.168.0.0/16")
|
||||||
|
|
||||||
|
await page.get_by_role("button", name="Save Defaults").click()
|
||||||
|
await expect(page.get_by_text("Client defaults saved")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||||
|
assert c.default_client_endpoint == "vpn.test.local"
|
||||||
|
assert c.default_client_dns == ["9.9.9.9", "149.112.112.112"]
|
||||||
|
assert c.default_client_mtu == 1420
|
||||||
|
assert c.default_client_persistent_keepalive == 30
|
||||||
|
assert c.default_client_allowed_ips == ["10.0.0.0/8", "192.168.0.0/16"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_defaults_persist_on_reload(page: Page, test_user: User):
|
||||||
|
"""Saved defaults are reflected after page reload."""
|
||||||
|
# Set values via DB
|
||||||
|
async with async_session() as session:
|
||||||
|
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||||
|
c.default_client_endpoint = "reload-test.vpn"
|
||||||
|
c.default_client_dns = ["8.8.8.8"]
|
||||||
|
c.default_client_mtu = 1500
|
||||||
|
c.default_client_persistent_keepalive = 15
|
||||||
|
c.default_client_allowed_ips = ["172.16.0.0/12"]
|
||||||
|
session.add(c)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
await _go_to_settings(page)
|
||||||
|
|
||||||
|
await expect(page.locator("input[aria-label='Endpoint']")).to_have_value("reload-test.vpn")
|
||||||
|
await expect(page.locator("input[aria-label='DNS Servers']")).to_have_value("8.8.8.8")
|
||||||
|
await expect(page.locator("input[aria-label='MTU']")).to_have_value("1500")
|
||||||
|
await expect(page.locator("input[aria-label='Persistent Keepalive']")).to_have_value("15")
|
||||||
|
await expect(page.locator("input[aria-label='Allowed IPs']")).to_have_value("172.16.0.0/12")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Security ---
|
||||||
|
|
||||||
|
|
||||||
|
async def test_save_security_local_auth_toggle(page: Page, test_user: User):
|
||||||
|
"""Toggle local auth off — verify in DB."""
|
||||||
|
await _go_to_settings(page)
|
||||||
|
|
||||||
|
# Find the local auth switch and toggle it off
|
||||||
|
switch = page.locator(".q-toggle", has_text="Local Authentication")
|
||||||
|
await switch.click()
|
||||||
|
|
||||||
|
await page.get_by_role("button", name="Save Security Settings").click()
|
||||||
|
await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||||
|
assert c.local_auth_enabled is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_save_vpn_session_duration(page: Page, test_user: User):
|
||||||
|
"""Change VPN session duration — verify in DB."""
|
||||||
|
await _go_to_settings(page)
|
||||||
|
|
||||||
|
await page.locator("label:has-text('VPN Session Duration')").click()
|
||||||
|
await page.get_by_role("option", name="Every Day").click()
|
||||||
|
|
||||||
|
await page.get_by_role("button", name="Save Security Settings").click()
|
||||||
|
await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||||
|
assert c.vpn_session_duration == 86400
|
||||||
|
|
||||||
|
|
||||||
|
async def test_save_unprivileged_toggles(page: Page, test_user: User):
|
||||||
|
"""Toggle unprivileged device management/configuration — verify in DB."""
|
||||||
|
await _go_to_settings(page)
|
||||||
|
|
||||||
|
await page.locator(".q-toggle", has_text="Allow Unprivileged Device Management").click()
|
||||||
|
await page.locator(".q-toggle", has_text="Allow Unprivileged Device Configuration").click()
|
||||||
|
|
||||||
|
await page.get_by_role("button", name="Save Security Settings").click()
|
||||||
|
await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||||
|
# Toggled from default (True) to False
|
||||||
|
assert c.allow_unprivileged_device_management is False
|
||||||
|
assert c.allow_unprivileged_device_configuration is False
|
||||||
|
|
||||||
|
|
||||||
|
# --- OIDC Providers ---
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_oidc_provider(page: Page, test_user: User):
|
||||||
|
"""Add an OIDC provider — appears in table and DB."""
|
||||||
|
await _go_to_settings(page)
|
||||||
|
|
||||||
|
await page.get_by_role("button", name="Add OIDC Provider").click()
|
||||||
|
await expect(page.get_by_text("OIDC Provider", exact=True)).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
await page.locator(".q-dialog input[aria-label='Config ID']").fill("e2e-test-oidc")
|
||||||
|
await page.locator(".q-dialog input[aria-label='Label']").fill("E2E Test IdP")
|
||||||
|
await page.locator(".q-dialog input[aria-label='Client ID']").fill("test-client-id")
|
||||||
|
await page.locator(".q-dialog input[aria-label='Client Secret']").fill("test-client-secret")
|
||||||
|
await page.locator(".q-dialog input[aria-label='Discovery Document URI']").fill("https://idp.test/.well-known/openid-configuration")
|
||||||
|
|
||||||
|
await page.locator(".q-dialog").get_by_role("button", name="Save").click()
|
||||||
|
await expect(page.get_by_text("OIDC provider 'E2E Test IdP' saved")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
await expect(page.get_by_role("cell", name="e2e-test-oidc")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||||
|
provider = next((p for p in c.openid_connect_providers if p["id"] == "e2e-test-oidc"), None)
|
||||||
|
assert provider is not None
|
||||||
|
assert provider["label"] == "E2E Test IdP"
|
||||||
|
assert provider["client_id"] == "test-client-id"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_oidc_provider(page: Page, test_user: User):
|
||||||
|
"""Delete an OIDC provider — removed from table and DB."""
|
||||||
|
# Seed a provider
|
||||||
|
async with async_session() as session:
|
||||||
|
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||||
|
providers = list(c.openid_connect_providers or [])
|
||||||
|
providers.append({
|
||||||
|
"id": "delete-me-oidc", "label": "Delete Me", "scope": "openid",
|
||||||
|
"client_id": "x", "client_secret": "x",
|
||||||
|
"discovery_document_uri": "https://x/.well-known/openid-configuration",
|
||||||
|
})
|
||||||
|
c.openid_connect_providers = providers
|
||||||
|
session.add(c)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
await _go_to_settings(page)
|
||||||
|
await expect(page.get_by_role("cell", name="delete-me-oidc")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
row = page.locator("tr", has_text="delete-me-oidc")
|
||||||
|
await row.locator(".q-btn").first.click()
|
||||||
|
|
||||||
|
await expect(page.get_by_text("OIDC provider deleted")).to_be_visible(timeout=5_000)
|
||||||
|
await page.wait_for_timeout(500)
|
||||||
|
await expect(page.get_by_role("cell", name="delete-me-oidc")).not_to_be_visible()
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||||
|
assert not any(p["id"] == "delete-me-oidc" for p in c.openid_connect_providers)
|
||||||
|
|
||||||
|
|
||||||
|
# --- SAML Providers ---
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_saml_provider(page: Page, test_user: User):
|
||||||
|
"""Add a SAML provider — appears in table and DB."""
|
||||||
|
await _go_to_settings(page)
|
||||||
|
|
||||||
|
await page.get_by_role("button", name="Add SAML Provider").click()
|
||||||
|
await expect(page.get_by_text("SAML Identity Provider", exact=True)).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
await page.locator(".q-dialog input[aria-label='Config ID']").fill("e2e-test-saml")
|
||||||
|
await page.locator(".q-dialog input[aria-label='Label']").fill("E2E SAML IdP")
|
||||||
|
await page.locator(".q-dialog textarea").fill("<EntityDescriptor>test</EntityDescriptor>")
|
||||||
|
|
||||||
|
await page.locator(".q-dialog").get_by_role("button", name="Save").click()
|
||||||
|
await expect(page.get_by_text("SAML provider 'E2E SAML IdP' saved")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
await expect(page.get_by_role("cell", name="e2e-test-saml")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||||
|
provider = next((p for p in c.saml_identity_providers if p["id"] == "e2e-test-saml"), None)
|
||||||
|
assert provider is not None
|
||||||
|
assert provider["label"] == "E2E SAML IdP"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_saml_provider(page: Page, test_user: User):
|
||||||
|
"""Delete a SAML provider — removed from table and DB."""
|
||||||
|
async with async_session() as session:
|
||||||
|
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||||
|
providers = list(c.saml_identity_providers or [])
|
||||||
|
providers.append({
|
||||||
|
"id": "delete-me-saml", "label": "Delete Me SAML",
|
||||||
|
"metadata": "<EntityDescriptor/>",
|
||||||
|
})
|
||||||
|
c.saml_identity_providers = providers
|
||||||
|
session.add(c)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
await _go_to_settings(page)
|
||||||
|
await expect(page.get_by_role("cell", name="delete-me-saml")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
row = page.locator("tr", has_text="delete-me-saml")
|
||||||
|
await row.locator(".q-btn").first.click()
|
||||||
|
|
||||||
|
await expect(page.get_by_text("SAML provider deleted")).to_be_visible(timeout=5_000)
|
||||||
|
await page.wait_for_timeout(500)
|
||||||
|
await expect(page.get_by_role("cell", name="delete-me-saml")).not_to_be_visible()
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||||
|
assert not any(p["id"] == "delete-me-saml" for p in c.saml_identity_providers)
|
||||||
41
tests/e2e/test_magic_link_page.py
Normal file
41
tests/e2e/test_magic_link_page.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
"""E2E tests for magic link request page."""
|
||||||
|
|
||||||
|
from playwright.async_api import Page, expect
|
||||||
|
|
||||||
|
from tests.e2e.conftest import TEST_APP_BASE, TEST_EMAIL
|
||||||
|
from wiregui.models.user import User
|
||||||
|
|
||||||
|
|
||||||
|
async def test_magic_link_page_renders(page: Page, test_user: User):
|
||||||
|
"""Magic link request page renders with email input and submit button."""
|
||||||
|
await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
|
||||||
|
await page.wait_for_load_state("networkidle")
|
||||||
|
await expect(page.get_by_text("Sign in with magic link")).to_be_visible(timeout=10_000)
|
||||||
|
await expect(page.locator("input[aria-label='Email']")).to_be_visible()
|
||||||
|
await expect(page.get_by_role("button", name="Send Magic Link")).to_be_visible()
|
||||||
|
await expect(page.get_by_role("button", name="Back to login")).to_be_visible()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_magic_link_shows_success_on_submit(page: Page, test_user: User):
|
||||||
|
"""Submitting an email shows success message (regardless of whether account exists)."""
|
||||||
|
await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
|
||||||
|
await page.wait_for_load_state("networkidle")
|
||||||
|
await page.locator("input[aria-label='Email']").fill(TEST_EMAIL)
|
||||||
|
await page.get_by_role("button", name="Send Magic Link").click()
|
||||||
|
await expect(page.get_by_text("a sign-in link has been sent")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_magic_link_empty_email_shows_error(page: Page, test_user: User):
|
||||||
|
"""Submitting without email shows error."""
|
||||||
|
await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
|
||||||
|
await page.wait_for_load_state("networkidle")
|
||||||
|
await page.get_by_role("button", name="Send Magic Link").click()
|
||||||
|
await expect(page.get_by_text("Enter your email")).to_be_visible(timeout=5_000)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_magic_link_back_to_login(page: Page, test_user: User):
|
||||||
|
"""Back to login button navigates to login page."""
|
||||||
|
await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
|
||||||
|
await page.wait_for_load_state("networkidle")
|
||||||
|
await page.get_by_role("button", name="Back to login").click()
|
||||||
|
await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible(timeout=10_000)
|
||||||
111
tests/e2e/test_mfa_login.py
Normal file
111
tests/e2e/test_mfa_login.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
"""E2E tests for MFA login flow — login with TOTP redirects to /mfa challenge page."""
|
||||||
|
|
||||||
|
import pyotp
|
||||||
|
import pytest_asyncio
|
||||||
|
from playwright.async_api import Page, expect
|
||||||
|
|
||||||
|
from wiregui.auth.mfa import generate_totp_secret
|
||||||
|
from wiregui.auth.passwords import hash_password
|
||||||
|
from wiregui.db import async_session
|
||||||
|
from wiregui.models.mfa_method import MFAMethod
|
||||||
|
from wiregui.models.user import User
|
||||||
|
from tests.e2e.conftest import (
|
||||||
|
FAKE_SERVER_KEY,
|
||||||
|
TEST_APP_BASE,
|
||||||
|
TEST_PASSWORD,
|
||||||
|
_cleanup_user_by_email,
|
||||||
|
)
|
||||||
|
|
||||||
|
MFA_EMAIL = "e2e-mfa@example.com"
|
||||||
|
MFA_PASSWORD = "mfapass123"
|
||||||
|
TOTP_SECRET = generate_totp_secret()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mfa_user(app_server):
|
||||||
|
"""Create a user with a TOTP MFA method, clean up after."""
|
||||||
|
await _cleanup_user_by_email(MFA_EMAIL)
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
from sqlmodel import select
|
||||||
|
from wiregui.models.configuration import Configuration
|
||||||
|
|
||||||
|
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||||
|
if config:
|
||||||
|
if not config.server_public_key:
|
||||||
|
config.server_public_key = FAKE_SERVER_KEY
|
||||||
|
session.add(config)
|
||||||
|
else:
|
||||||
|
config = Configuration(server_public_key=FAKE_SERVER_KEY)
|
||||||
|
session.add(config)
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
email=MFA_EMAIL,
|
||||||
|
password_hash=hash_password(MFA_PASSWORD),
|
||||||
|
role="admin",
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
|
||||||
|
mfa = MFAMethod(
|
||||||
|
name="Test TOTP",
|
||||||
|
type="totp",
|
||||||
|
payload={"secret": TOTP_SECRET},
|
||||||
|
user_id=user.id,
|
||||||
|
)
|
||||||
|
session.add(mfa)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
yield user
|
||||||
|
|
||||||
|
await _cleanup_user_by_email(MFA_EMAIL)
|
||||||
|
|
||||||
|
|
||||||
|
async def _login_mfa_user(page: Page):
|
||||||
|
"""Fill login form for the MFA user and submit."""
|
||||||
|
await page.goto(f"{TEST_APP_BASE}/login")
|
||||||
|
await page.wait_for_load_state("networkidle")
|
||||||
|
await page.locator("input[aria-label='Email']").fill(MFA_EMAIL)
|
||||||
|
await page.locator("input[aria-label='Password']").fill(MFA_PASSWORD)
|
||||||
|
await page.get_by_role("button", name="Sign in", exact=True).click()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_mfa_login_redirects_to_challenge(page: Page, mfa_user: User):
|
||||||
|
"""Login with MFA-enabled user redirects to /mfa challenge page."""
|
||||||
|
await _login_mfa_user(page)
|
||||||
|
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
|
||||||
|
await expect(page.locator("input[aria-label='Authentication Code']")).to_be_visible()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_mfa_valid_totp_completes_login(page: Page, mfa_user: User):
|
||||||
|
"""Entering a valid TOTP code on /mfa completes login."""
|
||||||
|
await _login_mfa_user(page)
|
||||||
|
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
|
||||||
|
|
||||||
|
code = pyotp.TOTP(TOTP_SECRET).now()
|
||||||
|
await page.locator("input[aria-label='Authentication Code']").fill(code)
|
||||||
|
await page.get_by_role("button", name="Verify").click()
|
||||||
|
|
||||||
|
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_mfa_invalid_code_shows_error(page: Page, mfa_user: User):
|
||||||
|
"""Entering an invalid TOTP code shows error and stays on /mfa."""
|
||||||
|
await _login_mfa_user(page)
|
||||||
|
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
|
||||||
|
|
||||||
|
await page.locator("input[aria-label='Authentication Code']").fill("000000")
|
||||||
|
await page.get_by_role("button", name="Verify").click()
|
||||||
|
|
||||||
|
await expect(page.get_by_text("Invalid code")).to_be_visible(timeout=5_000)
|
||||||
|
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_mfa_cancel_returns_to_login(page: Page, mfa_user: User):
|
||||||
|
"""Clicking Cancel on /mfa clears session and returns to login."""
|
||||||
|
await _login_mfa_user(page)
|
||||||
|
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
|
||||||
|
|
||||||
|
await page.get_by_role("button", name="Cancel").click()
|
||||||
|
await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible(timeout=10_000)
|
||||||
177
tests/e2e/test_saml_login.py
Normal file
177
tests/e2e/test_saml_login.py
Normal file
|
|
@ -0,0 +1,177 @@
|
||||||
|
"""E2E tests for SAML authentication — mock SimpleSAMLphp IdP.
|
||||||
|
|
||||||
|
Requires mock-saml service running (docker compose up -d mock-saml).
|
||||||
|
IdP metadata: http://localhost:8080/simplesaml/saml2/idp/metadata.php
|
||||||
|
Test users: user1/user1pass, user2/user2pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from playwright.async_api import Page, expect
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
from wiregui.db import async_session
|
||||||
|
from wiregui.models.configuration import Configuration
|
||||||
|
from wiregui.models.user import User
|
||||||
|
from tests.e2e.conftest import FAKE_SERVER_KEY, _cleanup_user_by_email
|
||||||
|
|
||||||
|
MOCK_SAML_HOST = os.environ.get("MOCK_SAML_HOST", "localhost")
|
||||||
|
MOCK_SAML_METADATA_URL = f"http://{MOCK_SAML_HOST}:8080/simplesaml/saml2/idp/metadata.php"
|
||||||
|
|
||||||
|
# Separate app port for SAML tests (like OIDC IdP tests)
|
||||||
|
SAML_APP_PORT = 13003
|
||||||
|
SAML_APP_BASE = f"http://localhost:{SAML_APP_PORT}"
|
||||||
|
|
||||||
|
SAML_TEST_EMAIL = "user1@example.com"
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_idp_metadata() -> str:
|
||||||
|
"""Fetch IdP metadata XML from the mock SAML server."""
|
||||||
|
try:
|
||||||
|
r = httpx.get(MOCK_SAML_METADATA_URL, timeout=5)
|
||||||
|
r.raise_for_status()
|
||||||
|
return r.text
|
||||||
|
except Exception:
|
||||||
|
pytest.skip(f"Mock SAML IdP not available at {MOCK_SAML_METADATA_URL}")
|
||||||
|
|
||||||
|
|
||||||
|
def _saml_provider_config(metadata: str) -> dict:
|
||||||
|
return {
|
||||||
|
"id": "test-saml",
|
||||||
|
"label": "Sign in with Mock SAML",
|
||||||
|
"metadata": metadata,
|
||||||
|
"sign_requests": False,
|
||||||
|
"sign_metadata": False,
|
||||||
|
"signed_assertion_in_resp": False,
|
||||||
|
"signed_envelopes_in_resp": False,
|
||||||
|
"auto_create_users": True,
|
||||||
|
"strict": False, # Relaxed for test IdP with expired certs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="module")
|
||||||
|
async def saml_metadata():
|
||||||
|
return _fetch_idp_metadata()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def app_with_saml(saml_metadata):
|
||||||
|
"""Start a WireGUI instance with a SAML provider seeded in the DB."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Seed the SAML provider config into the database
|
||||||
|
async def _seed():
|
||||||
|
async with async_session() as session:
|
||||||
|
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||||
|
if config is None:
|
||||||
|
config = Configuration(server_public_key=FAKE_SERVER_KEY)
|
||||||
|
session.add(config)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
providers = list(config.saml_identity_providers or [])
|
||||||
|
providers = [p for p in providers if p.get("id") != "test-saml"]
|
||||||
|
providers.append(_saml_provider_config(saml_metadata))
|
||||||
|
config.saml_identity_providers = providers
|
||||||
|
session.add(config)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(_seed())
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["WG_LOG_TO_FILE"] = "false"
|
||||||
|
env["WG_PORT"] = str(SAML_APP_PORT)
|
||||||
|
env["WG_EXTERNAL_URL"] = SAML_APP_BASE
|
||||||
|
env.pop("PYTEST_CURRENT_TEST", None)
|
||||||
|
env.pop("NICEGUI_SCREEN_TEST_PORT", None)
|
||||||
|
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
["uv", "run", "python", "-m", "wiregui.main"],
|
||||||
|
env=env,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(30):
|
||||||
|
try:
|
||||||
|
r = httpx.get(f"{SAML_APP_BASE}/api/health", timeout=1)
|
||||||
|
if r.status_code == 200:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
proc.kill()
|
||||||
|
out = proc.stdout.read().decode() if proc.stdout else ""
|
||||||
|
pytest.fail(f"App did not start in time. Output:\n{out}")
|
||||||
|
|
||||||
|
yield proc
|
||||||
|
|
||||||
|
proc.terminate()
|
||||||
|
proc.wait(timeout=10)
|
||||||
|
|
||||||
|
# Clean up seeded provider and test user
|
||||||
|
async def _cleanup():
|
||||||
|
await _cleanup_user_by_email(SAML_TEST_EMAIL)
|
||||||
|
async with async_session() as session:
|
||||||
|
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||||
|
if config:
|
||||||
|
config.saml_identity_providers = [
|
||||||
|
p for p in (config.saml_identity_providers or []) if p.get("id") != "test-saml"
|
||||||
|
]
|
||||||
|
session.add(config)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(_cleanup())
|
||||||
|
|
||||||
|
|
||||||
|
async def test_saml_button_visible_on_login(app_with_saml, page: Page):
|
||||||
|
"""Login page shows SAML provider button."""
|
||||||
|
await page.goto(f"{SAML_APP_BASE}/login")
|
||||||
|
await page.wait_for_load_state("networkidle")
|
||||||
|
await expect(page.get_by_text("Sign in with Mock SAML")).to_be_visible(timeout=10_000)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_saml_redirect_to_idp(app_with_saml, page: Page):
|
||||||
|
"""Clicking SAML login redirects to the SimpleSAMLphp IdP login page."""
|
||||||
|
await page.goto(f"{SAML_APP_BASE}/auth/saml/test-saml")
|
||||||
|
# Should redirect to the SimpleSAMLphp SSO service
|
||||||
|
await page.wait_for_url(f"**{MOCK_SAML_HOST}:8080/simplesaml/**", timeout=10_000)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_saml_sp_metadata_endpoint(app_with_saml, page: Page):
|
||||||
|
"""SP metadata endpoint returns valid XML."""
|
||||||
|
response = await page.request.get(f"{SAML_APP_BASE}/auth/saml/test-saml/metadata")
|
||||||
|
assert response.status == 200
|
||||||
|
body = await response.text()
|
||||||
|
assert "EntityDescriptor" in body
|
||||||
|
assert "AssertionConsumerService" in body
|
||||||
|
|
||||||
|
|
||||||
|
async def test_full_saml_login_flow(app_with_saml, page: Page):
|
||||||
|
"""Full SAML SSO flow: app → IdP login → callback → authenticated."""
|
||||||
|
await page.goto(f"{SAML_APP_BASE}/auth/saml/test-saml")
|
||||||
|
await page.wait_for_url(f"**{MOCK_SAML_HOST}:8080/simplesaml/**", timeout=10_000)
|
||||||
|
|
||||||
|
# SimpleSAMLphp login form
|
||||||
|
await page.locator("input[name='username']").fill("user1")
|
||||||
|
await page.locator("input[name='password']").fill("password")
|
||||||
|
await page.locator("button[type='submit'], input[type='submit']").first.click()
|
||||||
|
|
||||||
|
# Should redirect back to the app after SAML response
|
||||||
|
await page.wait_for_url(f"{SAML_APP_BASE}/**", timeout=15_000)
|
||||||
|
await page.wait_for_load_state("networkidle")
|
||||||
|
await page.wait_for_timeout(3000)
|
||||||
|
|
||||||
|
assert "/login" not in page.url, f"SAML login failed — still on login page: {page.url}"
|
||||||
|
|
||||||
|
# Verify user was auto-created in DB
|
||||||
|
async with async_session() as session:
|
||||||
|
result = await session.execute(select(User).where(User.email == SAML_TEST_EMAIL))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
assert user is not None, f"Expected user {SAML_TEST_EMAIL} to be auto-created"
|
||||||
|
assert user.last_signed_in_method == "saml:test-saml"
|
||||||
263
tests/test_api_deps.py
Normal file
263
tests/test_api_deps.py
Normal file
|
|
@ -0,0 +1,263 @@
|
||||||
|
"""Tests for API dependency injection — Bearer token auth and admin guard."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from datetime import timedelta
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from wiregui.auth.api_token import generate_api_token
|
||||||
|
from wiregui.auth.passwords import hash_password
|
||||||
|
from wiregui.db import async_session
|
||||||
|
from wiregui.models.api_token import ApiToken
|
||||||
|
from wiregui.models.user import User
|
||||||
|
from wiregui.utils.time import utcnow
|
||||||
|
|
||||||
|
|
||||||
|
# ========== resolve_bearer_token ==========
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resolve_valid_token():
|
||||||
|
"""Valid, non-expired token resolves to user."""
|
||||||
|
from wiregui.auth.api_token import resolve_bearer_token
|
||||||
|
|
||||||
|
plaintext, token_hash = generate_api_token()
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
user = User(email="api-test@test.com", password_hash=hash_password("x"), role="admin")
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
|
||||||
|
api_token = ApiToken(
|
||||||
|
token_hash=token_hash,
|
||||||
|
user_id=user.id,
|
||||||
|
expires_at=utcnow() + timedelta(hours=1),
|
||||||
|
)
|
||||||
|
session.add(api_token)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with async_session() as session:
|
||||||
|
resolved = await resolve_bearer_token(session, plaintext)
|
||||||
|
assert resolved is not None
|
||||||
|
assert resolved.id == user.id
|
||||||
|
assert resolved.email == "api-test@test.com"
|
||||||
|
finally:
|
||||||
|
async with async_session() as session:
|
||||||
|
await session.delete(await session.get(ApiToken, api_token.id))
|
||||||
|
await session.delete(await session.get(User, user.id))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resolve_expired_token():
|
||||||
|
"""Expired token returns None."""
|
||||||
|
from wiregui.auth.api_token import resolve_bearer_token
|
||||||
|
|
||||||
|
plaintext, token_hash = generate_api_token()
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
user = User(email="api-expired@test.com", password_hash=hash_password("x"), role="admin")
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
|
||||||
|
api_token = ApiToken(
|
||||||
|
token_hash=token_hash,
|
||||||
|
user_id=user.id,
|
||||||
|
expires_at=utcnow() - timedelta(hours=1), # already expired
|
||||||
|
)
|
||||||
|
session.add(api_token)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with async_session() as session:
|
||||||
|
resolved = await resolve_bearer_token(session, plaintext)
|
||||||
|
assert resolved is None
|
||||||
|
finally:
|
||||||
|
async with async_session() as session:
|
||||||
|
await session.delete(await session.get(ApiToken, api_token.id))
|
||||||
|
await session.delete(await session.get(User, user.id))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resolve_invalid_token():
|
||||||
|
"""Nonexistent token returns None."""
|
||||||
|
from wiregui.auth.api_token import resolve_bearer_token
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
resolved = await resolve_bearer_token(session, "totally-bogus-token")
|
||||||
|
assert resolved is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resolve_token_disabled_user():
|
||||||
|
"""Token for disabled user returns None."""
|
||||||
|
from wiregui.auth.api_token import resolve_bearer_token
|
||||||
|
|
||||||
|
plaintext, token_hash = generate_api_token()
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
user = User(
|
||||||
|
email="api-disabled@test.com", password_hash=hash_password("x"),
|
||||||
|
role="admin", disabled_at=utcnow(),
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
|
||||||
|
api_token = ApiToken(
|
||||||
|
token_hash=token_hash,
|
||||||
|
user_id=user.id,
|
||||||
|
expires_at=utcnow() + timedelta(hours=1),
|
||||||
|
)
|
||||||
|
session.add(api_token)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with async_session() as session:
|
||||||
|
resolved = await resolve_bearer_token(session, plaintext)
|
||||||
|
assert resolved is None
|
||||||
|
finally:
|
||||||
|
async with async_session() as session:
|
||||||
|
await session.delete(await session.get(ApiToken, api_token.id))
|
||||||
|
await session.delete(await session.get(User, user.id))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resolve_token_no_expiry():
|
||||||
|
"""Token without expires_at (never expires) resolves successfully."""
|
||||||
|
from wiregui.auth.api_token import resolve_bearer_token
|
||||||
|
|
||||||
|
plaintext, token_hash = generate_api_token()
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
user = User(email="api-noexp@test.com", password_hash=hash_password("x"), role="admin")
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
|
||||||
|
api_token = ApiToken(
|
||||||
|
token_hash=token_hash,
|
||||||
|
user_id=user.id,
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
session.add(api_token)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with async_session() as session:
|
||||||
|
resolved = await resolve_bearer_token(session, plaintext)
|
||||||
|
assert resolved is not None
|
||||||
|
assert resolved.id == user.id
|
||||||
|
finally:
|
||||||
|
async with async_session() as session:
|
||||||
|
await session.delete(await session.get(ApiToken, api_token.id))
|
||||||
|
await session.delete(await session.get(User, user.id))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== get_current_api_user (via FastAPI deps) ==========
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_current_api_user_missing_header():
|
||||||
|
"""Missing Authorization header raises 401."""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from wiregui.api.deps import get_current_api_user
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.headers = {}
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await get_current_api_user(request, session=AsyncMock())
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
assert "Missing" in exc_info.value.detail
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_current_api_user_bad_scheme():
|
||||||
|
"""Non-Bearer auth scheme raises 401."""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from wiregui.api.deps import get_current_api_user
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.headers = {"Authorization": "Basic dXNlcjpwYXNz"}
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await get_current_api_user(request, session=AsyncMock())
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_current_api_user_invalid_token():
|
||||||
|
"""Valid Bearer scheme but bogus token raises 401."""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from wiregui.api.deps import get_current_api_user
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.headers = {"Authorization": "Bearer bogus-token-value"}
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await get_current_api_user(request, session=session)
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
assert "Invalid" in exc_info.value.detail
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_current_api_user_valid_token():
|
||||||
|
"""Valid Bearer token resolves to user."""
|
||||||
|
from wiregui.api.deps import get_current_api_user
|
||||||
|
|
||||||
|
plaintext, token_hash = generate_api_token()
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
user = User(email="api-dep-test@test.com", password_hash=hash_password("x"), role="admin")
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
|
||||||
|
api_token = ApiToken(
|
||||||
|
token_hash=token_hash,
|
||||||
|
user_id=user.id,
|
||||||
|
expires_at=utcnow() + timedelta(hours=1),
|
||||||
|
)
|
||||||
|
session.add(api_token)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
request = MagicMock()
|
||||||
|
request.headers = {"Authorization": f"Bearer {plaintext}"}
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
resolved = await get_current_api_user(request, session=session)
|
||||||
|
assert resolved.id == user.id
|
||||||
|
finally:
|
||||||
|
async with async_session() as session:
|
||||||
|
await session.delete(await session.get(ApiToken, api_token.id))
|
||||||
|
await session.delete(await session.get(User, user.id))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== require_admin ==========
|
||||||
|
|
||||||
|
|
||||||
|
async def test_require_admin_allows_admin():
|
||||||
|
"""Admin user passes require_admin."""
|
||||||
|
from wiregui.api.deps import require_admin
|
||||||
|
|
||||||
|
admin_user = MagicMock(spec=User)
|
||||||
|
admin_user.role = "admin"
|
||||||
|
result = await require_admin(user=admin_user)
|
||||||
|
assert result == admin_user
|
||||||
|
|
||||||
|
|
||||||
|
async def test_require_admin_rejects_unprivileged():
|
||||||
|
"""Non-admin user gets 403."""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from wiregui.api.deps import require_admin
|
||||||
|
|
||||||
|
regular_user = MagicMock(spec=User)
|
||||||
|
regular_user.role = "unprivileged"
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await require_admin(user=regular_user)
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
assert "Admin" in exc_info.value.detail
|
||||||
206
tests/test_firewall_extended.py
Normal file
206
tests/test_firewall_extended.py
Normal file
|
|
@ -0,0 +1,206 @@
|
||||||
|
"""Extended firewall tests — _nft/_nft_batch error handling, add_device_jump_rule edge cases, policies."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from wiregui.services.firewall import (
|
||||||
|
_nft,
|
||||||
|
_nft_batch,
|
||||||
|
add_device_jump_rule,
|
||||||
|
setup_base_tables,
|
||||||
|
setup_masquerade,
|
||||||
|
apply_peer_to_peer_policy,
|
||||||
|
apply_lan_to_peers_policy,
|
||||||
|
get_ruleset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== _nft error handling ==========
|
||||||
|
|
||||||
|
|
||||||
|
@patch("asyncio.create_subprocess_exec")
|
||||||
|
async def test_nft_raises_on_failure(mock_exec):
|
||||||
|
"""_nft raises RuntimeError on non-zero exit code."""
|
||||||
|
mock_proc = AsyncMock()
|
||||||
|
mock_proc.communicate.return_value = (b"", b"nft: error message")
|
||||||
|
mock_proc.returncode = 1
|
||||||
|
mock_exec.return_value = mock_proc
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="nft.*failed"):
|
||||||
|
await _nft("list ruleset")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("asyncio.create_subprocess_exec")
|
||||||
|
async def test_nft_returns_stdout_on_success(mock_exec):
|
||||||
|
"""_nft returns stdout on success."""
|
||||||
|
mock_proc = AsyncMock()
|
||||||
|
mock_proc.communicate.return_value = (b"table inet wiregui {}", b"")
|
||||||
|
mock_proc.returncode = 0
|
||||||
|
mock_exec.return_value = mock_proc
|
||||||
|
|
||||||
|
result = await _nft("list ruleset")
|
||||||
|
assert "wiregui" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ========== _nft_batch error handling ==========
|
||||||
|
|
||||||
|
|
||||||
|
@patch("asyncio.create_subprocess_exec")
|
||||||
|
async def test_nft_batch_raises_on_failure(mock_exec):
|
||||||
|
"""_nft_batch raises RuntimeError on non-zero exit code."""
|
||||||
|
mock_proc = AsyncMock()
|
||||||
|
mock_proc.communicate.return_value = (b"", b"Error: syntax error")
|
||||||
|
mock_proc.returncode = 1
|
||||||
|
mock_exec.return_value = mock_proc
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="nft batch failed"):
|
||||||
|
await _nft_batch(["add table inet wiregui"])
|
||||||
|
|
||||||
|
|
||||||
|
@patch("asyncio.create_subprocess_exec")
|
||||||
|
async def test_nft_batch_sends_commands_via_stdin(mock_exec):
|
||||||
|
"""_nft_batch sends all commands via stdin to nft -f -."""
|
||||||
|
mock_proc = AsyncMock()
|
||||||
|
mock_proc.communicate.return_value = (b"", b"")
|
||||||
|
mock_proc.returncode = 0
|
||||||
|
mock_exec.return_value = mock_proc
|
||||||
|
|
||||||
|
cmds = ["add table inet wiregui", "add chain inet wiregui test"]
|
||||||
|
await _nft_batch(cmds)
|
||||||
|
|
||||||
|
mock_exec.assert_awaited_once()
|
||||||
|
# Verify nft -f - was called
|
||||||
|
call_args = mock_exec.call_args[0]
|
||||||
|
assert call_args == ("nft", "-f", "-")
|
||||||
|
# Verify stdin data
|
||||||
|
stdin_data = mock_proc.communicate.call_args[0][0]
|
||||||
|
assert b"add table inet wiregui" in stdin_data
|
||||||
|
assert b"add chain inet wiregui test" in stdin_data
|
||||||
|
|
||||||
|
|
||||||
|
# ========== add_device_jump_rule edge cases ==========
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
async def test_add_device_jump_rule_ipv4_only(mock_batch):
|
||||||
|
"""Only IPv4 — generates single IPv4 jump rule."""
|
||||||
|
await add_device_jump_rule("user-id-1", "10.0.0.5", None)
|
||||||
|
mock_batch.assert_awaited_once()
|
||||||
|
cmds = mock_batch.call_args[0][0]
|
||||||
|
assert len(cmds) == 1
|
||||||
|
assert "ip saddr 10.0.0.5" in cmds[0]
|
||||||
|
assert "jump" in cmds[0]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
async def test_add_device_jump_rule_ipv6_only(mock_batch):
|
||||||
|
"""Only IPv6 — generates single IPv6 jump rule."""
|
||||||
|
await add_device_jump_rule("user-id-2", None, "fd00::5")
|
||||||
|
mock_batch.assert_awaited_once()
|
||||||
|
cmds = mock_batch.call_args[0][0]
|
||||||
|
assert len(cmds) == 1
|
||||||
|
assert "ip6 saddr fd00::5" in cmds[0]
|
||||||
|
assert "jump" in cmds[0]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
async def test_add_device_jump_rule_no_ips(mock_batch):
|
||||||
|
"""Neither IPv4 nor IPv6 — no nft commands issued."""
|
||||||
|
await add_device_jump_rule("user-id-3", None, None)
|
||||||
|
mock_batch.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
async def test_add_device_jump_rule_both_ips(mock_batch):
|
||||||
|
"""Both IPv4 and IPv6 — generates two jump rules."""
|
||||||
|
await add_device_jump_rule("user-id-4", "10.0.0.7", "fd00::7")
|
||||||
|
mock_batch.assert_awaited_once()
|
||||||
|
cmds = mock_batch.call_args[0][0]
|
||||||
|
assert len(cmds) == 2
|
||||||
|
assert any("ip saddr 10.0.0.7" in c for c in cmds)
|
||||||
|
assert any("ip6 saddr fd00::7" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== setup_base_tables — already exists ==========
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
async def test_setup_base_tables_already_exists(mock_batch):
|
||||||
|
"""If table already exists (File exists error), don't raise."""
|
||||||
|
mock_batch.side_effect = RuntimeError("File exists")
|
||||||
|
await setup_base_tables() # should not raise
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
async def test_setup_base_tables_other_error_raises(mock_batch):
|
||||||
|
"""Other nft errors should propagate."""
|
||||||
|
mock_batch.side_effect = RuntimeError("Permission denied")
|
||||||
|
with pytest.raises(RuntimeError, match="Permission denied"):
|
||||||
|
await setup_base_tables()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== setup_masquerade — error handling ==========
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
async def test_setup_masquerade_error_swallowed(mock_batch):
|
||||||
|
"""Masquerade errors are logged but not raised."""
|
||||||
|
mock_batch.side_effect = RuntimeError("nft error")
|
||||||
|
await setup_masquerade(iface="wg0") # should not raise
|
||||||
|
|
||||||
|
|
||||||
|
# ========== policy functions — command verification ==========
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
async def test_peer_to_peer_enabled(mock_batch):
|
||||||
|
"""Enabling peer-to-peer generates accept rules."""
|
||||||
|
await apply_peer_to_peer_policy(True)
|
||||||
|
cmds = mock_batch.call_args[0][0]
|
||||||
|
assert any("accept" in c for c in cmds)
|
||||||
|
assert any("peer_to_peer" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
async def test_peer_to_peer_disabled(mock_batch):
|
||||||
|
"""Disabling peer-to-peer generates drop rules."""
|
||||||
|
await apply_peer_to_peer_policy(False)
|
||||||
|
cmds = mock_batch.call_args[0][0]
|
||||||
|
assert any("drop" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
async def test_lan_to_peers_enabled(mock_batch):
|
||||||
|
"""Enabling LAN-to-peers generates accept rules."""
|
||||||
|
await apply_lan_to_peers_policy(True)
|
||||||
|
cmds = mock_batch.call_args[0][0]
|
||||||
|
assert any("accept" in c for c in cmds)
|
||||||
|
assert any("lan_to_peers" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
async def test_lan_to_peers_disabled(mock_batch):
|
||||||
|
"""Disabling LAN-to-peers generates drop rules."""
|
||||||
|
await apply_lan_to_peers_policy(False)
|
||||||
|
cmds = mock_batch.call_args[0][0]
|
||||||
|
assert any("drop" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== get_ruleset — error handling ==========
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft", new_callable=AsyncMock)
|
||||||
|
async def test_get_ruleset_returns_output(mock_nft):
|
||||||
|
"""get_ruleset returns nft list ruleset output."""
|
||||||
|
mock_nft.return_value = "table inet wiregui { ... }"
|
||||||
|
result = await get_ruleset()
|
||||||
|
assert "wiregui" in result
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft", new_callable=AsyncMock)
|
||||||
|
async def test_get_ruleset_returns_fallback_on_error(mock_nft):
|
||||||
|
"""get_ruleset returns friendly message when nft not available."""
|
||||||
|
mock_nft.side_effect = RuntimeError("nft not found")
|
||||||
|
result = await get_ruleset()
|
||||||
|
assert "not available" in result
|
||||||
114
tests/test_wireguard_extended.py
Normal file
114
tests/test_wireguard_extended.py
Normal file
|
|
@ -0,0 +1,114 @@
|
||||||
|
"""Tests for WireGuard service — ensure_interface, set_private_key, set_listen_port, configure_interface."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch, call
|
||||||
|
|
||||||
|
from wiregui.services.wireguard import (
|
||||||
|
ensure_interface,
|
||||||
|
set_private_key,
|
||||||
|
set_listen_port,
|
||||||
|
configure_interface,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== ensure_interface ==========
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||||
|
async def test_ensure_interface_already_exists(mock_run):
|
||||||
|
"""If interface exists (ip link show succeeds), do nothing."""
|
||||||
|
mock_run.return_value = ""
|
||||||
|
await ensure_interface(iface="wg-test")
|
||||||
|
# Only called once for ip link show
|
||||||
|
mock_run.assert_awaited_once_with(["ip", "link", "show", "wg-test"])
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||||
|
async def test_ensure_interface_creates_new(mock_run):
|
||||||
|
"""If interface doesn't exist, create it, assign IPs, bring up."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def side_effect(args, input_data=None):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1 and args == ["ip", "link", "show", "wg-test"]:
|
||||||
|
raise RuntimeError("Device not found")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
mock_run.side_effect = side_effect
|
||||||
|
await ensure_interface(iface="wg-test")
|
||||||
|
|
||||||
|
# Should have called: ip link show (fails), ip link add, ip addr add x2, ip link set up
|
||||||
|
assert mock_run.await_count == 5
|
||||||
|
calls = [c[0][0] for c in mock_run.call_args_list]
|
||||||
|
assert calls[1] == ["ip", "link", "add", "wg-test", "type", "wireguard"]
|
||||||
|
assert calls[2][0:3] == ["ip", "address", "add"]
|
||||||
|
assert calls[3][0:3] == ["ip", "address", "add"]
|
||||||
|
assert calls[4] == ["ip", "link", "set", "wg-test", "up"]
|
||||||
|
|
||||||
|
|
||||||
|
# ========== set_private_key ==========
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||||
|
async def test_set_private_key(mock_run):
|
||||||
|
"""set_private_key calls wg set with private-key path."""
|
||||||
|
mock_run.return_value = ""
|
||||||
|
await set_private_key("/tmp/test.key", iface="wg-test")
|
||||||
|
mock_run.assert_awaited_once_with(["wg", "set", "wg-test", "private-key", "/tmp/test.key"])
|
||||||
|
|
||||||
|
|
||||||
|
# ========== set_listen_port ==========
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||||
|
async def test_set_listen_port(mock_run):
|
||||||
|
"""set_listen_port calls wg set with listen-port."""
|
||||||
|
mock_run.return_value = ""
|
||||||
|
await set_listen_port(51820, iface="wg-test")
|
||||||
|
mock_run.assert_awaited_once_with(["wg", "set", "wg-test", "listen-port", "51820"])
|
||||||
|
|
||||||
|
|
||||||
|
# ========== configure_interface ==========
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||||
|
@patch("wiregui.db.async_session")
|
||||||
|
async def test_configure_interface_no_config(mock_session_cls, mock_run):
|
||||||
|
"""If no Configuration row exists, do not call wg set."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = None
|
||||||
|
mock_session.execute.return_value = mock_result
|
||||||
|
mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
await configure_interface(iface="wg-test")
|
||||||
|
mock_run.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||||
|
@patch("wiregui.db.async_session")
|
||||||
|
async def test_configure_interface_sets_key_and_port(mock_session_cls, mock_run):
|
||||||
|
"""With valid config, writes key to temp file and calls wg set."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.server_private_key = "test-private-key-value"
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = mock_config
|
||||||
|
mock_session.execute.return_value = mock_result
|
||||||
|
mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
mock_run.return_value = ""
|
||||||
|
await configure_interface(iface="wg-test")
|
||||||
|
|
||||||
|
mock_run.assert_awaited_once()
|
||||||
|
args = mock_run.call_args[0][0]
|
||||||
|
assert args[0:3] == ["wg", "set", "wg-test"]
|
||||||
|
assert "private-key" in args
|
||||||
|
assert "listen-port" in args
|
||||||
|
|
@ -17,7 +17,7 @@ def _build_saml_settings(provider_config: dict) -> dict:
|
||||||
idp_settings = idp_data.get("idp", {})
|
idp_settings = idp_data.get("idp", {})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"strict": True,
|
"strict": provider_config.get("strict", True),
|
||||||
"debug": False,
|
"debug": False,
|
||||||
"sp": {
|
"sp": {
|
||||||
"entityId": f"{base_url}/auth/saml/{provider_config['id']}/metadata",
|
"entityId": f"{base_url}/auth/saml/{provider_config['id']}/metadata",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from loguru import logger
|
||||||
from nicegui import app, ui
|
from nicegui import app, ui
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
||||||
|
from wiregui.config import get_settings
|
||||||
from wiregui.db import async_session
|
from wiregui.db import async_session
|
||||||
from wiregui.models.configuration import Configuration
|
from wiregui.models.configuration import Configuration
|
||||||
from wiregui.pages.layout import layout
|
from wiregui.pages.layout import layout
|
||||||
|
|
|
||||||
|
|
@ -101,14 +101,13 @@ async def saml_callback(provider_id: str, request: Request):
|
||||||
session.add(user)
|
session.add(user)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
request.session["authenticated"] = True
|
# Store auth data in Starlette session — picked up by /auth/complete
|
||||||
request.session["user_id"] = str(user.id)
|
request.session["oidc_user_id"] = str(user.id)
|
||||||
request.session["email"] = user.email
|
request.session["oidc_email"] = user.email
|
||||||
request.session["role"] = user.role
|
request.session["oidc_role"] = user.role
|
||||||
request.session["theme_preference"] = user.theme_preference
|
|
||||||
|
|
||||||
logger.info("SAML login: {} via {}", email, provider_id)
|
logger.info("SAML login: {} via {}", email, provider_id)
|
||||||
return RedirectResponse(url="/", status_code=303)
|
return RedirectResponse(url="/auth/complete", status_code=303)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("SAML callback failed for {}: {}", provider_id, e)
|
logger.error("SAML callback failed for {}: {}", provider_id, e)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Login page — email/password, MFA redirect, OIDC provider buttons."""
|
"""Login page — email/password, MFA redirect, OIDC/SAML provider buttons."""
|
||||||
|
|
||||||
from nicegui import app, ui
|
from nicegui import app, ui
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
@ -6,6 +6,7 @@ from sqlmodel import select
|
||||||
from wiregui.auth.oidc import load_providers
|
from wiregui.auth.oidc import load_providers
|
||||||
from wiregui.auth.session import authenticate_user
|
from wiregui.auth.session import authenticate_user
|
||||||
from wiregui.db import async_session
|
from wiregui.db import async_session
|
||||||
|
from wiregui.models.configuration import Configuration
|
||||||
from wiregui.models.mfa_method import MFAMethod
|
from wiregui.models.mfa_method import MFAMethod
|
||||||
from wiregui.pages.style import apply_style
|
from wiregui.pages.style import apply_style
|
||||||
from wiregui.utils.time import utcnow
|
from wiregui.utils.time import utcnow
|
||||||
|
|
@ -18,9 +19,13 @@ async def login_page():
|
||||||
|
|
||||||
apply_style()
|
apply_style()
|
||||||
|
|
||||||
# Load OIDC providers for SSO buttons
|
# Load SSO providers for login buttons
|
||||||
oidc_providers = await load_providers()
|
oidc_providers = await load_providers()
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||||
|
saml_providers = config.saml_identity_providers if config else []
|
||||||
|
|
||||||
async def try_login():
|
async def try_login():
|
||||||
user = await authenticate_user(email.value, password.value)
|
user = await authenticate_user(email.value, password.value)
|
||||||
if user is None:
|
if user is None:
|
||||||
|
|
@ -76,8 +81,8 @@ async def login_page():
|
||||||
|
|
||||||
password.on("keydown.enter", try_login)
|
password.on("keydown.enter", try_login)
|
||||||
|
|
||||||
# OIDC provider buttons
|
# SSO provider buttons
|
||||||
if oidc_providers:
|
if oidc_providers or saml_providers:
|
||||||
ui.separator().classes("q-my-md")
|
ui.separator().classes("q-my-md")
|
||||||
ui.label("Or sign in with").classes("text-caption text-center w-full")
|
ui.label("Or sign in with").classes("text-caption text-center w-full")
|
||||||
for provider in oidc_providers:
|
for provider in oidc_providers:
|
||||||
|
|
@ -87,3 +92,10 @@ async def login_page():
|
||||||
label,
|
label,
|
||||||
on_click=lambda p=pid: ui.run_javascript(f"window.location.href='/auth/oidc/{p}'"),
|
on_click=lambda p=pid: ui.run_javascript(f"window.location.href='/auth/oidc/{p}'"),
|
||||||
).props("color=primary unelevated").classes("w-full q-mt-xs")
|
).props("color=primary unelevated").classes("w-full q-mt-xs")
|
||||||
|
for provider in saml_providers:
|
||||||
|
pid = provider.get("id", "")
|
||||||
|
label = provider.get("label", pid)
|
||||||
|
ui.button(
|
||||||
|
label,
|
||||||
|
on_click=lambda p=pid: ui.run_javascript(f"window.location.href='/auth/saml/{p}'"),
|
||||||
|
).props("color=primary unelevated").classes("w-full q-mt-xs")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue