feat: comprehensive test suite + SAML auth fixes + mock SAML IdP
Tests (198 unit + 70 e2e = 268 total): - Add test_api_deps.py: Bearer token auth, get_current_api_user, require_admin - Add test_wireguard_extended.py: ensure_interface, set_private_key, set_listen_port - Add test_firewall_extended.py: _nft/_nft_batch errors, jump rules, policies - Add test_mfa_login.py: MFA redirect, TOTP verify, invalid code, cancel - Add test_magic_link_page.py: page render, submit, empty email, back to login - Add test_admin_devices.py: list, filter, create, edit, delete, config dialog - Add test_admin_rules.py: list, create, edit, delete (all DB-verified) - Add test_admin_settings.py: defaults, security, OIDC/SAML providers - Add test_saml_login.py: button visible, redirect, metadata, full login flow Bug fixes: - Fix SAML callback to use /auth/complete bridge (same fix as OIDC) - Fix missing get_settings import in admin settings page - Add SAML provider buttons to login page - Make SAML strict mode configurable per-provider Infrastructure: - Add mock SimpleSAMLphp IdP to compose.yml with SP config - Add mock-saml service to CI workflows (release + dev)
This commit is contained in:
parent
25cff5e4d9
commit
06b5a3dc12
18 changed files with 1768 additions and 47 deletions
239
tests/e2e/test_admin_devices.py
Normal file
239
tests/e2e/test_admin_devices.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""E2E tests for admin device management page."""
|
||||
|
||||
import pytest_asyncio
|
||||
from playwright.async_api import Page, expect
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.crypto import generate_keypair, generate_preshared_key
|
||||
from tests.e2e.conftest import (
|
||||
TEST_APP_BASE,
|
||||
TEST_EMAIL,
|
||||
TEST_PASSWORD,
|
||||
_cleanup_user_by_email,
|
||||
login,
|
||||
)
|
||||
|
||||
SECOND_USER_EMAIL = "e2e-device-user2@example.com"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def second_user(test_user):
|
||||
"""Create a second user with a device for filtering tests."""
|
||||
await _cleanup_user_by_email(SECOND_USER_EMAIL)
|
||||
|
||||
async with async_session() as session:
|
||||
user = User(
|
||||
email=SECOND_USER_EMAIL,
|
||||
password_hash=hash_password("pass12345"),
|
||||
role="unprivileged",
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
yield user
|
||||
|
||||
await _cleanup_user_by_email(SECOND_USER_EMAIL)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def devices_for_both_users(test_user, second_user):
|
||||
"""Create one device per user for table/filter tests."""
|
||||
_, pub1 = generate_keypair()
|
||||
_, pub2 = generate_keypair()
|
||||
psk1 = generate_preshared_key()
|
||||
psk2 = generate_preshared_key()
|
||||
|
||||
async with async_session() as session:
|
||||
d1 = Device(
|
||||
name="admin-laptop",
|
||||
public_key=pub1,
|
||||
preshared_key=psk1,
|
||||
ipv4="10.0.0.10",
|
||||
user_id=test_user.id,
|
||||
)
|
||||
d2 = Device(
|
||||
name="user2-phone",
|
||||
public_key=pub2,
|
||||
preshared_key=psk2,
|
||||
ipv4="10.0.0.11",
|
||||
user_id=second_user.id,
|
||||
)
|
||||
session.add_all([d1, d2])
|
||||
await session.commit()
|
||||
|
||||
yield d1, d2
|
||||
|
||||
# Cleanup handled by user fixture cascade
|
||||
|
||||
|
||||
async def _go_to_admin_devices(page: Page):
|
||||
"""Login as admin and navigate to admin devices page."""
|
||||
await login(page)
|
||||
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
|
||||
await page.goto(f"{TEST_APP_BASE}/admin/devices")
|
||||
await expect(page.locator("role=main").get_by_text("All Devices")).to_be_visible(timeout=10_000)
|
||||
|
||||
|
||||
async def test_list_all_devices(page: Page, devices_for_both_users):
|
||||
"""Admin devices page lists devices from all users."""
|
||||
await _go_to_admin_devices(page)
|
||||
await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
|
||||
await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
|
||||
|
||||
|
||||
async def test_filter_by_user(page: Page, second_user, devices_for_both_users):
|
||||
"""Filtering by user shows only that user's devices."""
|
||||
await _go_to_admin_devices(page)
|
||||
await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
|
||||
await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
|
||||
|
||||
# Filter to second user
|
||||
await page.locator("label:has-text('Filter by User')").click()
|
||||
await page.get_by_role("option", name=SECOND_USER_EMAIL).click()
|
||||
await page.wait_for_timeout(1000)
|
||||
|
||||
await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
|
||||
await expect(page.get_by_text("admin-laptop")).not_to_be_visible()
|
||||
|
||||
# Filter back to all
|
||||
await page.locator("label:has-text('Filter by User')").click()
|
||||
await page.get_by_role("option", name="All Users").click()
|
||||
await page.wait_for_timeout(1000)
|
||||
|
||||
await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
|
||||
await expect(page.get_by_text("user2-phone")).to_be_visible(timeout=5_000)
|
||||
|
||||
|
||||
async def test_create_device_with_defaults(page: Page, test_user):
|
||||
"""Create device with all defaults — config dialog appears."""
|
||||
await _go_to_admin_devices(page)
|
||||
await page.get_by_role("button", name="Add Device").click()
|
||||
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
|
||||
|
||||
await page.locator("input[aria-label='Device Name']").fill("default-test-device")
|
||||
await page.get_by_role("button", name="Create").click()
|
||||
|
||||
# Config dialog should appear with WireGuard config
|
||||
await expect(page.get_by_text("Config for default-test-device")).to_be_visible(timeout=10_000)
|
||||
await expect(page.get_by_text("[Interface]")).to_be_visible(timeout=5_000)
|
||||
await page.get_by_role("button", name="Close").click()
|
||||
await page.wait_for_timeout(500)
|
||||
|
||||
# Device should be in the table
|
||||
await expect(page.get_by_role("cell", name="default-test-device").first).to_be_visible(timeout=5_000)
|
||||
|
||||
|
||||
async def test_create_device_with_overrides(page: Page, test_user):
|
||||
"""Create device with custom config overrides."""
|
||||
await _go_to_admin_devices(page)
|
||||
await page.get_by_role("button", name="Add Device").click()
|
||||
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
|
||||
|
||||
await page.locator("input[aria-label='Device Name']").fill("custom-override-dev")
|
||||
await page.locator("input[aria-label='Description (optional)']").fill("Custom overrides test")
|
||||
|
||||
# Toggle off DNS default and set custom — Quasar switches use .q-toggle
|
||||
await page.locator(".q-toggle", has_text="Use default DNS").click()
|
||||
dns_input = page.locator("input[aria-label='DNS Servers']")
|
||||
await dns_input.clear()
|
||||
await dns_input.fill("8.8.8.8, 8.8.4.4")
|
||||
|
||||
# Toggle off MTU default and set custom
|
||||
await page.locator(".q-toggle", has_text="Use default MTU").click()
|
||||
mtu_input = page.locator("input[aria-label='MTU']")
|
||||
await mtu_input.clear()
|
||||
await mtu_input.fill("1400")
|
||||
|
||||
await page.get_by_role("button", name="Create").click()
|
||||
|
||||
await expect(page.get_by_text("Config for custom-override-dev")).to_be_visible(timeout=10_000)
|
||||
await page.get_by_role("button", name="Close").click()
|
||||
await page.wait_for_timeout(500)
|
||||
|
||||
await expect(page.get_by_role("cell", name="custom-override-dev").first).to_be_visible(timeout=5_000)
|
||||
|
||||
# Verify in DB
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Device).where(Device.name == "custom-override-dev")
|
||||
.order_by(Device.inserted_at.desc()).limit(1)
|
||||
)
|
||||
device = result.scalar_one()
|
||||
assert device.use_default_dns is False
|
||||
assert "8.8.8.8" in device.dns
|
||||
assert device.use_default_mtu is False
|
||||
assert device.mtu == 1400
|
||||
|
||||
|
||||
async def test_edit_device_name_and_description(page: Page, devices_for_both_users):
|
||||
"""Edit a device name and description via the edit dialog."""
|
||||
await _go_to_admin_devices(page)
|
||||
await expect(page.get_by_text("admin-laptop")).to_be_visible(timeout=5_000)
|
||||
|
||||
# Click edit button on admin-laptop row — Quasar slot buttons with icon
|
||||
row = page.locator("tr", has_text="admin-laptop")
|
||||
await row.locator(".q-btn").first.click()
|
||||
|
||||
await expect(page.get_by_text("Edit Device")).to_be_visible(timeout=5_000)
|
||||
|
||||
name_input = page.locator(".q-dialog input[aria-label='Device Name']")
|
||||
await name_input.clear()
|
||||
await name_input.fill("admin-laptop-renamed")
|
||||
|
||||
desc_input = page.locator(".q-dialog input[aria-label='Description']")
|
||||
await desc_input.clear()
|
||||
await desc_input.fill("Updated description")
|
||||
|
||||
await page.get_by_role("button", name="Save").click()
|
||||
await expect(page.get_by_text("Device updated")).to_be_visible(timeout=5_000)
|
||||
await expect(page.get_by_text("admin-laptop-renamed")).to_be_visible(timeout=5_000)
|
||||
|
||||
|
||||
async def test_delete_device(page: Page, test_user):
|
||||
"""Delete a device — removed from table."""
|
||||
_, pub = generate_keypair()
|
||||
async with async_session() as session:
|
||||
d = Device(
|
||||
name="delete-me-device",
|
||||
public_key=pub,
|
||||
preshared_key=generate_preshared_key(),
|
||||
ipv4="10.0.0.99",
|
||||
user_id=test_user.id,
|
||||
)
|
||||
session.add(d)
|
||||
await session.commit()
|
||||
|
||||
await _go_to_admin_devices(page)
|
||||
await expect(page.get_by_role("cell", name="delete-me-device")).to_be_visible(timeout=5_000)
|
||||
|
||||
# Click the delete (second) button in the row
|
||||
row = page.locator("tr", has_text="delete-me-device")
|
||||
await row.locator(".q-btn").nth(1).click()
|
||||
|
||||
await expect(page.get_by_text("Deleted delete-me-device")).to_be_visible(timeout=5_000)
|
||||
await page.wait_for_timeout(1000)
|
||||
await expect(page.get_by_role("cell", name="delete-me-device")).not_to_be_visible()
|
||||
|
||||
|
||||
async def test_config_dialog_shows_wg_config(page: Page, test_user):
|
||||
"""Config dialog after device creation shows valid WireGuard config."""
|
||||
await _go_to_admin_devices(page)
|
||||
await page.get_by_role("button", name="Add Device").click()
|
||||
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
|
||||
|
||||
await page.locator("input[aria-label='Device Name']").fill("config-test-device")
|
||||
await page.get_by_role("button", name="Create").click()
|
||||
|
||||
await expect(page.get_by_text("Config for config-test-device")).to_be_visible(timeout=10_000)
|
||||
await expect(page.get_by_text("[Interface]")).to_be_visible(timeout=5_000)
|
||||
await expect(page.get_by_text("[Peer]")).to_be_visible(timeout=5_000)
|
||||
await expect(page.get_by_text("PrivateKey")).to_be_visible()
|
||||
await expect(page.get_by_role("button", name="Download .conf")).to_be_visible()
|
||||
|
||||
# QR code should be rendered
|
||||
await expect(page.locator(".q-dialog img")).to_be_visible(timeout=5_000)
|
||||
227
tests/e2e/test_admin_rules.py
Normal file
227
tests/e2e/test_admin_rules.py
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
"""E2E tests for admin firewall rules management page."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import pytest_asyncio
|
||||
from playwright.async_api import Page, expect
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
from tests.e2e.conftest import TEST_APP_BASE, TEST_EMAIL, login
|
||||
|
||||
|
||||
async def _cleanup_test_rules():
|
||||
"""Remove rules created by tests (identified by test-specific destinations)."""
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Rule).where(Rule.destination.in_([
|
||||
"10.99.0.0/16", "10.88.0.0/16", "10.77.0.0/16",
|
||||
"10.66.0.0/16", "10.55.0.0/16",
|
||||
]))
|
||||
)
|
||||
for rule in result.scalars().all():
|
||||
await session.delete(rule)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def clean_rules(app_server):
|
||||
"""Clean up test rules before and after each test."""
|
||||
await _cleanup_test_rules()
|
||||
yield
|
||||
await _cleanup_test_rules()
|
||||
|
||||
|
||||
async def _go_to_rules(page: Page):
|
||||
"""Login and navigate to admin rules page."""
|
||||
await login(page)
|
||||
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
|
||||
await page.goto(f"{TEST_APP_BASE}/admin/rules")
|
||||
await expect(page.locator("role=main").get_by_text("Firewall Rules")).to_be_visible(timeout=10_000)
|
||||
|
||||
|
||||
async def _create_rule_via_dialog(
|
||||
page: Page, *, action: str = "accept", destination: str = "10.99.0.0/16",
|
||||
protocol: str = "any", port_range: str = "", user: str = "global",
|
||||
):
|
||||
"""Open create dialog and fill in a rule."""
|
||||
await page.get_by_role("button", name="Add Rule").click()
|
||||
await expect(page.get_by_text("New Firewall Rule")).to_be_visible(timeout=5_000)
|
||||
|
||||
# Action select
|
||||
if action != "accept":
|
||||
await page.locator(".q-dialog label:has-text('Action')").click()
|
||||
await page.get_by_role("option", name=action).click()
|
||||
|
||||
# Destination
|
||||
await page.locator(".q-dialog input[aria-label='Destination (CIDR)']").fill(destination)
|
||||
|
||||
# Protocol
|
||||
if protocol != "any":
|
||||
await page.locator(".q-dialog label:has-text('Protocol')").click()
|
||||
await page.get_by_role("option", name=protocol).click()
|
||||
|
||||
# Port range
|
||||
if port_range:
|
||||
await page.locator(".q-dialog input[aria-label='Port Range']").fill(port_range)
|
||||
|
||||
# User
|
||||
if user != "global":
|
||||
await page.locator(".q-dialog label:has-text('Applies to')").click()
|
||||
await page.get_by_role("option", name=user).click()
|
||||
|
||||
await page.get_by_role("button", name="Create").click()
|
||||
await page.wait_for_timeout(500)
|
||||
|
||||
|
||||
async def test_list_rules_table(page: Page, test_user: User):
|
||||
"""Rules page renders table with correct columns."""
|
||||
# Seed a rule in DB
|
||||
async with async_session() as session:
|
||||
rule = Rule(action="accept", destination="10.99.0.0/16", port_type="tcp",
|
||||
port_range="443", user_id=test_user.id)
|
||||
session.add(rule)
|
||||
await session.commit()
|
||||
|
||||
await _go_to_rules(page)
|
||||
|
||||
await expect(page.get_by_role("cell", name="accept")).to_be_visible(timeout=5_000)
|
||||
await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible()
|
||||
await expect(page.get_by_role("cell", name="tcp")).to_be_visible()
|
||||
await expect(page.get_by_role("cell", name="443")).to_be_visible()
|
||||
await expect(page.get_by_role("cell", name=TEST_EMAIL)).to_be_visible()
|
||||
|
||||
|
||||
async def test_create_accept_rule_with_cidr(page: Page, test_user: User):
|
||||
"""Create an accept rule with CIDR — verify in table and DB."""
|
||||
await _go_to_rules(page)
|
||||
await _create_rule_via_dialog(page, action="accept", destination="10.88.0.0/16")
|
||||
|
||||
await expect(page.get_by_role("cell", name="10.88.0.0/16")).to_be_visible(timeout=5_000)
|
||||
|
||||
# Verify in DB
|
||||
async with async_session() as session:
|
||||
result = await session.execute(select(Rule).where(Rule.destination == "10.88.0.0/16"))
|
||||
rule = result.scalar_one()
|
||||
assert rule.action == "accept"
|
||||
assert rule.port_type is None
|
||||
assert rule.port_range is None
|
||||
assert rule.user_id is None
|
||||
|
||||
|
||||
async def test_create_drop_rule_with_tcp_port_range(page: Page, test_user: User):
|
||||
"""Create a drop rule with TCP port range — verify in table and DB."""
|
||||
await _go_to_rules(page)
|
||||
await _create_rule_via_dialog(
|
||||
page, action="drop", destination="10.77.0.0/16",
|
||||
protocol="tcp", port_range="80-443",
|
||||
)
|
||||
|
||||
await expect(page.get_by_role("cell", name="10.77.0.0/16")).to_be_visible(timeout=5_000)
|
||||
await expect(page.get_by_role("cell", name="drop").first).to_be_visible()
|
||||
|
||||
# Verify in DB
|
||||
async with async_session() as session:
|
||||
result = await session.execute(select(Rule).where(Rule.destination == "10.77.0.0/16"))
|
||||
rule = result.scalar_one()
|
||||
assert rule.action == "drop"
|
||||
assert rule.port_type == "tcp"
|
||||
assert rule.port_range == "80-443"
|
||||
|
||||
|
||||
async def test_create_global_rule(page: Page, test_user: User):
|
||||
"""Create a global rule (no user) — shows 'Global' in table and DB has null user_id."""
|
||||
await _go_to_rules(page)
|
||||
await _create_rule_via_dialog(page, destination="10.66.0.0/16", user="global")
|
||||
|
||||
await expect(page.get_by_role("cell", name="10.66.0.0/16")).to_be_visible(timeout=5_000)
|
||||
await expect(page.get_by_role("cell", name="Global")).to_be_visible()
|
||||
|
||||
# Verify in DB
|
||||
async with async_session() as session:
|
||||
result = await session.execute(select(Rule).where(Rule.destination == "10.66.0.0/16"))
|
||||
rule = result.scalar_one()
|
||||
assert rule.user_id is None
|
||||
|
||||
|
||||
async def test_edit_rule_action(page: Page, test_user: User):
|
||||
"""Edit rule action from accept to drop — verify in table and DB."""
|
||||
async with async_session() as session:
|
||||
rule = Rule(action="accept", destination="10.55.0.0/16")
|
||||
session.add(rule)
|
||||
await session.commit()
|
||||
rule_id = rule.id
|
||||
|
||||
await _go_to_rules(page)
|
||||
await expect(page.get_by_role("cell", name="10.55.0.0/16")).to_be_visible(timeout=5_000)
|
||||
|
||||
# Click edit (first button in the row)
|
||||
row = page.locator("tr", has_text="10.55.0.0/16")
|
||||
await row.locator(".q-btn").first.click()
|
||||
await expect(page.get_by_text("Edit Firewall Rule")).to_be_visible(timeout=5_000)
|
||||
|
||||
# Change action to drop
|
||||
await page.locator(".q-dialog label:has-text('Action')").click()
|
||||
await page.get_by_role("option", name="drop").click()
|
||||
|
||||
await page.get_by_role("button", name="Save").click()
|
||||
await expect(page.get_by_text("Rule updated")).to_be_visible(timeout=5_000)
|
||||
|
||||
# Verify in DB
|
||||
async with async_session() as session:
|
||||
rule = await session.get(Rule, rule_id)
|
||||
assert rule.action == "drop"
|
||||
|
||||
|
||||
async def test_edit_rule_destination(page: Page, test_user: User):
|
||||
"""Edit rule destination — verify in table and DB."""
|
||||
async with async_session() as session:
|
||||
rule = Rule(action="accept", destination="10.99.0.0/16")
|
||||
session.add(rule)
|
||||
await session.commit()
|
||||
rule_id = rule.id
|
||||
|
||||
await _go_to_rules(page)
|
||||
await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible(timeout=5_000)
|
||||
|
||||
row = page.locator("tr", has_text="10.99.0.0/16")
|
||||
await row.locator(".q-btn").first.click()
|
||||
await expect(page.get_by_text("Edit Firewall Rule")).to_be_visible(timeout=5_000)
|
||||
|
||||
dest_input = page.locator(".q-dialog input[aria-label='Destination (CIDR)']")
|
||||
await dest_input.clear()
|
||||
await dest_input.fill("10.88.0.0/16")
|
||||
|
||||
await page.get_by_role("button", name="Save").click()
|
||||
await expect(page.get_by_text("Rule updated")).to_be_visible(timeout=5_000)
|
||||
|
||||
# Verify in DB
|
||||
async with async_session() as session:
|
||||
rule = await session.get(Rule, rule_id)
|
||||
assert rule.destination == "10.88.0.0/16"
|
||||
|
||||
|
||||
async def test_delete_rule(page: Page, test_user: User):
|
||||
"""Delete a rule — removed from table and DB."""
|
||||
async with async_session() as session:
|
||||
rule = Rule(action="accept", destination="10.99.0.0/16")
|
||||
session.add(rule)
|
||||
await session.commit()
|
||||
rule_id = rule.id
|
||||
|
||||
await _go_to_rules(page)
|
||||
await expect(page.get_by_role("cell", name="10.99.0.0/16")).to_be_visible(timeout=5_000)
|
||||
|
||||
# Click delete (second button in the row)
|
||||
row = page.locator("tr", has_text="10.99.0.0/16")
|
||||
await row.locator(".q-btn").nth(1).click()
|
||||
await page.wait_for_timeout(1000)
|
||||
|
||||
await expect(page.get_by_role("cell", name="10.99.0.0/16")).not_to_be_visible()
|
||||
|
||||
# Verify in DB
|
||||
async with async_session() as session:
|
||||
rule = await session.get(Rule, rule_id)
|
||||
assert rule is None
|
||||
281
tests/e2e/test_admin_settings.py
Normal file
281
tests/e2e/test_admin_settings.py
Normal file
|
|
@ -0,0 +1,281 @@
|
|||
"""E2E tests for admin settings page — client defaults, security, OIDC/SAML providers."""
|
||||
|
||||
import pytest_asyncio
|
||||
from playwright.async_api import Page, expect
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.user import User
|
||||
from tests.e2e.conftest import TEST_APP_BASE, login
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def reset_config(app_server):
|
||||
"""Snapshot config before test, restore after."""
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||
if not c:
|
||||
yield
|
||||
return
|
||||
snap = {
|
||||
"default_client_endpoint": c.default_client_endpoint,
|
||||
"default_client_dns": list(c.default_client_dns),
|
||||
"default_client_mtu": c.default_client_mtu,
|
||||
"default_client_persistent_keepalive": c.default_client_persistent_keepalive,
|
||||
"default_client_allowed_ips": list(c.default_client_allowed_ips),
|
||||
"vpn_session_duration": c.vpn_session_duration,
|
||||
"local_auth_enabled": c.local_auth_enabled,
|
||||
"allow_unprivileged_device_management": c.allow_unprivileged_device_management,
|
||||
"allow_unprivileged_device_configuration": c.allow_unprivileged_device_configuration,
|
||||
"openid_connect_providers": list(c.openid_connect_providers or []),
|
||||
"saml_identity_providers": list(c.saml_identity_providers or []),
|
||||
}
|
||||
cid = c.id
|
||||
|
||||
yield
|
||||
|
||||
async with async_session() as session:
|
||||
c = await session.get(Configuration, cid)
|
||||
if c:
|
||||
for k, v in snap.items():
|
||||
setattr(c, k, v)
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _go_to_settings(page: Page):
|
||||
await login(page)
|
||||
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
|
||||
await page.goto(f"{TEST_APP_BASE}/admin/settings")
|
||||
await expect(page.get_by_text("Default Client Configuration")).to_be_visible(timeout=10_000)
|
||||
|
||||
|
||||
# --- Client Defaults ---
|
||||
|
||||
|
||||
async def test_save_client_defaults(page: Page, test_user: User):
|
||||
"""Save endpoint, DNS, MTU, keepalive, allowed IPs — verify persists in DB."""
|
||||
await _go_to_settings(page)
|
||||
|
||||
endpoint = page.locator("input[aria-label='Endpoint']")
|
||||
await endpoint.clear()
|
||||
await endpoint.fill("vpn.test.local")
|
||||
|
||||
dns = page.locator("input[aria-label='DNS Servers']")
|
||||
await dns.clear()
|
||||
await dns.fill("9.9.9.9, 149.112.112.112")
|
||||
|
||||
mtu = page.locator("input[aria-label='MTU']")
|
||||
await mtu.clear()
|
||||
await mtu.fill("1420")
|
||||
|
||||
keepalive = page.locator("input[aria-label='Persistent Keepalive']")
|
||||
await keepalive.clear()
|
||||
await keepalive.fill("30")
|
||||
|
||||
allowed = page.locator("input[aria-label='Allowed IPs']")
|
||||
await allowed.clear()
|
||||
await allowed.fill("10.0.0.0/8, 192.168.0.0/16")
|
||||
|
||||
await page.get_by_role("button", name="Save Defaults").click()
|
||||
await expect(page.get_by_text("Client defaults saved")).to_be_visible(timeout=5_000)
|
||||
|
||||
# Verify in DB
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||
assert c.default_client_endpoint == "vpn.test.local"
|
||||
assert c.default_client_dns == ["9.9.9.9", "149.112.112.112"]
|
||||
assert c.default_client_mtu == 1420
|
||||
assert c.default_client_persistent_keepalive == 30
|
||||
assert c.default_client_allowed_ips == ["10.0.0.0/8", "192.168.0.0/16"]
|
||||
|
||||
|
||||
async def test_client_defaults_persist_on_reload(page: Page, test_user: User):
|
||||
"""Saved defaults are reflected after page reload."""
|
||||
# Set values via DB
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||
c.default_client_endpoint = "reload-test.vpn"
|
||||
c.default_client_dns = ["8.8.8.8"]
|
||||
c.default_client_mtu = 1500
|
||||
c.default_client_persistent_keepalive = 15
|
||||
c.default_client_allowed_ips = ["172.16.0.0/12"]
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
|
||||
await _go_to_settings(page)
|
||||
|
||||
await expect(page.locator("input[aria-label='Endpoint']")).to_have_value("reload-test.vpn")
|
||||
await expect(page.locator("input[aria-label='DNS Servers']")).to_have_value("8.8.8.8")
|
||||
await expect(page.locator("input[aria-label='MTU']")).to_have_value("1500")
|
||||
await expect(page.locator("input[aria-label='Persistent Keepalive']")).to_have_value("15")
|
||||
await expect(page.locator("input[aria-label='Allowed IPs']")).to_have_value("172.16.0.0/12")
|
||||
|
||||
|
||||
# --- Security ---
|
||||
|
||||
|
||||
async def test_save_security_local_auth_toggle(page: Page, test_user: User):
|
||||
"""Toggle local auth off — verify in DB."""
|
||||
await _go_to_settings(page)
|
||||
|
||||
# Find the local auth switch and toggle it off
|
||||
switch = page.locator(".q-toggle", has_text="Local Authentication")
|
||||
await switch.click()
|
||||
|
||||
await page.get_by_role("button", name="Save Security Settings").click()
|
||||
await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000)
|
||||
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||
assert c.local_auth_enabled is False
|
||||
|
||||
|
||||
async def test_save_vpn_session_duration(page: Page, test_user: User):
|
||||
"""Change VPN session duration — verify in DB."""
|
||||
await _go_to_settings(page)
|
||||
|
||||
await page.locator("label:has-text('VPN Session Duration')").click()
|
||||
await page.get_by_role("option", name="Every Day").click()
|
||||
|
||||
await page.get_by_role("button", name="Save Security Settings").click()
|
||||
await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000)
|
||||
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||
assert c.vpn_session_duration == 86400
|
||||
|
||||
|
||||
async def test_save_unprivileged_toggles(page: Page, test_user: User):
|
||||
"""Toggle unprivileged device management/configuration — verify in DB."""
|
||||
await _go_to_settings(page)
|
||||
|
||||
await page.locator(".q-toggle", has_text="Allow Unprivileged Device Management").click()
|
||||
await page.locator(".q-toggle", has_text="Allow Unprivileged Device Configuration").click()
|
||||
|
||||
await page.get_by_role("button", name="Save Security Settings").click()
|
||||
await expect(page.get_by_text("Security settings saved")).to_be_visible(timeout=5_000)
|
||||
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||
# Toggled from default (True) to False
|
||||
assert c.allow_unprivileged_device_management is False
|
||||
assert c.allow_unprivileged_device_configuration is False
|
||||
|
||||
|
||||
# --- OIDC Providers ---
|
||||
|
||||
|
||||
async def test_add_oidc_provider(page: Page, test_user: User):
|
||||
"""Add an OIDC provider — appears in table and DB."""
|
||||
await _go_to_settings(page)
|
||||
|
||||
await page.get_by_role("button", name="Add OIDC Provider").click()
|
||||
await expect(page.get_by_text("OIDC Provider", exact=True)).to_be_visible(timeout=5_000)
|
||||
|
||||
await page.locator(".q-dialog input[aria-label='Config ID']").fill("e2e-test-oidc")
|
||||
await page.locator(".q-dialog input[aria-label='Label']").fill("E2E Test IdP")
|
||||
await page.locator(".q-dialog input[aria-label='Client ID']").fill("test-client-id")
|
||||
await page.locator(".q-dialog input[aria-label='Client Secret']").fill("test-client-secret")
|
||||
await page.locator(".q-dialog input[aria-label='Discovery Document URI']").fill("https://idp.test/.well-known/openid-configuration")
|
||||
|
||||
await page.locator(".q-dialog").get_by_role("button", name="Save").click()
|
||||
await expect(page.get_by_text("OIDC provider 'E2E Test IdP' saved")).to_be_visible(timeout=5_000)
|
||||
|
||||
await expect(page.get_by_role("cell", name="e2e-test-oidc")).to_be_visible(timeout=5_000)
|
||||
|
||||
# Verify in DB
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||
provider = next((p for p in c.openid_connect_providers if p["id"] == "e2e-test-oidc"), None)
|
||||
assert provider is not None
|
||||
assert provider["label"] == "E2E Test IdP"
|
||||
assert provider["client_id"] == "test-client-id"
|
||||
|
||||
|
||||
async def test_delete_oidc_provider(page: Page, test_user: User):
|
||||
"""Delete an OIDC provider — removed from table and DB."""
|
||||
# Seed a provider
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||
providers = list(c.openid_connect_providers or [])
|
||||
providers.append({
|
||||
"id": "delete-me-oidc", "label": "Delete Me", "scope": "openid",
|
||||
"client_id": "x", "client_secret": "x",
|
||||
"discovery_document_uri": "https://x/.well-known/openid-configuration",
|
||||
})
|
||||
c.openid_connect_providers = providers
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
|
||||
await _go_to_settings(page)
|
||||
await expect(page.get_by_role("cell", name="delete-me-oidc")).to_be_visible(timeout=5_000)
|
||||
|
||||
row = page.locator("tr", has_text="delete-me-oidc")
|
||||
await row.locator(".q-btn").first.click()
|
||||
|
||||
await expect(page.get_by_text("OIDC provider deleted")).to_be_visible(timeout=5_000)
|
||||
await page.wait_for_timeout(500)
|
||||
await expect(page.get_by_role("cell", name="delete-me-oidc")).not_to_be_visible()
|
||||
|
||||
# Verify in DB
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||
assert not any(p["id"] == "delete-me-oidc" for p in c.openid_connect_providers)
|
||||
|
||||
|
||||
# --- SAML Providers ---
|
||||
|
||||
|
||||
async def test_add_saml_provider(page: Page, test_user: User):
|
||||
"""Add a SAML provider — appears in table and DB."""
|
||||
await _go_to_settings(page)
|
||||
|
||||
await page.get_by_role("button", name="Add SAML Provider").click()
|
||||
await expect(page.get_by_text("SAML Identity Provider", exact=True)).to_be_visible(timeout=5_000)
|
||||
|
||||
await page.locator(".q-dialog input[aria-label='Config ID']").fill("e2e-test-saml")
|
||||
await page.locator(".q-dialog input[aria-label='Label']").fill("E2E SAML IdP")
|
||||
await page.locator(".q-dialog textarea").fill("<EntityDescriptor>test</EntityDescriptor>")
|
||||
|
||||
await page.locator(".q-dialog").get_by_role("button", name="Save").click()
|
||||
await expect(page.get_by_text("SAML provider 'E2E SAML IdP' saved")).to_be_visible(timeout=5_000)
|
||||
|
||||
await expect(page.get_by_role("cell", name="e2e-test-saml")).to_be_visible(timeout=5_000)
|
||||
|
||||
# Verify in DB
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||
provider = next((p for p in c.saml_identity_providers if p["id"] == "e2e-test-saml"), None)
|
||||
assert provider is not None
|
||||
assert provider["label"] == "E2E SAML IdP"
|
||||
|
||||
|
||||
async def test_delete_saml_provider(page: Page, test_user: User):
|
||||
"""Delete a SAML provider — removed from table and DB."""
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||
providers = list(c.saml_identity_providers or [])
|
||||
providers.append({
|
||||
"id": "delete-me-saml", "label": "Delete Me SAML",
|
||||
"metadata": "<EntityDescriptor/>",
|
||||
})
|
||||
c.saml_identity_providers = providers
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
|
||||
await _go_to_settings(page)
|
||||
await expect(page.get_by_role("cell", name="delete-me-saml")).to_be_visible(timeout=5_000)
|
||||
|
||||
row = page.locator("tr", has_text="delete-me-saml")
|
||||
await row.locator(".q-btn").first.click()
|
||||
|
||||
await expect(page.get_by_text("SAML provider deleted")).to_be_visible(timeout=5_000)
|
||||
await page.wait_for_timeout(500)
|
||||
await expect(page.get_by_role("cell", name="delete-me-saml")).not_to_be_visible()
|
||||
|
||||
# Verify in DB
|
||||
async with async_session() as session:
|
||||
c = (await session.execute(select(Configuration).limit(1))).scalar_one()
|
||||
assert not any(p["id"] == "delete-me-saml" for p in c.saml_identity_providers)
|
||||
41
tests/e2e/test_magic_link_page.py
Normal file
41
tests/e2e/test_magic_link_page.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""E2E tests for magic link request page."""
|
||||
|
||||
from playwright.async_api import Page, expect
|
||||
|
||||
from tests.e2e.conftest import TEST_APP_BASE, TEST_EMAIL
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
async def test_magic_link_page_renders(page: Page, test_user: User):
|
||||
"""Magic link request page renders with email input and submit button."""
|
||||
await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
|
||||
await page.wait_for_load_state("networkidle")
|
||||
await expect(page.get_by_text("Sign in with magic link")).to_be_visible(timeout=10_000)
|
||||
await expect(page.locator("input[aria-label='Email']")).to_be_visible()
|
||||
await expect(page.get_by_role("button", name="Send Magic Link")).to_be_visible()
|
||||
await expect(page.get_by_role("button", name="Back to login")).to_be_visible()
|
||||
|
||||
|
||||
async def test_magic_link_shows_success_on_submit(page: Page, test_user: User):
|
||||
"""Submitting an email shows success message (regardless of whether account exists)."""
|
||||
await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
|
||||
await page.wait_for_load_state("networkidle")
|
||||
await page.locator("input[aria-label='Email']").fill(TEST_EMAIL)
|
||||
await page.get_by_role("button", name="Send Magic Link").click()
|
||||
await expect(page.get_by_text("a sign-in link has been sent")).to_be_visible(timeout=5_000)
|
||||
|
||||
|
||||
async def test_magic_link_empty_email_shows_error(page: Page, test_user: User):
|
||||
"""Submitting without email shows error."""
|
||||
await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
|
||||
await page.wait_for_load_state("networkidle")
|
||||
await page.get_by_role("button", name="Send Magic Link").click()
|
||||
await expect(page.get_by_text("Enter your email")).to_be_visible(timeout=5_000)
|
||||
|
||||
|
||||
async def test_magic_link_back_to_login(page: Page, test_user: User):
|
||||
"""Back to login button navigates to login page."""
|
||||
await page.goto(f"{TEST_APP_BASE}/auth/magic-link")
|
||||
await page.wait_for_load_state("networkidle")
|
||||
await page.get_by_role("button", name="Back to login").click()
|
||||
await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible(timeout=10_000)
|
||||
111
tests/e2e/test_mfa_login.py
Normal file
111
tests/e2e/test_mfa_login.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""E2E tests for MFA login flow — login with TOTP redirects to /mfa challenge page."""
|
||||
|
||||
import pyotp
|
||||
import pytest_asyncio
|
||||
from playwright.async_api import Page, expect
|
||||
|
||||
from wiregui.auth.mfa import generate_totp_secret
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.user import User
|
||||
from tests.e2e.conftest import (
|
||||
FAKE_SERVER_KEY,
|
||||
TEST_APP_BASE,
|
||||
TEST_PASSWORD,
|
||||
_cleanup_user_by_email,
|
||||
)
|
||||
|
||||
MFA_EMAIL = "e2e-mfa@example.com"
|
||||
MFA_PASSWORD = "mfapass123"
|
||||
TOTP_SECRET = generate_totp_secret()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mfa_user(app_server):
|
||||
"""Create a user with a TOTP MFA method, clean up after."""
|
||||
await _cleanup_user_by_email(MFA_EMAIL)
|
||||
|
||||
async with async_session() as session:
|
||||
from sqlmodel import select
|
||||
from wiregui.models.configuration import Configuration
|
||||
|
||||
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||
if config:
|
||||
if not config.server_public_key:
|
||||
config.server_public_key = FAKE_SERVER_KEY
|
||||
session.add(config)
|
||||
else:
|
||||
config = Configuration(server_public_key=FAKE_SERVER_KEY)
|
||||
session.add(config)
|
||||
|
||||
user = User(
|
||||
email=MFA_EMAIL,
|
||||
password_hash=hash_password(MFA_PASSWORD),
|
||||
role="admin",
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
mfa = MFAMethod(
|
||||
name="Test TOTP",
|
||||
type="totp",
|
||||
payload={"secret": TOTP_SECRET},
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(mfa)
|
||||
await session.commit()
|
||||
|
||||
yield user
|
||||
|
||||
await _cleanup_user_by_email(MFA_EMAIL)
|
||||
|
||||
|
||||
async def _login_mfa_user(page: Page):
|
||||
"""Fill login form for the MFA user and submit."""
|
||||
await page.goto(f"{TEST_APP_BASE}/login")
|
||||
await page.wait_for_load_state("networkidle")
|
||||
await page.locator("input[aria-label='Email']").fill(MFA_EMAIL)
|
||||
await page.locator("input[aria-label='Password']").fill(MFA_PASSWORD)
|
||||
await page.get_by_role("button", name="Sign in", exact=True).click()
|
||||
|
||||
|
||||
async def test_mfa_login_redirects_to_challenge(page: Page, mfa_user: User):
|
||||
"""Login with MFA-enabled user redirects to /mfa challenge page."""
|
||||
await _login_mfa_user(page)
|
||||
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
|
||||
await expect(page.locator("input[aria-label='Authentication Code']")).to_be_visible()
|
||||
|
||||
|
||||
async def test_mfa_valid_totp_completes_login(page: Page, mfa_user: User):
|
||||
"""Entering a valid TOTP code on /mfa completes login."""
|
||||
await _login_mfa_user(page)
|
||||
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
|
||||
|
||||
code = pyotp.TOTP(TOTP_SECRET).now()
|
||||
await page.locator("input[aria-label='Authentication Code']").fill(code)
|
||||
await page.get_by_role("button", name="Verify").click()
|
||||
|
||||
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)
|
||||
|
||||
|
||||
async def test_mfa_invalid_code_shows_error(page: Page, mfa_user: User):
|
||||
"""Entering an invalid TOTP code shows error and stays on /mfa."""
|
||||
await _login_mfa_user(page)
|
||||
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
|
||||
|
||||
await page.locator("input[aria-label='Authentication Code']").fill("000000")
|
||||
await page.get_by_role("button", name="Verify").click()
|
||||
|
||||
await expect(page.get_by_text("Invalid code")).to_be_visible(timeout=5_000)
|
||||
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible()
|
||||
|
||||
|
||||
async def test_mfa_cancel_returns_to_login(page: Page, mfa_user: User):
|
||||
"""Clicking Cancel on /mfa clears session and returns to login."""
|
||||
await _login_mfa_user(page)
|
||||
await expect(page.get_by_text("Two-Factor Authentication")).to_be_visible(timeout=10_000)
|
||||
|
||||
await page.get_by_role("button", name="Cancel").click()
|
||||
await expect(page.get_by_role("button", name="Sign in", exact=True)).to_be_visible(timeout=10_000)
|
||||
177
tests/e2e/test_saml_login.py
Normal file
177
tests/e2e/test_saml_login.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
"""E2E tests for SAML authentication — mock SimpleSAMLphp IdP.
|
||||
|
||||
Requires mock-saml service running (docker compose up -d mock-saml).
|
||||
IdP metadata: http://localhost:8080/simplesaml/saml2/idp/metadata.php
|
||||
Test users: user1/user1pass, user2/user2pass
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from playwright.async_api import Page, expect
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.user import User
|
||||
from tests.e2e.conftest import FAKE_SERVER_KEY, _cleanup_user_by_email
|
||||
|
||||
MOCK_SAML_HOST = os.environ.get("MOCK_SAML_HOST", "localhost")
|
||||
MOCK_SAML_METADATA_URL = f"http://{MOCK_SAML_HOST}:8080/simplesaml/saml2/idp/metadata.php"
|
||||
|
||||
# Separate app port for SAML tests (like OIDC IdP tests)
|
||||
SAML_APP_PORT = 13003
|
||||
SAML_APP_BASE = f"http://localhost:{SAML_APP_PORT}"
|
||||
|
||||
SAML_TEST_EMAIL = "user1@example.com"
|
||||
|
||||
|
||||
def _fetch_idp_metadata() -> str:
|
||||
"""Fetch IdP metadata XML from the mock SAML server."""
|
||||
try:
|
||||
r = httpx.get(MOCK_SAML_METADATA_URL, timeout=5)
|
||||
r.raise_for_status()
|
||||
return r.text
|
||||
except Exception:
|
||||
pytest.skip(f"Mock SAML IdP not available at {MOCK_SAML_METADATA_URL}")
|
||||
|
||||
|
||||
def _saml_provider_config(metadata: str) -> dict:
|
||||
return {
|
||||
"id": "test-saml",
|
||||
"label": "Sign in with Mock SAML",
|
||||
"metadata": metadata,
|
||||
"sign_requests": False,
|
||||
"sign_metadata": False,
|
||||
"signed_assertion_in_resp": False,
|
||||
"signed_envelopes_in_resp": False,
|
||||
"auto_create_users": True,
|
||||
"strict": False, # Relaxed for test IdP with expired certs
|
||||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def saml_metadata():
|
||||
return _fetch_idp_metadata()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app_with_saml(saml_metadata):
|
||||
"""Start a WireGUI instance with a SAML provider seeded in the DB."""
|
||||
import asyncio
|
||||
|
||||
# Seed the SAML provider config into the database
|
||||
async def _seed():
|
||||
async with async_session() as session:
|
||||
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||
if config is None:
|
||||
config = Configuration(server_public_key=FAKE_SERVER_KEY)
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
providers = list(config.saml_identity_providers or [])
|
||||
providers = [p for p in providers if p.get("id") != "test-saml"]
|
||||
providers.append(_saml_provider_config(saml_metadata))
|
||||
config.saml_identity_providers = providers
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(_seed())
|
||||
|
||||
env = os.environ.copy()
|
||||
env["WG_LOG_TO_FILE"] = "false"
|
||||
env["WG_PORT"] = str(SAML_APP_PORT)
|
||||
env["WG_EXTERNAL_URL"] = SAML_APP_BASE
|
||||
env.pop("PYTEST_CURRENT_TEST", None)
|
||||
env.pop("NICEGUI_SCREEN_TEST_PORT", None)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
["uv", "run", "python", "-m", "wiregui.main"],
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
|
||||
for _ in range(30):
|
||||
try:
|
||||
r = httpx.get(f"{SAML_APP_BASE}/api/health", timeout=1)
|
||||
if r.status_code == 200:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(1)
|
||||
else:
|
||||
proc.kill()
|
||||
out = proc.stdout.read().decode() if proc.stdout else ""
|
||||
pytest.fail(f"App did not start in time. Output:\n{out}")
|
||||
|
||||
yield proc
|
||||
|
||||
proc.terminate()
|
||||
proc.wait(timeout=10)
|
||||
|
||||
# Clean up seeded provider and test user
|
||||
async def _cleanup():
|
||||
await _cleanup_user_by_email(SAML_TEST_EMAIL)
|
||||
async with async_session() as session:
|
||||
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||
if config:
|
||||
config.saml_identity_providers = [
|
||||
p for p in (config.saml_identity_providers or []) if p.get("id") != "test-saml"
|
||||
]
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(_cleanup())
|
||||
|
||||
|
||||
async def test_saml_button_visible_on_login(app_with_saml, page: Page):
|
||||
"""Login page shows SAML provider button."""
|
||||
await page.goto(f"{SAML_APP_BASE}/login")
|
||||
await page.wait_for_load_state("networkidle")
|
||||
await expect(page.get_by_text("Sign in with Mock SAML")).to_be_visible(timeout=10_000)
|
||||
|
||||
|
||||
async def test_saml_redirect_to_idp(app_with_saml, page: Page):
|
||||
"""Clicking SAML login redirects to the SimpleSAMLphp IdP login page."""
|
||||
await page.goto(f"{SAML_APP_BASE}/auth/saml/test-saml")
|
||||
# Should redirect to the SimpleSAMLphp SSO service
|
||||
await page.wait_for_url(f"**{MOCK_SAML_HOST}:8080/simplesaml/**", timeout=10_000)
|
||||
|
||||
|
||||
async def test_saml_sp_metadata_endpoint(app_with_saml, page: Page):
|
||||
"""SP metadata endpoint returns valid XML."""
|
||||
response = await page.request.get(f"{SAML_APP_BASE}/auth/saml/test-saml/metadata")
|
||||
assert response.status == 200
|
||||
body = await response.text()
|
||||
assert "EntityDescriptor" in body
|
||||
assert "AssertionConsumerService" in body
|
||||
|
||||
|
||||
async def test_full_saml_login_flow(app_with_saml, page: Page):
|
||||
"""Full SAML SSO flow: app → IdP login → callback → authenticated."""
|
||||
await page.goto(f"{SAML_APP_BASE}/auth/saml/test-saml")
|
||||
await page.wait_for_url(f"**{MOCK_SAML_HOST}:8080/simplesaml/**", timeout=10_000)
|
||||
|
||||
# SimpleSAMLphp login form
|
||||
await page.locator("input[name='username']").fill("user1")
|
||||
await page.locator("input[name='password']").fill("password")
|
||||
await page.locator("button[type='submit'], input[type='submit']").first.click()
|
||||
|
||||
# Should redirect back to the app after SAML response
|
||||
await page.wait_for_url(f"{SAML_APP_BASE}/**", timeout=15_000)
|
||||
await page.wait_for_load_state("networkidle")
|
||||
await page.wait_for_timeout(3000)
|
||||
|
||||
assert "/login" not in page.url, f"SAML login failed — still on login page: {page.url}"
|
||||
|
||||
# Verify user was auto-created in DB
|
||||
async with async_session() as session:
|
||||
result = await session.execute(select(User).where(User.email == SAML_TEST_EMAIL))
|
||||
user = result.scalar_one_or_none()
|
||||
assert user is not None, f"Expected user {SAML_TEST_EMAIL} to be auto-created"
|
||||
assert user.last_signed_in_method == "saml:test-saml"
|
||||
263
tests/test_api_deps.py
Normal file
263
tests/test_api_deps.py
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
"""Tests for API dependency injection — Bearer token auth and admin guard."""
|
||||
|
||||
import hashlib
|
||||
from datetime import timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from wiregui.auth.api_token import generate_api_token
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# ========== resolve_bearer_token ==========
|
||||
|
||||
|
||||
async def test_resolve_valid_token():
|
||||
"""Valid, non-expired token resolves to user."""
|
||||
from wiregui.auth.api_token import resolve_bearer_token
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
|
||||
async with async_session() as session:
|
||||
user = User(email="api-test@test.com", password_hash=hash_password("x"), role="admin")
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
api_token = ApiToken(
|
||||
token_hash=token_hash,
|
||||
user_id=user.id,
|
||||
expires_at=utcnow() + timedelta(hours=1),
|
||||
)
|
||||
session.add(api_token)
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
async with async_session() as session:
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is not None
|
||||
assert resolved.id == user.id
|
||||
assert resolved.email == "api-test@test.com"
|
||||
finally:
|
||||
async with async_session() as session:
|
||||
await session.delete(await session.get(ApiToken, api_token.id))
|
||||
await session.delete(await session.get(User, user.id))
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def test_resolve_expired_token():
|
||||
"""Expired token returns None."""
|
||||
from wiregui.auth.api_token import resolve_bearer_token
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
|
||||
async with async_session() as session:
|
||||
user = User(email="api-expired@test.com", password_hash=hash_password("x"), role="admin")
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
api_token = ApiToken(
|
||||
token_hash=token_hash,
|
||||
user_id=user.id,
|
||||
expires_at=utcnow() - timedelta(hours=1), # already expired
|
||||
)
|
||||
session.add(api_token)
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
async with async_session() as session:
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is None
|
||||
finally:
|
||||
async with async_session() as session:
|
||||
await session.delete(await session.get(ApiToken, api_token.id))
|
||||
await session.delete(await session.get(User, user.id))
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def test_resolve_invalid_token():
|
||||
"""Nonexistent token returns None."""
|
||||
from wiregui.auth.api_token import resolve_bearer_token
|
||||
|
||||
async with async_session() as session:
|
||||
resolved = await resolve_bearer_token(session, "totally-bogus-token")
|
||||
assert resolved is None
|
||||
|
||||
|
||||
async def test_resolve_token_disabled_user():
|
||||
"""Token for disabled user returns None."""
|
||||
from wiregui.auth.api_token import resolve_bearer_token
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
|
||||
async with async_session() as session:
|
||||
user = User(
|
||||
email="api-disabled@test.com", password_hash=hash_password("x"),
|
||||
role="admin", disabled_at=utcnow(),
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
api_token = ApiToken(
|
||||
token_hash=token_hash,
|
||||
user_id=user.id,
|
||||
expires_at=utcnow() + timedelta(hours=1),
|
||||
)
|
||||
session.add(api_token)
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
async with async_session() as session:
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is None
|
||||
finally:
|
||||
async with async_session() as session:
|
||||
await session.delete(await session.get(ApiToken, api_token.id))
|
||||
await session.delete(await session.get(User, user.id))
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def test_resolve_token_no_expiry():
|
||||
"""Token without expires_at (never expires) resolves successfully."""
|
||||
from wiregui.auth.api_token import resolve_bearer_token
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
|
||||
async with async_session() as session:
|
||||
user = User(email="api-noexp@test.com", password_hash=hash_password("x"), role="admin")
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
api_token = ApiToken(
|
||||
token_hash=token_hash,
|
||||
user_id=user.id,
|
||||
expires_at=None,
|
||||
)
|
||||
session.add(api_token)
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
async with async_session() as session:
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is not None
|
||||
assert resolved.id == user.id
|
||||
finally:
|
||||
async with async_session() as session:
|
||||
await session.delete(await session.get(ApiToken, api_token.id))
|
||||
await session.delete(await session.get(User, user.id))
|
||||
await session.commit()
|
||||
|
||||
|
||||
# ========== get_current_api_user (via FastAPI deps) ==========
|
||||
|
||||
|
||||
async def test_get_current_api_user_missing_header():
|
||||
"""Missing Authorization header raises 401."""
|
||||
from fastapi import HTTPException
|
||||
from wiregui.api.deps import get_current_api_user
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_api_user(request, session=AsyncMock())
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Missing" in exc_info.value.detail
|
||||
|
||||
|
||||
async def test_get_current_api_user_bad_scheme():
|
||||
"""Non-Bearer auth scheme raises 401."""
|
||||
from fastapi import HTTPException
|
||||
from wiregui.api.deps import get_current_api_user
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"Authorization": "Basic dXNlcjpwYXNz"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_api_user(request, session=AsyncMock())
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
async def test_get_current_api_user_invalid_token():
|
||||
"""Valid Bearer scheme but bogus token raises 401."""
|
||||
from fastapi import HTTPException
|
||||
from wiregui.api.deps import get_current_api_user
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"Authorization": "Bearer bogus-token-value"}
|
||||
|
||||
async with async_session() as session:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_api_user(request, session=session)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid" in exc_info.value.detail
|
||||
|
||||
|
||||
async def test_get_current_api_user_valid_token():
|
||||
"""Valid Bearer token resolves to user."""
|
||||
from wiregui.api.deps import get_current_api_user
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
|
||||
async with async_session() as session:
|
||||
user = User(email="api-dep-test@test.com", password_hash=hash_password("x"), role="admin")
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
api_token = ApiToken(
|
||||
token_hash=token_hash,
|
||||
user_id=user.id,
|
||||
expires_at=utcnow() + timedelta(hours=1),
|
||||
)
|
||||
session.add(api_token)
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
request = MagicMock()
|
||||
request.headers = {"Authorization": f"Bearer {plaintext}"}
|
||||
|
||||
async with async_session() as session:
|
||||
resolved = await get_current_api_user(request, session=session)
|
||||
assert resolved.id == user.id
|
||||
finally:
|
||||
async with async_session() as session:
|
||||
await session.delete(await session.get(ApiToken, api_token.id))
|
||||
await session.delete(await session.get(User, user.id))
|
||||
await session.commit()
|
||||
|
||||
|
||||
# ========== require_admin ==========
|
||||
|
||||
|
||||
async def test_require_admin_allows_admin():
|
||||
"""Admin user passes require_admin."""
|
||||
from wiregui.api.deps import require_admin
|
||||
|
||||
admin_user = MagicMock(spec=User)
|
||||
admin_user.role = "admin"
|
||||
result = await require_admin(user=admin_user)
|
||||
assert result == admin_user
|
||||
|
||||
|
||||
async def test_require_admin_rejects_unprivileged():
|
||||
"""Non-admin user gets 403."""
|
||||
from fastapi import HTTPException
|
||||
from wiregui.api.deps import require_admin
|
||||
|
||||
regular_user = MagicMock(spec=User)
|
||||
regular_user.role = "unprivileged"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await require_admin(user=regular_user)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin" in exc_info.value.detail
|
||||
206
tests/test_firewall_extended.py
Normal file
206
tests/test_firewall_extended.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
"""Extended firewall tests — _nft/_nft_batch error handling, add_device_jump_rule edge cases, policies."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from wiregui.services.firewall import (
|
||||
_nft,
|
||||
_nft_batch,
|
||||
add_device_jump_rule,
|
||||
setup_base_tables,
|
||||
setup_masquerade,
|
||||
apply_peer_to_peer_policy,
|
||||
apply_lan_to_peers_policy,
|
||||
get_ruleset,
|
||||
)
|
||||
|
||||
|
||||
# ========== _nft error handling ==========
|
||||
|
||||
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_nft_raises_on_failure(mock_exec):
|
||||
"""_nft raises RuntimeError on non-zero exit code."""
|
||||
mock_proc = AsyncMock()
|
||||
mock_proc.communicate.return_value = (b"", b"nft: error message")
|
||||
mock_proc.returncode = 1
|
||||
mock_exec.return_value = mock_proc
|
||||
|
||||
with pytest.raises(RuntimeError, match="nft.*failed"):
|
||||
await _nft("list ruleset")
|
||||
|
||||
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_nft_returns_stdout_on_success(mock_exec):
|
||||
"""_nft returns stdout on success."""
|
||||
mock_proc = AsyncMock()
|
||||
mock_proc.communicate.return_value = (b"table inet wiregui {}", b"")
|
||||
mock_proc.returncode = 0
|
||||
mock_exec.return_value = mock_proc
|
||||
|
||||
result = await _nft("list ruleset")
|
||||
assert "wiregui" in result
|
||||
|
||||
|
||||
# ========== _nft_batch error handling ==========
|
||||
|
||||
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_nft_batch_raises_on_failure(mock_exec):
|
||||
"""_nft_batch raises RuntimeError on non-zero exit code."""
|
||||
mock_proc = AsyncMock()
|
||||
mock_proc.communicate.return_value = (b"", b"Error: syntax error")
|
||||
mock_proc.returncode = 1
|
||||
mock_exec.return_value = mock_proc
|
||||
|
||||
with pytest.raises(RuntimeError, match="nft batch failed"):
|
||||
await _nft_batch(["add table inet wiregui"])
|
||||
|
||||
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_nft_batch_sends_commands_via_stdin(mock_exec):
|
||||
"""_nft_batch sends all commands via stdin to nft -f -."""
|
||||
mock_proc = AsyncMock()
|
||||
mock_proc.communicate.return_value = (b"", b"")
|
||||
mock_proc.returncode = 0
|
||||
mock_exec.return_value = mock_proc
|
||||
|
||||
cmds = ["add table inet wiregui", "add chain inet wiregui test"]
|
||||
await _nft_batch(cmds)
|
||||
|
||||
mock_exec.assert_awaited_once()
|
||||
# Verify nft -f - was called
|
||||
call_args = mock_exec.call_args[0]
|
||||
assert call_args == ("nft", "-f", "-")
|
||||
# Verify stdin data
|
||||
stdin_data = mock_proc.communicate.call_args[0][0]
|
||||
assert b"add table inet wiregui" in stdin_data
|
||||
assert b"add chain inet wiregui test" in stdin_data
|
||||
|
||||
|
||||
# ========== add_device_jump_rule edge cases ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_add_device_jump_rule_ipv4_only(mock_batch):
|
||||
"""Only IPv4 — generates single IPv4 jump rule."""
|
||||
await add_device_jump_rule("user-id-1", "10.0.0.5", None)
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert len(cmds) == 1
|
||||
assert "ip saddr 10.0.0.5" in cmds[0]
|
||||
assert "jump" in cmds[0]
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_add_device_jump_rule_ipv6_only(mock_batch):
|
||||
"""Only IPv6 — generates single IPv6 jump rule."""
|
||||
await add_device_jump_rule("user-id-2", None, "fd00::5")
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert len(cmds) == 1
|
||||
assert "ip6 saddr fd00::5" in cmds[0]
|
||||
assert "jump" in cmds[0]
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_add_device_jump_rule_no_ips(mock_batch):
|
||||
"""Neither IPv4 nor IPv6 — no nft commands issued."""
|
||||
await add_device_jump_rule("user-id-3", None, None)
|
||||
mock_batch.assert_not_awaited()
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_add_device_jump_rule_both_ips(mock_batch):
|
||||
"""Both IPv4 and IPv6 — generates two jump rules."""
|
||||
await add_device_jump_rule("user-id-4", "10.0.0.7", "fd00::7")
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert len(cmds) == 2
|
||||
assert any("ip saddr 10.0.0.7" in c for c in cmds)
|
||||
assert any("ip6 saddr fd00::7" in c for c in cmds)
|
||||
|
||||
|
||||
# ========== setup_base_tables — already exists ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_setup_base_tables_already_exists(mock_batch):
|
||||
"""If table already exists (File exists error), don't raise."""
|
||||
mock_batch.side_effect = RuntimeError("File exists")
|
||||
await setup_base_tables() # should not raise
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_setup_base_tables_other_error_raises(mock_batch):
|
||||
"""Other nft errors should propagate."""
|
||||
mock_batch.side_effect = RuntimeError("Permission denied")
|
||||
with pytest.raises(RuntimeError, match="Permission denied"):
|
||||
await setup_base_tables()
|
||||
|
||||
|
||||
# ========== setup_masquerade — error handling ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_setup_masquerade_error_swallowed(mock_batch):
|
||||
"""Masquerade errors are logged but not raised."""
|
||||
mock_batch.side_effect = RuntimeError("nft error")
|
||||
await setup_masquerade(iface="wg0") # should not raise
|
||||
|
||||
|
||||
# ========== policy functions — command verification ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_peer_to_peer_enabled(mock_batch):
|
||||
"""Enabling peer-to-peer generates accept rules."""
|
||||
await apply_peer_to_peer_policy(True)
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("accept" in c for c in cmds)
|
||||
assert any("peer_to_peer" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_peer_to_peer_disabled(mock_batch):
|
||||
"""Disabling peer-to-peer generates drop rules."""
|
||||
await apply_peer_to_peer_policy(False)
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("drop" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_lan_to_peers_enabled(mock_batch):
|
||||
"""Enabling LAN-to-peers generates accept rules."""
|
||||
await apply_lan_to_peers_policy(True)
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("accept" in c for c in cmds)
|
||||
assert any("lan_to_peers" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_lan_to_peers_disabled(mock_batch):
|
||||
"""Disabling LAN-to-peers generates drop rules."""
|
||||
await apply_lan_to_peers_policy(False)
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("drop" in c for c in cmds)
|
||||
|
||||
|
||||
# ========== get_ruleset — error handling ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft", new_callable=AsyncMock)
|
||||
async def test_get_ruleset_returns_output(mock_nft):
|
||||
"""get_ruleset returns nft list ruleset output."""
|
||||
mock_nft.return_value = "table inet wiregui { ... }"
|
||||
result = await get_ruleset()
|
||||
assert "wiregui" in result
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft", new_callable=AsyncMock)
|
||||
async def test_get_ruleset_returns_fallback_on_error(mock_nft):
|
||||
"""get_ruleset returns friendly message when nft not available."""
|
||||
mock_nft.side_effect = RuntimeError("nft not found")
|
||||
result = await get_ruleset()
|
||||
assert "not available" in result
|
||||
114
tests/test_wireguard_extended.py
Normal file
114
tests/test_wireguard_extended.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
"""Tests for WireGuard service — ensure_interface, set_private_key, set_listen_port, configure_interface."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch, call
|
||||
|
||||
from wiregui.services.wireguard import (
|
||||
ensure_interface,
|
||||
set_private_key,
|
||||
set_listen_port,
|
||||
configure_interface,
|
||||
)
|
||||
|
||||
|
||||
# ========== ensure_interface ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
async def test_ensure_interface_already_exists(mock_run):
|
||||
"""If interface exists (ip link show succeeds), do nothing."""
|
||||
mock_run.return_value = ""
|
||||
await ensure_interface(iface="wg-test")
|
||||
# Only called once for ip link show
|
||||
mock_run.assert_awaited_once_with(["ip", "link", "show", "wg-test"])
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
async def test_ensure_interface_creates_new(mock_run):
|
||||
"""If interface doesn't exist, create it, assign IPs, bring up."""
|
||||
call_count = 0
|
||||
|
||||
async def side_effect(args, input_data=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1 and args == ["ip", "link", "show", "wg-test"]:
|
||||
raise RuntimeError("Device not found")
|
||||
return ""
|
||||
|
||||
mock_run.side_effect = side_effect
|
||||
await ensure_interface(iface="wg-test")
|
||||
|
||||
# Should have called: ip link show (fails), ip link add, ip addr add x2, ip link set up
|
||||
assert mock_run.await_count == 5
|
||||
calls = [c[0][0] for c in mock_run.call_args_list]
|
||||
assert calls[1] == ["ip", "link", "add", "wg-test", "type", "wireguard"]
|
||||
assert calls[2][0:3] == ["ip", "address", "add"]
|
||||
assert calls[3][0:3] == ["ip", "address", "add"]
|
||||
assert calls[4] == ["ip", "link", "set", "wg-test", "up"]
|
||||
|
||||
|
||||
# ========== set_private_key ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
async def test_set_private_key(mock_run):
|
||||
"""set_private_key calls wg set with private-key path."""
|
||||
mock_run.return_value = ""
|
||||
await set_private_key("/tmp/test.key", iface="wg-test")
|
||||
mock_run.assert_awaited_once_with(["wg", "set", "wg-test", "private-key", "/tmp/test.key"])
|
||||
|
||||
|
||||
# ========== set_listen_port ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
async def test_set_listen_port(mock_run):
|
||||
"""set_listen_port calls wg set with listen-port."""
|
||||
mock_run.return_value = ""
|
||||
await set_listen_port(51820, iface="wg-test")
|
||||
mock_run.assert_awaited_once_with(["wg", "set", "wg-test", "listen-port", "51820"])
|
||||
|
||||
|
||||
# ========== configure_interface ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
@patch("wiregui.db.async_session")
|
||||
async def test_configure_interface_no_config(mock_session_cls, mock_run):
|
||||
"""If no Configuration row exists, do not call wg set."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
await configure_interface(iface="wg-test")
|
||||
mock_run.assert_not_awaited()
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
@patch("wiregui.db.async_session")
|
||||
async def test_configure_interface_sets_key_and_port(mock_session_cls, mock_run):
|
||||
"""With valid config, writes key to temp file and calls wg set."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.server_private_key = "test-private-key-value"
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_config
|
||||
mock_session.execute.return_value = mock_result
|
||||
mock_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_run.return_value = ""
|
||||
await configure_interface(iface="wg-test")
|
||||
|
||||
mock_run.assert_awaited_once()
|
||||
args = mock_run.call_args[0][0]
|
||||
assert args[0:3] == ["wg", "set", "wg-test"]
|
||||
assert "private-key" in args
|
||||
assert "listen-port" in args
|
||||
Loading…
Add table
Add a link
Reference in a new issue