wiregui/wiregui/services/firewall.py
Stefano Bertelli 0546b44507
Some checks failed
CI / test (push) Failing after 26s
CI / release (push) Has been skipped
CI / docker (push) Has been skipped
feat: initial WireGUI implementation — full VPN management platform
Complete Python/NiceGUI rewrite of the Wirezone (Elixir/Phoenix) VPN
management platform. All 10 implementation phases delivered.

Core stack:
- NiceGUI reactive UI with SQLModel ORM on PostgreSQL (asyncpg)
- Alembic migrations, Valkey/Redis cache, pydantic-settings config
- WireGuard management via subprocess (wg/ip/nft CLIs)
- 164 tests passing, 35% code coverage

Features:
- User/device/rule CRUD with admin and unprivileged roles
- Full device config form with per-device WG overrides
- WireGuard client config generation with QR codes
- REST API (v0) with Bearer token auth for all resources
- TOTP MFA with QR registration and challenge flow
- OIDC SSO with authlib (provider registry, auto-create users)
- Magic link passwordless sign-in via email
- SAML SP-initiated SSO with IdP metadata parsing
- WebAuthn/FIDO2 security key registration
- nftables firewall with per-user chains and masquerade
- Background tasks: WG stats polling, VPN session expiry,
  OIDC token refresh, WAN connectivity checks
- Startup reconciliation (DB ↔ WireGuard state sync)
- In-memory notification system with header badge
- Admin UI: users, devices, rules, settings (3 tabs), diagnostics
- Loguru logging with optional timestamped file output

Deployment:
- Multi-stage Dockerfile (python:3.13-slim)
- Docker Compose prod stack (bridge networking, NET_ADMIN, nftables)
- Forgejo CI: tests → semantic versioning → Docker registry push
- Health endpoint at /api/health
2026-03-30 16:53:46 -05:00

191 lines
6.9 KiB
Python

