diff --git a/.forgejo/workflows/dev.yml b/.forgejo/workflows/dev.yml new file mode 100644 index 0000000..676a2f2 --- /dev/null +++ b/.forgejo/workflows/dev.yml @@ -0,0 +1,49 @@ +name: Dev + +on: + push: + branches: + - dev + +jobs: + docker: + runs-on: docker + container: + image: catthehacker/ubuntu:act-latest + options: --privileged + steps: + - name: Checkout repository + run: | + git clone ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git -b dev . + git fetch origin main --tags + + - name: Build and push pre-release image + shell: bash + env: + REGISTRY_TOKEN: ${{ secrets.REGISTRY_TOKEN }} + run: | + # Derive version from latest tag on main: v1.2.3 -> 1.2.3.dev0, .dev1, etc. + LATEST_TAG=$(git describe --tags --abbrev=0 origin/main 2>/dev/null || echo "v0.0.0") + BASE_VERSION="${LATEST_TAG#v}" + # Count commits on dev since that tag + DEV_N=$(git rev-list --count "${LATEST_TAG}..HEAD" 2>/dev/null || echo "0") + VERSION="${BASE_VERSION}.dev${DEV_N}" + + REGISTRY=$(echo "${{ github.server_url }}" | sed 's|https://||; s|http://||') + IMAGE="${REGISTRY}/${{ github.repository_owner }}/wiregui" + + echo "Building ${IMAGE}:v${VERSION}" + + echo "${REGISTRY_TOKEN}" | docker login "${REGISTRY}" \ + -u "${{ github.repository_owner }}" --password-stdin + + docker build --no-cache \ + --build-arg "VERSION=${VERSION}" \ + -t "${IMAGE}:v${VERSION}" \ + -t "${IMAGE}:dev" \ + . + + docker push "${IMAGE}:v${VERSION}" + docker push "${IMAGE}:dev" + + echo "Pushed ${IMAGE}:v${VERSION}, ${IMAGE}:dev" diff --git a/TODO.md b/TODO.md index 5d1124a..340c382 100644 --- a/TODO.md +++ b/TODO.md @@ -1,3 +1,11 @@ +# WireGUI — Pending Items + +**Test count: 174 (164 unit + 10 E2E) | Coverage: ~35%** + +--- + +## Testing + # WireGUI Implementation TODO Migration of Wirezone (Elixir/Phoenix) to Python/NiceGUI. @@ -15,38 +23,13 @@ Source: `/home/stefanob/PycharmProjects/personal/wirezone` ### Testing (partially done) - [ ] HTTP-level integration tests (OIDC redirect/callback flow with respx mocking) - -### Coverage gaps (35% overall — run `uv run pytest --cov=wiregui --cov-report=term-missing --cov-branch`) - -**100% covered:** models, schemas, config, auth/passwords, auth/jwt, auth/mfa, auth/api_token, utils/crypto, utils/time, services/notifications - -**API routes (32-84% — partially covered via httpx TestClient):** -- [x] `wiregui/api/v0/users.py` (84%) — list/get/create/update/delete -- [x] `wiregui/api/v0/rules.py` (71%) — CRUD -- [x] `wiregui/api/v0/devices.py` (67%) — CRUD, permissions -- [x] `wiregui/api/v0/configuration.py` (61%) — get/update, auto-create -- [ ] `wiregui/api/deps.py` (32%) — test get_current_api_user with real Bearer header parsing, require_admin rejection - -**Services (62-89% covered):** -- [x] `wiregui/services/wireguard.py` (62%) — add/remove/get peers mocked -- [x] `wiregui/services/firewall.py` (73%) — base tables, chains, rules, rebuild mocked -- [x] `wiregui/services/events.py` (80%) — device + rule events, rebuild chain -- [x] `wiregui/services/email.py` (89%) — send_email, magic link, no-smtp fallback +- [ ] `wiregui/api/deps.py` — test get_current_api_user with real Bearer header parsing, require_admin rejection - [ ] `wiregui/services/wireguard.py` — test ensure_interface, set_private_key, set_listen_port - [ ] `wiregui/services/firewall.py` — test _nft/_nft_batch error handling, add_device_jump_rule with only ipv4/ipv6 - -**Tasks (40-84% covered):** -- [x] `wiregui/tasks/stats.py` (77%) — update from peers, no-op, unmatched peer -- [x] `wiregui/tasks/reconcile.py` (84%) — add missing, remove orphaned, in-sync -- [x] `wiregui/tasks/oidc_refresh.py` (40%) — no connections, skip unknown provider - [ ] `wiregui/tasks/oidc_refresh.py` — test successful refresh, failure with notification, disable_vpn_on_oidc_error - -**Auth modules (85-92% covered):** -- [x] `wiregui/auth/oidc.py` (87%) — register providers, get_client, load from config -- [x] `wiregui/auth/webauthn.py` (85%) — registration/authentication options -- [x] `wiregui/auth/session.py` (90%) — no-password, disabled, nonexistent user - [ ] `wiregui/auth/saml.py` (0%) — needs mock SAML IdP metadata + response parsing - [ ] `wiregui/auth/webauthn.py` — test verify_registration, verify_authentication with mock credential data +- [ ] E2E tests for admin pages (users, devices, rules, settings) **E2E page tests (Playwright async API in `tests/e2e/`):** - [x] `tests/e2e/test_login.py` (6 tests) — valid login, invalid password, nonexistent email, disabled user, logout, unauthenticated redirect @@ -112,6 +95,10 @@ Source: `/home/stefanob/PycharmProjects/personal/wirezone` - [ ] Device detail: delete with confirmation → redirects to /devices - [ ] Auto-refresh: stats labels update after timer fires (mock timer) +## UI + +- [ ] SSO Providers on account page: add Status column, "Disconnect" action +- [ ] Admin pages (users, devices, rules): apply same card-based styling as account/settings/diagnostics `tests/e2e/test_account_extended.py` — Account Page (additional): - [ ] SSO providers section shows connected providers - [ ] SSO providers section shows "No SSO providers" when empty @@ -124,6 +111,8 @@ Source: `/home/stefanob/PycharmProjects/personal/wirezone` - [ ] Danger zone: wrong email in confirmation → shows error +## Features + ### Deployment ✅ - [ ] First-run CLI setup command diff --git a/alembic/versions/b7e2f4a1c903_add_firewall_policy_fields.py b/alembic/versions/b7e2f4a1c903_add_firewall_policy_fields.py new file mode 100644 index 0000000..3ff2359 --- /dev/null +++ b/alembic/versions/b7e2f4a1c903_add_firewall_policy_fields.py @@ -0,0 +1,28 @@ +"""add firewall policy fields to configurations + +Revision ID: b7e2f4a1c903 +Revises: a3f1d8e92b01 +Create Date: 2026-03-31 00:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b7e2f4a1c903' +down_revision: Union[str, None] = 'a3f1d8e92b01' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('configurations', sa.Column('allow_peer_to_peer', sa.Boolean(), nullable=False, server_default='false')) + op.add_column('configurations', sa.Column('allow_lan_to_peers', sa.Boolean(), nullable=False, server_default='false')) + + +def downgrade() -> None: + op.drop_column('configurations', 'allow_lan_to_peers') + op.drop_column('configurations', 'allow_peer_to_peer') diff --git a/wiregui/models/configuration.py b/wiregui/models/configuration.py index 68b228a..6186d5d 100644 --- a/wiregui/models/configuration.py +++ b/wiregui/models/configuration.py @@ -32,6 +32,10 @@ class Configuration(SQLModel, table=True): sa_column=Column(JSON, default=["0.0.0.0/0", "::/0"]), ) + # Firewall policies + allow_peer_to_peer: bool = Field(default=False) + allow_lan_to_peers: bool = Field(default=False) + # Server WireGuard keypair (generated on first startup) server_private_key: str | None = None server_public_key: str | None = None diff --git a/wiregui/pages/admin/devices.py b/wiregui/pages/admin/devices.py index 8d052a2..9ac7ab2 100644 --- a/wiregui/pages/admin/devices.py +++ b/wiregui/pages/admin/devices.py @@ -46,10 +46,21 @@ async def admin_devices_page(): layout() - # Load users for filter and create form + settings = get_settings() + + # Load users and client defaults async with async_session() as session: users = (await session.execute(select(User).order_by(User.email))).scalars().all() + from wiregui.models.configuration import Configuration + _db_cfg = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() user_map = {str(u.id): u.email for u in users} + _defaults = { + "allowed_ips": ", ".join(_db_cfg.default_client_allowed_ips) if _db_cfg and _db_cfg.default_client_allowed_ips else settings.wg_allowed_ips, + "dns": ", ".join(_db_cfg.default_client_dns) if _db_cfg and _db_cfg.default_client_dns else settings.wg_dns, + "endpoint": _db_cfg.default_client_endpoint if _db_cfg and _db_cfg.default_client_endpoint else settings.wg_endpoint_host, + "mtu": str(_db_cfg.default_client_mtu) if _db_cfg else str(settings.wg_mtu), + "keepalive": str(_db_cfg.default_client_persistent_keepalive) if _db_cfg else str(settings.wg_persistent_keepalive), + } async def load_devices(user_filter: str | None = None) -> list[dict]: async with async_session() as session: @@ -127,7 +138,11 @@ async def admin_devices_page(): # Build config and show dialog immediately — don't wait for WG/firewall server_pubkey = await get_server_public_key() - config_text = build_client_config(device, private_key, server_pubkey) + async with async_session() as session: + from sqlmodel import select as sel + from wiregui.models.configuration import Configuration + db_config = (await session.execute(sel(Configuration).limit(1))).scalar_one_or_none() + config_text = build_client_config(device, private_key, server_pubkey, db_config) create_dialog.close() _reset_create_form() @@ -152,6 +167,11 @@ async def admin_devices_page(): create_use_default_endpoint.value = True create_use_default_mtu.value = True create_use_default_keepalive.value = True + create_allowed_ips.value = _defaults["allowed_ips"] + create_dns.value = _defaults["dns"] + create_endpoint.value = _defaults["endpoint"] + create_mtu.value = _defaults["mtu"] + create_keepalive.value = _defaults["keepalive"] # --- Edit device --- edit_device_id = {"value": None} @@ -285,19 +305,19 @@ async def admin_devices_page(): with ui.grid(columns=2).classes("w-full gap-2"): create_use_default_ips = ui.switch("Use default Allowed IPs", value=True) - create_allowed_ips = ui.input("Allowed IPs", placeholder="0.0.0.0/0, ::/0").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_ips, "value", backward=lambda v: not v) + create_allowed_ips = ui.input("Allowed IPs", value=_defaults["allowed_ips"]).props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_ips, "value", backward=lambda v: not v) create_use_default_dns = ui.switch("Use default DNS", value=True) - create_dns = ui.input("DNS Servers", placeholder="1.1.1.1, 1.0.0.1").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_dns, "value", backward=lambda v: not v) + create_dns = ui.input("DNS Servers", value=_defaults["dns"]).props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_dns, "value", backward=lambda v: not v) create_use_default_endpoint = ui.switch("Use default Endpoint", value=True) - create_endpoint = ui.input("Endpoint", placeholder="vpn.example.com").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_endpoint, "value", backward=lambda v: not v) + create_endpoint = ui.input("Endpoint", value=_defaults["endpoint"]).props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_endpoint, "value", backward=lambda v: not v) create_use_default_mtu = ui.switch("Use default MTU", value=True) - create_mtu = ui.input("MTU", placeholder="1280").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_mtu, "value", backward=lambda v: not v) + create_mtu = ui.input("MTU", value=_defaults["mtu"]).props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_mtu, "value", backward=lambda v: not v) create_use_default_keepalive = ui.switch("Use default Keepalive", value=True) - create_keepalive = ui.input("Persistent Keepalive", placeholder="25").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_keepalive, "value", backward=lambda v: not v) + create_keepalive = ui.input("Persistent Keepalive", value=_defaults["keepalive"]).props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_keepalive, "value", backward=lambda v: not v) with ui.row().classes("w-full justify-end q-mt-sm"): ui.button("Cancel", on_click=create_dialog.close).props("flat") diff --git a/wiregui/pages/admin/rules.py b/wiregui/pages/admin/rules.py index 982b71d..e992867 100644 --- a/wiregui/pages/admin/rules.py +++ b/wiregui/pages/admin/rules.py @@ -1,5 +1,6 @@ """Admin firewall rules management page.""" +import asyncio from uuid import UUID from loguru import logger @@ -7,10 +8,13 @@ from nicegui import app, ui from sqlmodel import select from wiregui.db import async_session +from wiregui.models.configuration import Configuration from wiregui.models.rule import Rule from wiregui.models.user import User from wiregui.pages.layout import layout from wiregui.services.events import on_rule_created, on_rule_deleted, on_rule_updated +from wiregui.services.firewall import apply_lan_to_peers_policy, apply_peer_to_peer_policy, get_ruleset +from wiregui.utils.time import utcnow @ui.page("/admin/rules") @@ -23,6 +27,7 @@ async def rules_page(): # Load users for the dropdown async with async_session() as session: users = (await session.execute(select(User).order_by(User.email))).scalars().all() + config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() user_options = {str(u.id): u.email for u in users} async def load_rules() -> list[dict]: @@ -69,7 +74,7 @@ async def rules_page(): await session.refresh(rule) logger.info("Rule created: {} {} -> {}", rule.action, rule.destination, user_id_val or "global") - await on_rule_created(rule) + asyncio.create_task(on_rule_created(rule)) create_dialog.close() _reset_form() @@ -110,7 +115,7 @@ async def rules_page(): session.add(rule) await session.commit() await session.refresh(rule) - await on_rule_updated(rule) + asyncio.create_task(on_rule_updated(rule)) logger.info("Rule updated: {} {}", edit_action.value, edit_dest.value) ui.notify("Rule updated") @@ -124,7 +129,7 @@ async def rules_page(): await session.delete(rule) await session.commit() logger.info("Rule deleted: {} {}", rule.action, rule.destination) - await on_rule_deleted(rule) + asyncio.create_task(on_rule_deleted(rule)) await refresh_table() def _reset_form(): @@ -134,34 +139,103 @@ async def rules_page(): port_range_input.value = "" user_select.value = "global" - # Page content - with ui.column().classes("w-full p-4"): - with ui.row().classes("w-full items-center justify-between"): - ui.label("Firewall Rules").classes("text-h5") - ui.button("Add Rule", icon="add", on_click=lambda: create_dialog.open()).props("color=primary") + # --- Firewall policy toggles --- + async def toggle_peer_to_peer(e): + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() + if c: + c.allow_peer_to_peer = e.value + c.updated_at = utcnow() + session.add(c) + await session.commit() + asyncio.create_task(apply_peer_to_peer_policy(e.value)) + ui.notify(f"Peer-to-peer: {'allowed' if e.value else 'denied'}") - columns = [ - {"name": "action", "label": "Action", "field": "action", "align": "left", "sortable": True}, - {"name": "destination", "label": "Destination", "field": "destination", "align": "left", "sortable": True}, - {"name": "port_type", "label": "Protocol", "field": "port_type", "align": "left"}, - {"name": "port_range", "label": "Port(s)", "field": "port_range", "align": "left"}, - {"name": "user", "label": "User", "field": "user", "align": "left"}, - {"name": "actions", "label": "", "field": "id", "align": "center"}, - ] - table = ui.table(columns=columns, rows=[], row_key="id").classes("w-full") - table.add_slot( - "body-cell-actions", - ''' - - - - - ''', - ) - table.on("edit", lambda e: open_edit(e.args)) - table.on("delete", lambda e: delete_rule(e.args)) + async def toggle_lan_to_peers(e): + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() + if c: + c.allow_lan_to_peers = e.value + c.updated_at = utcnow() + session.add(c) + await session.commit() + asyncio.create_task(apply_lan_to_peers_policy(e.value)) + ui.notify(f"LAN-to-peers: {'allowed' if e.value else 'denied'}") + + # --- Troubleshooting --- + async def show_nft_rules(): + ruleset = await get_ruleset() + with ui.dialog(value=True) as dlg: + with ui.card().classes("w-[800px]"): + ui.label("nftables Ruleset").classes("text-subtitle1 text-bold") + ui.label("Current system firewall rules for troubleshooting.").classes("text-caption text-grey") + ui.separator() + ui.textarea(value=ruleset).props("readonly outlined").classes( + "w-full font-mono text-xs" + ).style("min-height: 400px; white-space: pre") + with ui.row().classes("w-full justify-end q-mt-sm"): + ui.button("Close", on_click=dlg.close).props("flat") + + # --- Page content --- + with ui.column().classes("w-full p-4"): + ui.label("Firewall Rules").classes("text-h5 q-mb-md") + + # Policy switches + with ui.card().classes("w-full"): + ui.label("Network Policies").classes("text-subtitle1 text-bold") + ui.label("Control how traffic flows between peers and the local network.").classes("text-caption text-grey") + ui.separator() + + ui.switch( + "Allow peer-to-peer communication", + value=config.allow_peer_to_peer if config else False, + on_change=toggle_peer_to_peer, + ) + ui.label("Peers can communicate with each other through the WireGuard server (hub-and-spoke).").classes("text-caption text-grey q-ml-xl") + + ui.switch( + "Allow local network to reach peers", + value=config.allow_lan_to_peers if config else False, + on_change=toggle_lan_to_peers, + ).classes("q-mt-sm") + ui.label("Devices on the server's LAN can initiate connections to VPN peers.").classes("text-caption text-grey q-ml-xl") + + # Rules table + with ui.card().classes("w-full q-mt-md"): + with ui.row().classes("w-full items-center justify-between"): + ui.label("Per-User Rules").classes("text-subtitle1 text-bold") + ui.button("Add Rule", icon="add", on_click=lambda: create_dialog.open()).props("color=primary unelevated") + ui.separator() + + columns = [ + {"name": "action", "label": "Action", "field": "action", "align": "left", "sortable": True}, + {"name": "destination", "label": "Destination", "field": "destination", "align": "left", "sortable": True}, + {"name": "port_type", "label": "Protocol", "field": "port_type", "align": "left"}, + {"name": "port_range", "label": "Port(s)", "field": "port_range", "align": "left"}, + {"name": "user", "label": "User", "field": "user", "align": "left"}, + {"name": "actions", "label": "", "field": "id", "align": "center"}, + ] + table = ui.table(columns=columns, rows=[], row_key="id").classes("w-full") + table.add_slot( + "body-cell-actions", + ''' + + + + + ''', + ) + table.on("edit", lambda e: open_edit(e.args)) + table.on("delete", lambda e: delete_rule(e.args)) + + # Troubleshooting + with ui.card().classes("w-full q-mt-md"): + ui.label("Troubleshooting").classes("text-subtitle1 text-bold") + ui.label("Inspect the raw nftables ruleset configured on this system.").classes("text-caption text-grey") + ui.separator() + ui.button("View nftables Rules", icon="terminal", on_click=show_nft_rules).props("color=primary unelevated") # Create rule dialog with ui.dialog() as create_dialog: @@ -195,7 +269,7 @@ async def rules_page(): with ui.row().classes("w-full justify-end q-mt-sm"): ui.button("Cancel", on_click=create_dialog.close).props("flat") - ui.button("Create", on_click=create_rule).props("color=primary") + ui.button("Create", on_click=create_rule).props("color=primary unelevated") # Edit rule dialog user_options_map = {"global": "Global (all users)"} @@ -223,6 +297,6 @@ async def rules_page(): with ui.row().classes("w-full justify-end q-mt-sm"): ui.button("Cancel", on_click=edit_dialog.close).props("flat") - ui.button("Save", on_click=save_edit).props("color=primary") + ui.button("Save", on_click=save_edit).props("color=primary unelevated") await refresh_table() diff --git a/wiregui/pages/devices.py b/wiregui/pages/devices.py index 6790a98..d053b71 100644 --- a/wiregui/pages/devices.py +++ b/wiregui/pages/devices.py @@ -38,6 +38,19 @@ async def devices_page(): layout() user_id = UUID(app.storage.user["user_id"]) + settings = get_settings() + + # Load client defaults from DB config (falls back to env vars) + async with async_session() as session: + from wiregui.models.configuration import Configuration + _db_cfg = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() + _defaults = { + "allowed_ips": ", ".join(_db_cfg.default_client_allowed_ips) if _db_cfg and _db_cfg.default_client_allowed_ips else settings.wg_allowed_ips, + "dns": ", ".join(_db_cfg.default_client_dns) if _db_cfg and _db_cfg.default_client_dns else settings.wg_dns, + "endpoint": _db_cfg.default_client_endpoint if _db_cfg and _db_cfg.default_client_endpoint else settings.wg_endpoint_host, + "mtu": str(_db_cfg.default_client_mtu) if _db_cfg else str(settings.wg_mtu), + "keepalive": str(_db_cfg.default_client_persistent_keepalive) if _db_cfg else str(settings.wg_persistent_keepalive), + } async def load_devices() -> list[Device]: async with async_session() as session: @@ -112,7 +125,11 @@ async def devices_page(): # Build config and show dialog immediately — don't wait for WG/firewall server_pubkey = await get_server_public_key() - config_text = build_client_config(device, private_key, server_pubkey) + async with async_session() as session: + from sqlmodel import select as sel + from wiregui.models.configuration import Configuration + db_config = (await session.execute(sel(Configuration).limit(1))).scalar_one_or_none() + config_text = build_client_config(device, private_key, server_pubkey, db_config) create_dialog.close() _reset_create_form() @@ -137,11 +154,11 @@ async def devices_page(): create_use_default_endpoint.value = True create_use_default_mtu.value = True create_use_default_keepalive.value = True - create_endpoint.value = "" - create_dns.value = "" - create_mtu.value = "" - create_keepalive.value = "" - create_allowed_ips.value = "" + create_allowed_ips.value = _defaults["allowed_ips"] + create_dns.value = _defaults["dns"] + create_endpoint.value = _defaults["endpoint"] + create_mtu.value = _defaults["mtu"] + create_keepalive.value = _defaults["keepalive"] # --- Delete device --- async def delete_device(device_id: str): @@ -201,27 +218,27 @@ async def devices_page(): with ui.grid(columns=2).classes("w-full gap-2"): create_use_default_ips = ui.switch("Use default Allowed IPs", value=True) - create_allowed_ips = ui.input("Allowed IPs", placeholder="0.0.0.0/0, ::/0").props( + create_allowed_ips = ui.input("Allowed IPs", value=_defaults["allowed_ips"]).props( "outlined dense" ).classes("w-full").bind_enabled_from(create_use_default_ips, "value", backward=lambda v: not v) create_use_default_dns = ui.switch("Use default DNS", value=True) - create_dns = ui.input("DNS Servers", placeholder="1.1.1.1, 1.0.0.1").props( + create_dns = ui.input("DNS Servers", value=_defaults["dns"]).props( "outlined dense" ).classes("w-full").bind_enabled_from(create_use_default_dns, "value", backward=lambda v: not v) create_use_default_endpoint = ui.switch("Use default Endpoint", value=True) - create_endpoint = ui.input("Endpoint", placeholder="vpn.example.com").props( + create_endpoint = ui.input("Endpoint", value=_defaults["endpoint"]).props( "outlined dense" ).classes("w-full").bind_enabled_from(create_use_default_endpoint, "value", backward=lambda v: not v) create_use_default_mtu = ui.switch("Use default MTU", value=True) - create_mtu = ui.input("MTU", placeholder="1280").props( + create_mtu = ui.input("MTU", value=_defaults["mtu"]).props( "outlined dense" ).classes("w-full").bind_enabled_from(create_use_default_mtu, "value", backward=lambda v: not v) create_use_default_keepalive = ui.switch("Use default Keepalive", value=True) - create_keepalive = ui.input("Persistent Keepalive", placeholder="25").props( + create_keepalive = ui.input("Persistent Keepalive", value=_defaults["keepalive"]).props( "outlined dense" ).classes("w-full").bind_enabled_from(create_use_default_keepalive, "value", backward=lambda v: not v) diff --git a/wiregui/services/firewall.py b/wiregui/services/firewall.py index 924b6f1..489afdc 100644 --- a/wiregui/services/firewall.py +++ b/wiregui/services/firewall.py @@ -167,6 +167,74 @@ async def rebuild_all_rules(users_devices_rules: list[dict]) -> None: logger.info("Firewall rules rebuilt for {} users", len(users_devices_rules)) +async def apply_peer_to_peer_policy(enabled: bool) -> None: + """Allow or deny traffic between WireGuard peers (peer-to-peer through the server).""" + settings = get_settings() + iface = settings.wg_interface + v4_net = settings.wg_ipv4_network + v6_net = settings.wg_ipv6_network + chain = "peer_to_peer" + + commands = [ + f"add chain inet {TABLE_NAME} {chain}", + f"flush chain inet {TABLE_NAME} {chain}", + ] + + if enabled: + # Allow traffic from WG subnet destined to WG subnet (both directions through the interface) + commands.append(f'add rule inet {TABLE_NAME} {chain} ip saddr {v4_net} ip daddr {v4_net} accept') + commands.append(f'add rule inet {TABLE_NAME} {chain} ip6 saddr {v6_net} ip6 daddr {v6_net} accept') + else: + # Drop inter-peer traffic + commands.append(f'add rule inet {TABLE_NAME} {chain} ip saddr {v4_net} ip daddr {v4_net} drop') + commands.append(f'add rule inet {TABLE_NAME} {chain} ip6 saddr {v6_net} ip6 daddr {v6_net} drop') + + try: + await _nft_batch(commands) + # Ensure the forward chain jumps to peer_to_peer before user chains + # We flush and re-add to keep ordering correct + logger.info("Peer-to-peer policy: {}", "allow" if enabled else "deny") + except RuntimeError as e: + logger.error("Failed to apply peer-to-peer policy: {}", e) + + +async def apply_lan_to_peers_policy(enabled: bool) -> None: + """Allow or deny traffic from the local network to WireGuard peers.""" + settings = get_settings() + iface = settings.wg_interface + v4_net = settings.wg_ipv4_network + v6_net = settings.wg_ipv6_network + chain = "lan_to_peers" + + commands = [ + f"add chain inet {TABLE_NAME} {chain}", + f"flush chain inet {TABLE_NAME} {chain}", + ] + + if enabled: + # Allow traffic from non-WG sources destined to WG subnet (LAN → peers) + commands.append(f'add rule inet {TABLE_NAME} {chain} ip saddr != {v4_net} ip daddr {v4_net} accept') + commands.append(f'add rule inet {TABLE_NAME} {chain} ip6 saddr != {v6_net} ip6 daddr {v6_net} accept') + else: + # Drop LAN → peer traffic + commands.append(f'add rule inet {TABLE_NAME} {chain} ip saddr != {v4_net} ip daddr {v4_net} drop') + commands.append(f'add rule inet {TABLE_NAME} {chain} ip6 saddr != {v6_net} ip6 daddr {v6_net} drop') + + try: + await _nft_batch(commands) + logger.info("LAN-to-peers policy: {}", "allow" if enabled else "deny") + except RuntimeError as e: + logger.error("Failed to apply LAN-to-peers policy: {}", e) + + +async def get_ruleset() -> str: + """Dump the current nftables ruleset for troubleshooting.""" + try: + return await _nft("list ruleset") + except RuntimeError: + return "nftables is not available.\n\nThis requires root/NET_ADMIN privileges (production container)." + + def _user_chain_name(user_id: str) -> str: """Generate a deterministic chain name from a user ID.""" # Use first 12 chars of UUID (without hyphens) to keep names short diff --git a/wiregui/utils/network.py b/wiregui/utils/network.py index 56f5266..3bb505b 100644 --- a/wiregui/utils/network.py +++ b/wiregui/utils/network.py @@ -1,7 +1,7 @@ """IP address allocation for WireGuard tunnel addresses.""" import random -from ipaddress import IPv4Network, IPv6Network, ip_address +from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network from loguru import logger from sqlalchemy.ext.asyncio import AsyncSession @@ -11,17 +11,17 @@ from wiregui.models.device import Device async def allocate_ipv4(session: AsyncSession, network_cidr: str) -> str: - """Find the next available IPv4 address in the given CIDR range.""" + """Find an available IPv4 address in the given CIDR range.""" network = IPv4Network(network_cidr, strict=False) used = await _get_used_ips(session, "ipv4") - return _find_available(network, used) + return _find_available_v4(network, used) async def allocate_ipv6(session: AsyncSession, network_cidr: str) -> str: - """Find the next available IPv6 address in the given CIDR range.""" + """Find an available IPv6 address in the given CIDR range.""" network = IPv6Network(network_cidr, strict=False) used = await _get_used_ips(session, "ipv6") - return _find_available(network, used) + return _find_available_v6(network, used) async def _get_used_ips(session: AsyncSession, field: str) -> set[str]: @@ -31,30 +31,54 @@ async def _get_used_ips(session: AsyncSession, field: str) -> set[str]: return {row[0] for row in result.all()} -def _find_available(network: IPv4Network | IPv6Network, used: set[str]) -> str: - """Find an available IP in the network, starting from a random offset.""" - hosts = list(network.hosts()) - if not hosts: +def _find_available_v4(network: IPv4Network, used: set[str]) -> str: + """Find an available IPv4 by random sampling — O(1) per attempt, no list materialization.""" + # Usable range: network_address + 2 to broadcast - 1 (skip network, gateway, broadcast) + first = int(network.network_address) + 2 + last = int(network.broadcast_address) - 1 + pool_size = last - first + 1 + + if pool_size <= 0: raise ValueError(f"No usable hosts in {network}") + if len(used) >= pool_size: + raise ValueError(f"No available addresses in {network}") - # Skip the first host (gateway/server address) - hosts = hosts[1:] - if not hosts: - raise ValueError(f"No usable hosts in {network} after reserving gateway") - - # Start from a random offset, then scan forward and backward - start = random.randint(0, len(hosts) - 1) - - # Forward scan - for i in range(start, len(hosts)): - candidate = str(hosts[i]) + for _ in range(min(pool_size, 1000)): + candidate = str(IPv4Address(random.randint(first, last))) if candidate not in used: logger.debug("Allocated {} from {}", candidate, network) return candidate - # Backward scan - for i in range(start - 1, -1, -1): - candidate = str(hosts[i]) + # Fallback: sequential scan (only if random sampling keeps hitting used IPs) + for offset in range(pool_size): + candidate = str(IPv4Address(first + offset)) + if candidate not in used: + logger.debug("Allocated {} from {}", candidate, network) + return candidate + + raise ValueError(f"No available addresses in {network}") + + +def _find_available_v6(network: IPv6Network, used: set[str]) -> str: + """Find an available IPv6 by random sampling.""" + first = int(network.network_address) + 2 + last = int(network.broadcast_address) - 1 + pool_size = last - first + 1 + + if pool_size <= 0: + raise ValueError(f"No usable hosts in {network}") + if len(used) >= pool_size: + raise ValueError(f"No available addresses in {network}") + + for _ in range(min(pool_size, 1000)): + candidate = str(IPv6Address(random.randint(first, last))) + if candidate not in used: + logger.debug("Allocated {} from {}", candidate, network) + return candidate + + # Fallback: sequential scan + for offset in range(pool_size): + candidate = str(IPv6Address(first + offset)) if candidate not in used: logger.debug("Allocated {} from {}", candidate, network) return candidate diff --git a/wiregui/utils/wg_conf.py b/wiregui/utils/wg_conf.py index bd3217b..3f5e3be 100644 --- a/wiregui/utils/wg_conf.py +++ b/wiregui/utils/wg_conf.py @@ -1,6 +1,7 @@ """Build WireGuard client configuration files.""" from wiregui.config import get_settings +from wiregui.models.configuration import Configuration from wiregui.models.device import Device @@ -8,16 +9,40 @@ def build_client_config( device: Device, private_key: str, server_public_key: str, + db_config: Configuration | None = None, ) -> str: - """Build a WireGuard [Interface]+[Peer] config string for a device.""" + """Build a WireGuard [Interface]+[Peer] config string for a device. + + Uses DB Configuration for client defaults when available, + falls back to env-based Settings. + """ settings = get_settings() - # Resolve per-device or default values - dns = device.dns if not device.use_default_dns else settings.wg_dns - endpoint_host = device.endpoint if not device.use_default_endpoint else settings.wg_endpoint_host - mtu = device.mtu if not device.use_default_mtu else settings.wg_mtu - keepalive = device.persistent_keepalive if not device.use_default_persistent_keepalive else settings.wg_persistent_keepalive - allowed_ips = device.allowed_ips if not device.use_default_allowed_ips else settings.wg_allowed_ips + # Resolve per-device overrides → DB config defaults → env var defaults + if device.use_default_dns: + dns = db_config.default_client_dns if db_config and db_config.default_client_dns else settings.wg_dns + else: + dns = device.dns + + if device.use_default_endpoint: + endpoint_host = db_config.default_client_endpoint if db_config and db_config.default_client_endpoint else settings.wg_endpoint_host + else: + endpoint_host = device.endpoint + + if device.use_default_mtu: + mtu = db_config.default_client_mtu if db_config else settings.wg_mtu + else: + mtu = device.mtu + + if device.use_default_persistent_keepalive: + keepalive = db_config.default_client_persistent_keepalive if db_config else settings.wg_persistent_keepalive + else: + keepalive = device.persistent_keepalive + + if device.use_default_allowed_ips: + allowed_ips = db_config.default_client_allowed_ips if db_config and db_config.default_client_allowed_ips else settings.wg_allowed_ips + else: + allowed_ips = device.allowed_ips # Build address list addresses = []