diff --git a/alembic/versions/b7e2f4a1c903_add_firewall_policy_fields.py b/alembic/versions/b7e2f4a1c903_add_firewall_policy_fields.py new file mode 100644 index 0000000..3ff2359 --- /dev/null +++ b/alembic/versions/b7e2f4a1c903_add_firewall_policy_fields.py @@ -0,0 +1,28 @@ +"""add firewall policy fields to configurations + +Revision ID: b7e2f4a1c903 +Revises: a3f1d8e92b01 +Create Date: 2026-03-31 00:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b7e2f4a1c903' +down_revision: Union[str, None] = 'a3f1d8e92b01' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('configurations', sa.Column('allow_peer_to_peer', sa.Boolean(), nullable=False, server_default='false')) + op.add_column('configurations', sa.Column('allow_lan_to_peers', sa.Boolean(), nullable=False, server_default='false')) + + +def downgrade() -> None: + op.drop_column('configurations', 'allow_lan_to_peers') + op.drop_column('configurations', 'allow_peer_to_peer') diff --git a/wiregui/models/configuration.py b/wiregui/models/configuration.py index 68b228a..6186d5d 100644 --- a/wiregui/models/configuration.py +++ b/wiregui/models/configuration.py @@ -32,6 +32,10 @@ class Configuration(SQLModel, table=True): sa_column=Column(JSON, default=["0.0.0.0/0", "::/0"]), ) + # Firewall policies + allow_peer_to_peer: bool = Field(default=False) + allow_lan_to_peers: bool = Field(default=False) + # Server WireGuard keypair (generated on first startup) server_private_key: str | None = None server_public_key: str | None = None diff --git a/wiregui/pages/admin/rules.py b/wiregui/pages/admin/rules.py index 982b71d..e992867 100644 --- a/wiregui/pages/admin/rules.py +++ b/wiregui/pages/admin/rules.py @@ -1,5 +1,6 @@ """Admin firewall rules management page.""" +import asyncio from uuid import UUID from loguru import logger @@ -7,10 +8,13 @@ from nicegui import app, ui from sqlmodel import select from wiregui.db import async_session +from wiregui.models.configuration import Configuration from wiregui.models.rule import Rule from wiregui.models.user import User from wiregui.pages.layout import layout from wiregui.services.events import on_rule_created, on_rule_deleted, on_rule_updated +from wiregui.services.firewall import apply_lan_to_peers_policy, apply_peer_to_peer_policy, get_ruleset +from wiregui.utils.time import utcnow @ui.page("/admin/rules") @@ -23,6 +27,7 @@ async def rules_page(): # Load users for the dropdown async with async_session() as session: users = (await session.execute(select(User).order_by(User.email))).scalars().all() + config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() user_options = {str(u.id): u.email for u in users} async def load_rules() -> list[dict]: @@ -69,7 +74,7 @@ async def rules_page(): await session.refresh(rule) logger.info("Rule created: {} {} -> {}", rule.action, rule.destination, user_id_val or "global") - await on_rule_created(rule) + asyncio.create_task(on_rule_created(rule)) create_dialog.close() _reset_form() @@ -110,7 +115,7 @@ async def rules_page(): session.add(rule) await session.commit() await session.refresh(rule) - await on_rule_updated(rule) + asyncio.create_task(on_rule_updated(rule)) logger.info("Rule updated: {} {}", edit_action.value, edit_dest.value) ui.notify("Rule updated") @@ -124,7 +129,7 @@ async def rules_page(): await session.delete(rule) await session.commit() logger.info("Rule deleted: {} {}", rule.action, rule.destination) - await on_rule_deleted(rule) + asyncio.create_task(on_rule_deleted(rule)) await refresh_table() def _reset_form(): @@ -134,34 +139,103 @@ async def rules_page(): port_range_input.value = "" user_select.value = "global" - # Page content - with ui.column().classes("w-full p-4"): - with ui.row().classes("w-full items-center justify-between"): - ui.label("Firewall Rules").classes("text-h5") - ui.button("Add Rule", icon="add", on_click=lambda: create_dialog.open()).props("color=primary") + # --- Firewall policy toggles --- + async def toggle_peer_to_peer(e): + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() + if c: + c.allow_peer_to_peer = e.value + c.updated_at = utcnow() + session.add(c) + await session.commit() + asyncio.create_task(apply_peer_to_peer_policy(e.value)) + ui.notify(f"Peer-to-peer: {'allowed' if e.value else 'denied'}") - columns = [ - {"name": "action", "label": "Action", "field": "action", "align": "left", "sortable": True}, - {"name": "destination", "label": "Destination", "field": "destination", "align": "left", "sortable": True}, - {"name": "port_type", "label": "Protocol", "field": "port_type", "align": "left"}, - {"name": "port_range", "label": "Port(s)", "field": "port_range", "align": "left"}, - {"name": "user", "label": "User", "field": "user", "align": "left"}, - {"name": "actions", "label": "", "field": "id", "align": "center"}, - ] - table = ui.table(columns=columns, rows=[], row_key="id").classes("w-full") - table.add_slot( - "body-cell-actions", - ''' - - - - - ''', - ) - table.on("edit", lambda e: open_edit(e.args)) - table.on("delete", lambda e: delete_rule(e.args)) + async def toggle_lan_to_peers(e): + async with async_session() as session: + c = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none() + if c: + c.allow_lan_to_peers = e.value + c.updated_at = utcnow() + session.add(c) + await session.commit() + asyncio.create_task(apply_lan_to_peers_policy(e.value)) + ui.notify(f"LAN-to-peers: {'allowed' if e.value else 'denied'}") + + # --- Troubleshooting --- + async def show_nft_rules(): + ruleset = await get_ruleset() + with ui.dialog(value=True) as dlg: + with ui.card().classes("w-[800px]"): + ui.label("nftables Ruleset").classes("text-subtitle1 text-bold") + ui.label("Current system firewall rules for troubleshooting.").classes("text-caption text-grey") + ui.separator() + ui.textarea(value=ruleset).props("readonly outlined").classes( + "w-full font-mono text-xs" + ).style("min-height: 400px; white-space: pre") + with ui.row().classes("w-full justify-end q-mt-sm"): + ui.button("Close", on_click=dlg.close).props("flat") + + # --- Page content --- + with ui.column().classes("w-full p-4"): + ui.label("Firewall Rules").classes("text-h5 q-mb-md") + + # Policy switches + with ui.card().classes("w-full"): + ui.label("Network Policies").classes("text-subtitle1 text-bold") + ui.label("Control how traffic flows between peers and the local network.").classes("text-caption text-grey") + ui.separator() + + ui.switch( + "Allow peer-to-peer communication", + value=config.allow_peer_to_peer if config else False, + on_change=toggle_peer_to_peer, + ) + ui.label("Peers can communicate with each other through the WireGuard server (hub-and-spoke).").classes("text-caption text-grey q-ml-xl") + + ui.switch( + "Allow local network to reach peers", + value=config.allow_lan_to_peers if config else False, + on_change=toggle_lan_to_peers, + ).classes("q-mt-sm") + ui.label("Devices on the server's LAN can initiate connections to VPN peers.").classes("text-caption text-grey q-ml-xl") + + # Rules table + with ui.card().classes("w-full q-mt-md"): + with ui.row().classes("w-full items-center justify-between"): + ui.label("Per-User Rules").classes("text-subtitle1 text-bold") + ui.button("Add Rule", icon="add", on_click=lambda: create_dialog.open()).props("color=primary unelevated") + ui.separator() + + columns = [ + {"name": "action", "label": "Action", "field": "action", "align": "left", "sortable": True}, + {"name": "destination", "label": "Destination", "field": "destination", "align": "left", "sortable": True}, + {"name": "port_type", "label": "Protocol", "field": "port_type", "align": "left"}, + {"name": "port_range", "label": "Port(s)", "field": "port_range", "align": "left"}, + {"name": "user", "label": "User", "field": "user", "align": "left"}, + {"name": "actions", "label": "", "field": "id", "align": "center"}, + ] + table = ui.table(columns=columns, rows=[], row_key="id").classes("w-full") + table.add_slot( + "body-cell-actions", + ''' + + + + + ''', + ) + table.on("edit", lambda e: open_edit(e.args)) + table.on("delete", lambda e: delete_rule(e.args)) + + # Troubleshooting + with ui.card().classes("w-full q-mt-md"): + ui.label("Troubleshooting").classes("text-subtitle1 text-bold") + ui.label("Inspect the raw nftables ruleset configured on this system.").classes("text-caption text-grey") + ui.separator() + ui.button("View nftables Rules", icon="terminal", on_click=show_nft_rules).props("color=primary unelevated") # Create rule dialog with ui.dialog() as create_dialog: @@ -195,7 +269,7 @@ async def rules_page(): with ui.row().classes("w-full justify-end q-mt-sm"): ui.button("Cancel", on_click=create_dialog.close).props("flat") - ui.button("Create", on_click=create_rule).props("color=primary") + ui.button("Create", on_click=create_rule).props("color=primary unelevated") # Edit rule dialog user_options_map = {"global": "Global (all users)"} @@ -223,6 +297,6 @@ async def rules_page(): with ui.row().classes("w-full justify-end q-mt-sm"): ui.button("Cancel", on_click=edit_dialog.close).props("flat") - ui.button("Save", on_click=save_edit).props("color=primary") + ui.button("Save", on_click=save_edit).props("color=primary unelevated") await refresh_table() diff --git a/wiregui/services/firewall.py b/wiregui/services/firewall.py index 924b6f1..2707615 100644 --- a/wiregui/services/firewall.py +++ b/wiregui/services/firewall.py @@ -167,6 +167,74 @@ async def rebuild_all_rules(users_devices_rules: list[dict]) -> None: logger.info("Firewall rules rebuilt for {} users", len(users_devices_rules)) +async def apply_peer_to_peer_policy(enabled: bool) -> None: + """Allow or deny traffic between WireGuard peers (peer-to-peer through the server).""" + settings = get_settings() + iface = settings.wg_interface + v4_net = settings.wg_ipv4_network + v6_net = settings.wg_ipv6_network + chain = "peer_to_peer" + + commands = [ + f"add chain inet {TABLE_NAME} {chain}", + f"flush chain inet {TABLE_NAME} {chain}", + ] + + if enabled: + # Allow traffic from WG subnet destined to WG subnet (both directions through the interface) + commands.append(f'add rule inet {TABLE_NAME} {chain} ip saddr {v4_net} ip daddr {v4_net} accept') + commands.append(f'add rule inet {TABLE_NAME} {chain} ip6 saddr {v6_net} ip6 daddr {v6_net} accept') + else: + # Drop inter-peer traffic + commands.append(f'add rule inet {TABLE_NAME} {chain} ip saddr {v4_net} ip daddr {v4_net} drop') + commands.append(f'add rule inet {TABLE_NAME} {chain} ip6 saddr {v6_net} ip6 daddr {v6_net} drop') + + try: + await _nft_batch(commands) + # Ensure the forward chain jumps to peer_to_peer before user chains + # We flush and re-add to keep ordering correct + logger.info("Peer-to-peer policy: {}", "allow" if enabled else "deny") + except RuntimeError as e: + logger.error("Failed to apply peer-to-peer policy: {}", e) + + +async def apply_lan_to_peers_policy(enabled: bool) -> None: + """Allow or deny traffic from the local network to WireGuard peers.""" + settings = get_settings() + iface = settings.wg_interface + v4_net = settings.wg_ipv4_network + v6_net = settings.wg_ipv6_network + chain = "lan_to_peers" + + commands = [ + f"add chain inet {TABLE_NAME} {chain}", + f"flush chain inet {TABLE_NAME} {chain}", + ] + + if enabled: + # Allow traffic from non-WG sources destined to WG subnet (LAN → peers) + commands.append(f'add rule inet {TABLE_NAME} {chain} ip saddr != {v4_net} ip daddr {v4_net} accept') + commands.append(f'add rule inet {TABLE_NAME} {chain} ip6 saddr != {v6_net} ip6 daddr {v6_net} accept') + else: + # Drop LAN → peer traffic + commands.append(f'add rule inet {TABLE_NAME} {chain} ip saddr != {v4_net} ip daddr {v4_net} drop') + commands.append(f'add rule inet {TABLE_NAME} {chain} ip6 saddr != {v6_net} ip6 daddr {v6_net} drop') + + try: + await _nft_batch(commands) + logger.info("LAN-to-peers policy: {}", "allow" if enabled else "deny") + except RuntimeError as e: + logger.error("Failed to apply LAN-to-peers policy: {}", e) + + +async def get_ruleset() -> str: + """Dump the current nftables ruleset for troubleshooting.""" + try: + return await _nft("list ruleset") + except RuntimeError as e: + return f"Error: {e}" + + def _user_chain_name(user_id: str) -> str: """Generate a deterministic chain name from a user ID.""" # Use first 12 chars of UUID (without hyphens) to keep names short