"""nftables firewall management — per-user chains and sets for device traffic filtering."""
import asyncio
import json
from loguru import logger
from wiregui.config import get_settings
TABLE_NAME = "wiregui"
async def _nft(cmd: str) -> str:
"""Run an nft command and return stdout."""
proc = await asyncio.create_subprocess_exec(
"nft", *cmd.split(),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"nft {cmd} failed: {stderr.decode().strip()}")
return stdout.decode().strip()
async def _nft_batch(commands: list[str]) -> None:
"""Run multiple nft commands in a single atomic batch."""
batch = "\n".join(commands)
proc = await asyncio.create_subprocess_exec(
"nft", "-f", "-",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate(batch.encode())
if proc.returncode != 0:
raise RuntimeError(f"nft batch failed: {stderr.decode().strip()}")
async def setup_base_tables() -> None:
"""Create the base wiregui table with forward and postrouting chains."""
commands = [
f"add table inet {TABLE_NAME}",
# Forward chain for filtering device traffic
f"add chain inet {TABLE_NAME} forward {{ type filter hook forward priority 0; policy accept; }}",
# Postrouting for NAT/masquerade
f"add chain inet {TABLE_NAME} postrouting {{ type nat hook postrouting priority 100; policy accept; }}",
]
try:
await _nft_batch(commands)
logger.info("Base nftables table '{}' created", TABLE_NAME)
except RuntimeError as e:
# Table may already exist
if "File exists" not in str(e):
raise
logger.debug("Base nftables table '{}' already exists", TABLE_NAME)
async def setup_masquerade(iface: str | None = None) -> None:
"""Add masquerade rules for VPN traffic — NAT only traffic originating from WG subnets."""
settings = get_settings()
iface = iface or settings.wg_interface
v4_net = settings.wg_ipv4_network
v6_net = settings.wg_ipv6_network
commands = [
f"flush chain inet {TABLE_NAME} postrouting",
f'add rule inet {TABLE_NAME} postrouting ip saddr {v4_net} oifname != "{iface}" masquerade',
f'add rule inet {TABLE_NAME} postrouting ip6 saddr {v6_net} oifname != "{iface}" masquerade',
]
try:
await _nft_batch(commands)
logger.info("Masquerade rule added for {}", iface)
except RuntimeError as e:
logger.debug("Masquerade setup: {}", e)
async def add_user_chain(user_id: str) -> None:
"""Create a per-user chain for firewall rules."""
chain = _user_chain_name(user_id)
commands = [
f"add chain inet {TABLE_NAME} {chain}",
]
try:
await _nft_batch(commands)
logger.debug("User chain created: {}", chain)
except RuntimeError as e:
if "File exists" not in str(e):
raise
async def remove_user_chain(user_id: str) -> None:
"""Remove a per-user chain and all its rules."""
chain = _user_chain_name(user_id)
try:
await _nft_batch([
f"flush chain inet {TABLE_NAME} {chain}",
f"delete chain inet {TABLE_NAME} {chain}",
])
logger.debug("User chain removed: {}", chain)
except RuntimeError as e:
logger.debug("Remove user chain {}: {}", chain, e)
async def add_device_jump_rule(user_id: str, device_ipv4: str | None, device_ipv6: str | None) -> None:
"""Add jump rules in the forward chain to route device traffic to the user chain."""
chain = _user_chain_name(user_id)
commands = []
if device_ipv4:
commands.append(
f"add rule inet {TABLE_NAME} forward ip saddr {device_ipv4} jump {chain}"
)
if device_ipv6:
commands.append(
f"add rule inet {TABLE_NAME} forward ip6 saddr {device_ipv6} jump {chain}"
)
if commands:
await _nft_batch(commands)
logger.debug("Jump rules added for device {}/{} -> {}", device_ipv4, device_ipv6, chain)
async def apply_rule(user_id: str, destination: str, action: str, port_type: str | None = None, port_range: str | None = None) -> None:
"""Add a filter rule to a user's chain."""
chain = _user_chain_name(user_id)
rule = _build_rule_expr(destination, action, port_type, port_range)
await _nft_batch([f"add rule inet {TABLE_NAME} {chain} {rule}"])
logger.debug("Rule applied in {}: {} -> {}", chain, destination, action)
async def rebuild_all_rules(users_devices_rules: list[dict]) -> None:
"""Full reconciliation: flush and rebuild all per-user chains from DB state.
Args:
users_devices_rules: list of dicts with keys:
user_id, devices (list of {ipv4, ipv6}), rules (list of {destination, action, port_type, port_range})
"""
commands = []
for entry in users_devices_rules:
user_id = entry["user_id"]
chain = _user_chain_name(user_id)
# Create/flush user chain
commands.append(f"add chain inet {TABLE_NAME} {chain}")
commands.append(f"flush chain inet {TABLE_NAME} {chain}")
# Add rules
for rule in entry.get("rules", []):
expr = _build_rule_expr(
rule["destination"], rule["action"],
rule.get("port_type"), rule.get("port_range"),
)
commands.append(f"add rule inet {TABLE_NAME} {chain} {expr}")
# Flush forward chain jump rules and re-add
commands.append(f"flush chain inet {TABLE_NAME} forward")
for entry in users_devices_rules:
user_id = entry["user_id"]
chain = _user_chain_name(user_id)
for dev in entry.get("devices", []):
if dev.get("ipv4"):
commands.append(f"add rule inet {TABLE_NAME} forward ip saddr {dev['ipv4']} jump {chain}")
if dev.get("ipv6"):
commands.append(f"add rule inet {TABLE_NAME} forward ip6 saddr {dev['ipv6']} jump {chain}")
if commands:
await _nft_batch(commands)
logger.info("Firewall rules rebuilt for {} users", len(users_devices_rules))
def _user_chain_name(user_id: str) -> str:
"""Generate a deterministic chain name from a user ID."""
# Use first 12 chars of UUID (without hyphens) to keep names short
short = user_id.replace("-", "")[:12]
return f"user_{short}"
def _build_rule_expr(destination: str, action: str, port_type: str | None = None, port_range: str | None = None) -> str:
"""Build an nftables rule expression string."""
# Determine IP version from destination
if ":" in destination:
addr_match = f"ip6 daddr {destination}"
else:
addr_match = f"ip daddr {destination}"
parts = [addr_match]
if port_type and port_range:
parts.append(f"{port_type} dport {port_range}")
parts.append(action)
return " ".join(parts)