feat: initial WireGUI implementation — full VPN management platform
Some checks failed
CI / test (push) Failing after 26s
CI / release (push) Has been skipped
CI / docker (push) Has been skipped

Complete Python/NiceGUI rewrite of the Wirezone (Elixir/Phoenix) VPN
management platform. All 10 implementation phases delivered.

Core stack:
- NiceGUI reactive UI with SQLModel ORM on PostgreSQL (asyncpg)
- Alembic migrations, Valkey/Redis cache, pydantic-settings config
- WireGuard management via subprocess (wg/ip/nft CLIs)
- 164 tests passing, 35% code coverage

Features:
- User/device/rule CRUD with admin and unprivileged roles
- Full device config form with per-device WG overrides
- WireGuard client config generation with QR codes
- REST API (v0) with Bearer token auth for all resources
- TOTP MFA with QR registration and challenge flow
- OIDC SSO with authlib (provider registry, auto-create users)
- Magic link passwordless sign-in via email
- SAML SP-initiated SSO with IdP metadata parsing
- WebAuthn/FIDO2 security key registration
- nftables firewall with per-user chains and masquerade
- Background tasks: WG stats polling, VPN session expiry,
  OIDC token refresh, WAN connectivity checks
- Startup reconciliation (DB ↔ WireGuard state sync)
- In-memory notification system with header badge
- Admin UI: users, devices, rules, settings (3 tabs), diagnostics
- Loguru logging with optional timestamped file output

Deployment:
- Multi-stage Dockerfile (python:3.13-slim)
- Docker Compose prod stack (bridge networking, NET_ADMIN, nftables)
- Forgejo CI: tests → semantic versioning → Docker registry push
- Health endpoint at /api/health
This commit is contained in:
Stefano Bertelli 2026-03-30 16:53:46 -05:00
commit 0546b44507
109 changed files with 11793 additions and 0 deletions

13
.dockerignore Normal file
View file

@ -0,0 +1,13 @@
.venv/
__pycache__/
*.pyc
.env
.nicegui/
logs/
.git/
.idea/
.pytest_cache/
tests/
.forgejo/
*.md
compose*.yml

View file

@ -0,0 +1,211 @@
name: CI
on:
push:
branches:
- main
pull_request:
jobs:
test:
runs-on: docker
container:
image: python:3.13-slim
services:
postgres:
image: postgres:17
env:
POSTGRES_USER: wiregui
POSTGRES_PASSWORD: wiregui
POSTGRES_DB: wiregui
options: >-
--health-cmd "pg_isready -U wiregui"
--health-interval 5s
--health-timeout 5s
--health-retries 5
env:
WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui
steps:
- uses: actions/checkout@v4
- name: Install system dependencies
run: |
apt-get update && apt-get install -y --no-install-recommends \
wireguard-tools pkg-config libxml2-dev libxmlsec1-dev libxmlsec1-openssl
- name: Install uv
run: pip install uv
- name: Install dependencies
run: uv sync
- name: Run tests
run: uv run pytest -v --tb=short
release:
needs: test
if: github.ref == 'refs/heads/main' && github.event_name == 'push'
runs-on: docker
outputs:
new_tag: ${{ steps.version.outputs.new_tag }}
new_version: ${{ steps.version.outputs.new_version }}
skip: ${{ steps.version.outputs.skip }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Determine version bump
id: version
run: |
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
echo "latest_tag=${LATEST_TAG}" >> "$GITHUB_OUTPUT"
CURRENT="${LATEST_TAG#v}"
IFS='.' read -r MAJOR MINOR PATCH <<< "$CURRENT"
COMMITS=$(git log "${LATEST_TAG}..HEAD" --pretty=format:"%s" 2>/dev/null || git log --pretty=format:"%s")
BUMP="none"
while IFS= read -r msg; do
case "$msg" in
*"BREAKING CHANGE"*|*"!:"*)
BUMP="major"
break
;;
feat:*|feat\(*)
[ "$BUMP" != "major" ] && BUMP="minor"
;;
fix:*|fix\(*|perf:*|perf\(*|refactor:*|refactor\(*)
[ "$BUMP" = "none" ] && BUMP="patch"
;;
esac
done <<< "$COMMITS"
if [ "$BUMP" = "none" ]; then
echo "No version-relevant commits since ${LATEST_TAG}, skipping release"
echo "skip=true" >> "$GITHUB_OUTPUT"
exit 0
fi
case "$BUMP" in
major) MAJOR=$((MAJOR + 1)); MINOR=0; PATCH=0 ;;
minor) MINOR=$((MINOR + 1)); PATCH=0 ;;
patch) PATCH=$((PATCH + 1)) ;;
esac
NEW_VERSION="${MAJOR}.${MINOR}.${PATCH}"
echo "new_version=${NEW_VERSION}" >> "$GITHUB_OUTPUT"
echo "new_tag=v${NEW_VERSION}" >> "$GITHUB_OUTPUT"
echo "bump=${BUMP}" >> "$GITHUB_OUTPUT"
echo "skip=false" >> "$GITHUB_OUTPUT"
echo "Version bump: ${BUMP} -> v${NEW_VERSION}"
- name: Generate changelog
id: changelog
if: steps.version.outputs.skip != 'true'
run: |
LATEST_TAG="${{ steps.version.outputs.latest_tag }}"
NEW_TAG="${{ steps.version.outputs.new_tag }}"
BODY="## ${NEW_TAG}"$'\n\n'
for type_label in "feat:Features" "fix:Bug Fixes" "refactor:Refactoring" "perf:Performance" "docs:Documentation" "chore:Maintenance"; do
prefix="${type_label%%:*}"
label="${type_label#*:}"
MATCHES=$(git log "${LATEST_TAG}..HEAD" --pretty=format:"%s" 2>/dev/null | grep -E "^${prefix}[:(]" || true)
if [ -n "$MATCHES" ]; then
BODY="${BODY}### ${label}"$'\n\n'
while IFS= read -r line; do
CLEAN=$(echo "$line" | sed -E "s/^${prefix}(\([^)]*\))?:\s*//")
BODY="${BODY}- ${CLEAN}"$'\n'
done <<< "$MATCHES"
BODY="${BODY}"$'\n'
fi
done
echo "${BODY}" > /tmp/changelog.md
echo "Generated changelog for ${NEW_TAG}"
- name: Create tag and release
if: steps.version.outputs.skip != 'true'
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
NEW_TAG="${{ steps.version.outputs.new_tag }}"
git config user.name "Forgejo Actions"
git config user.email "noreply@forge.provvedo.com"
git tag -a "${NEW_TAG}" -m "Release ${NEW_TAG}"
git push origin "${NEW_TAG}"
FORGEJO_URL="${GITHUB_SERVER_URL}"
REPO="${GITHUB_REPOSITORY}"
python3 -c "
import json, urllib.request, os
body = open('/tmp/changelog.md').read()
tag = '${NEW_TAG}'
data = json.dumps({
'tag_name': tag,
'name': tag,
'body': body,
'draft': False,
'prerelease': False
}).encode()
req = urllib.request.Request(
'${FORGEJO_URL}/api/v1/repos/${REPO}/releases',
data=data,
headers={
'Authorization': 'token ' + os.environ['GITHUB_TOKEN'],
'Content-Type': 'application/json'
},
method='POST'
)
resp = urllib.request.urlopen(req)
print(f'Created release {tag} (HTTP {resp.status})')
"
docker:
needs: release
if: needs.release.outputs.skip != 'true'
runs-on: docker
container:
image: catthehacker/ubuntu:act-latest
options: --privileged
steps:
- uses: actions/checkout@v4
- name: Build and push image
shell: bash
env:
REGISTRY_TOKEN: ${{ secrets.REGISTRY_TOKEN }}
run: |
VERSION="${{ needs.release.outputs.new_version }}"
TAG="${{ needs.release.outputs.new_tag }}"
REGISTRY=$(echo "${{ github.server_url }}" | sed 's|https://||; s|http://||')
IMAGE="${REGISTRY}/${{ github.repository_owner }}/wiregui"
MAJOR=$(echo "$VERSION" | cut -d. -f1)
MINOR=$(echo "$VERSION" | cut -d. -f2)
echo "Building ${IMAGE}:${TAG}"
# Log in to Forgejo container registry
echo "${REGISTRY_TOKEN}" | docker login "${REGISTRY}" \
-u "${{ github.repository_owner }}" --password-stdin
# Build the image
docker build --network host \
--build-arg "VERSION=${VERSION}" \
-t "${IMAGE}:${TAG}" \
-t "${IMAGE}:${MAJOR}.${MINOR}" \
-t "${IMAGE}:latest" \
.
# Push all tags
docker push "${IMAGE}:${TAG}"
docker push "${IMAGE}:${MAJOR}.${MINOR}"
docker push "${IMAGE}:latest"
echo "Pushed ${IMAGE}:${TAG}, ${IMAGE}:${MAJOR}.${MINOR}, ${IMAGE}:latest"

6
.gitignore vendored Normal file
View file

@ -0,0 +1,6 @@
.venv/
__pycache__/
*.pyc
.env
.nicegui/
logs/

1
.python-version Normal file
View file

@ -0,0 +1 @@
3.13

124
CLAUDE.md Normal file
View file

@ -0,0 +1,124 @@
# WireGUI
## Project Overview
WireGUI is a Python rewrite of the Wirezone VPN management platform (Elixir/Phoenix).
Original source: `/home/stefanob/PycharmProjects/personal/wirezone`
## Tech Stack
- **UI Framework**: NiceGUI (reactive server-side UI over WebSocket)
- **ORM/Models**: SQLModel (SQLAlchemy + Pydantic)
- **Database**: PostgreSQL (via asyncpg)
- **Cache/Sessions**: Valkey (Redis-compatible)
- **Migrations**: Alembic
- **REST API**: FastAPI (built into NiceGUI)
- **Auth**: authlib (OIDC), python-jose (JWT), pyotp (TOTP), webauthn, bcrypt
- **VPN**: subprocess calls to `wg` and `ip` commands
- **Firewall**: python-nftables or subprocess `nft`
- **Python**: 3.13
- **Package Manager**: uv
## Development Setup
```bash
uv sync # Install dependencies
docker compose up -d # Start PostgreSQL and Valkey
alembic upgrade head # Run migrations
uv run python -m wiregui.main # Start the application
```
## Project Structure
```
wiregui/
├── main.py # NiceGUI entrypoint, mounts FastAPI, starts background tasks
├── config.py # pydantic-settings: Settings class
├── db.py # async SQLAlchemy engine + sessionmaker
├── redis.py # Valkey connection pool
├── models/ # SQLModel table definitions
│ ├── user.py
│ ├── device.py
│ ├── rule.py
│ ├── mfa_method.py
│ ├── oidc_connection.py
│ ├── api_token.py
│ ├── connectivity_check.py
│ └── configuration.py
├── schemas/ # Pydantic request/response schemas (non-table)
├── auth/ # Authentication modules
│ ├── passwords.py # bcrypt hashing
│ ├── jwt.py # JWT create/verify
│ ├── session.py # NiceGUI session middleware
│ ├── oidc.py # authlib OIDC
│ ├── saml.py # python3-saml
│ ├── mfa.py # TOTP + WebAuthn
│ └── api_token.py # API token auth
├── api/v0/ # REST API routers
│ ├── users.py
│ ├── devices.py
│ ├── rules.py
│ └── configuration.py
├── pages/ # NiceGUI page definitions
│ ├── layout.py # shared sidebar/header
│ ├── login.py
│ ├── devices.py # user device CRUD
│ ├── account.py # user account/MFA
│ ├── mfa_challenge.py
│ └── admin/ # admin pages
│ ├── users.py
│ ├── devices.py
│ ├── rules.py
│ ├── settings.py
│ └── diagnostics.py
├── services/ # Core services
│ ├── wireguard.py # WG interface management
│ ├── firewall.py # nftables rule management
│ ├── events.py # DB → WG/firewall bridge
│ ├── notifications.py # in-memory notification queue
│ └── email.py # aiosmtplib
├── tasks/ # Background tasks
│ ├── vpn_session.py # expire VPN sessions
│ ├── stats.py # poll WG stats
│ ├── connectivity.py # WAN connectivity checks
│ └── oidc_refresh.py # refresh OIDC tokens
└── utils/ # Utilities
├── crypto.py # keypair gen, Fernet encrypt/decrypt
├── network.py # IP allocation, CIDR validation
└── validators.py # shared validators
alembic/
├── env.py
├── script.py.mako
└── versions/
```
## Commands
- `uv sync` — install/update dependencies
- `uv run python -m wiregui.main` — run the app
- `alembic revision --autogenerate -m "description"` — create migration
- `alembic upgrade head` — apply all migrations
- `alembic downgrade -1` — rollback last migration
- `docker compose up -d` — start local Postgres + Valkey
- `docker compose down` — stop local services
- `pytest` — run tests
## Conventions
- Use SQLModel for all database models (combines SQLAlchemy table + Pydantic validation)
- Use async database sessions with asyncpg
- Place all NiceGUI pages in `wiregui/pages/`
- Place all SQLModel table models in `wiregui/models/`
- Place Pydantic request/response schemas in `wiregui/schemas/`
- Use Alembic autogenerate for migrations
- Background tasks use asyncio (create_task + while/sleep pattern)
- WireGuard/nftables managed via subprocess (asyncio.create_subprocess_exec)
- DB is source of truth; WG/firewall state is reconciled on startup
## Logging — MANDATORY
- **Use loguru for ALL logging and messages. No `print()` statements allowed anywhere in this project.**
- Import: `from loguru import logger`
- Use `logger.info()`, `logger.warning()`, `logger.error()`, `logger.debug()`, etc.
- Loguru is configured in `wiregui/logging.py` via `setup_logging()`
- When `WG_LOG_TO_FILE=true` (default), timestamped log files are written to `logs/` in the project root
- The `logs/` directory is gitignored
## Testing
- Tests live in `tests/` mirroring the `wiregui/` structure
- Run with `uv run pytest`
- Use `pytest-asyncio` for async tests
- Test database: uses same Postgres instance, separate `wiregui_test` database

55
Dockerfile Normal file
View file

@ -0,0 +1,55 @@
FROM python:3.13-slim AS builder
WORKDIR /app
# Install uv for fast dependency resolution
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
# Install system deps needed for building (wireguard-tools for wg CLI)
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc libpq-dev wireguard-tools nftables iproute2 \
&& rm -rf /var/lib/apt/lists/*
# Copy dependency files first for layer caching
COPY pyproject.toml uv.lock* ./
# Install dependencies (production only, no dev group)
RUN uv sync --no-dev --frozen 2>/dev/null || uv sync --no-dev
# Copy application code
COPY wiregui/ wiregui/
COPY alembic/ alembic/
COPY alembic.ini ./
FROM python:3.13-slim AS runner
WORKDIR /app
# Runtime dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
wireguard-tools nftables iproute2 libpq5 \
&& rm -rf /var/lib/apt/lists/*
# Copy uv and virtualenv from builder
COPY --from=builder /usr/local/bin/uv /usr/local/bin/uv
COPY --from=builder /app/.venv /app/.venv
COPY --from=builder /app/wiregui /app/wiregui
COPY --from=builder /app/alembic /app/alembic
COPY --from=builder /app/alembic.ini /app/alembic.ini
COPY --from=builder /app/pyproject.toml /app/pyproject.toml
# Ensure the venv is on PATH
ENV PATH="/app/.venv/bin:$PATH"
ENV PYTHONUNBUFFERED=1
# Create logs directory
RUN mkdir -p /app/logs
ARG VERSION=0.0.0-dev
ENV WG_VERSION=$VERSION
EXPOSE 13000
EXPOSE 51820/udp
# Run migrations then start the app
CMD ["sh", "-c", "alembic upgrade head && python -m wiregui.main"]

0
README.md Normal file
View file

196
TODO.md Normal file
View file

@ -0,0 +1,196 @@
# WireGUI Implementation TODO
Migration of Wirezone (Elixir/Phoenix) to Python/NiceGUI.
Source: `/home/stefanob/PycharmProjects/personal/wirezone`
**Test count: 164 passing | Coverage: 35%**
---
## Phase 1: Foundation — Models, DB, Config ✅
- [x] `pyproject.toml` with dependencies, `uv sync`
- [x] Package directory structure
- [x] `wiregui/config.py` — pydantic-settings (DB, Redis, WG, auth, SMTP, logging)
- [x] `wiregui/db.py` — async engine, sessionmaker, `init_db()`
- [x] `wiregui/redis.py` — Valkey connection pool
- [x] All 8 SQLModel models (User, Device, Rule, MFAMethod, OIDCConnection, ApiToken, ConnectivityCheck, Configuration)
- [x] Alembic init + initial migration + `alembic upgrade head`
- [x] `wiregui/main.py` — app entrypoint
- [x] `compose.yml` — PostgreSQL 17 + Valkey 8
- [x] `wiregui/utils/time.py``utcnow()` helper for naive UTC timestamps
---
## Phase 2: Auth System — Login + Sessions ✅
- [x] `wiregui/auth/passwords.py` — bcrypt hash/verify (direct bcrypt, not passlib)
- [x] `wiregui/auth/jwt.py` — create/decode JWT via python-jose
- [x] `wiregui/auth/session.py``authenticate_user()` email/password verification
- [x] `wiregui/auth/middleware.py` — HTTP-level auth middleware (available for REST API)
- [x] `wiregui/auth/seed.py` — auto-create admin on first startup
- [x] `wiregui/pages/login.py` — login page with email/password form
- [x] `wiregui/pages/home.py` — authenticated home (redirects to /devices)
- [x] Auth guards via `app.storage.user` on each page
- [x] Logout clears storage and redirects
---
## Phase 3: Device UI — User-Facing CRUD ✅
- [x] `wiregui/pages/layout.py` — shared sidebar + header
- [x] `wiregui/utils/network.py` — IPv4/IPv6 allocation (random offset + scan)
- [x] `wiregui/utils/crypto.py` — WG keypair + PSK generation via `wg` CLI
- [x] `wiregui/utils/wg_conf.py` — WG client `.conf` builder
- [x] `wiregui/pages/devices.py``/devices` list + create dialog + delete
- [x] `/devices/{device_id}` — detail page with stats and config flags
- [x] QR code generation + `.conf` download
- [x] Full device create/edit form with all wirezone options (description, per-device config overrides, use_default_* toggles with bound inputs, better layout)
---
## Phase 4: WireGuard Integration ✅
- [x] `wiregui/services/wireguard.py` — async subprocess: ensure_interface, add/remove_peer, get_peers, set_private_key, set_listen_port
- [x] `wiregui/services/events.py` — event bridge: device CRUD → WG + firewall
- [x] Device create/delete in UI fires WG events
- [x] `wiregui/tasks/__init__.py` — background task registry + cancel_all
- [x] `wiregui/tasks/stats.py` — poll WG stats every 60s, update DB
- [x] `wiregui/tasks/reconcile.py` — startup reconciliation (diff DB vs WG, add/remove)
- [x] `config.py``wg_enabled` flag (default False for dev)
- [x] Startup: ensure_interface → reconcile → stats_loop (when wg_enabled)
---
## Phase 5: Firewall (nftables) ✅
- [x] `wiregui/services/firewall.py` — nft CLI: setup_base_tables, masquerade, per-user chains, jump rules, apply_rule, rebuild_all_rules
- [x] IPv4/IPv6 aware, TCP/UDP port range support
- [x] `wiregui/pages/admin/rules.py``/admin/rules` CRUD (action, CIDR, protocol, port, user)
- [x] Events: on_rule_created/deleted, on_device_created adds jump rules
- [x] Startup: setup_base_tables + setup_masquerade (when wg_enabled)
- [x] Edit rule — edit dialog in admin rules page with all fields
- [x] Full user chain rebuild on rule update/delete via `_rebuild_user_chain()` in events.py
---
## Phase 6: REST API (v0) ✅
- [x] `wiregui/auth/api_token.py` — token generation (random → sha256), Bearer resolution with expiry + disabled user checks
- [x] `wiregui/api/deps.py` — get_db, get_current_api_user, require_admin
- [x] `wiregui/schemas/` — Pydantic schemas: UserRead/Create/Update, DeviceRead/Create/Update, RuleRead/Create/Update, ConfigurationRead/Update
- [x] `wiregui/api/v0/users.py` — full CRUD (admin only)
- [x] `wiregui/api/v0/devices.py` — full CRUD (owner or admin, triggers WG/firewall events)
- [x] `wiregui/api/v0/rules.py` — full CRUD (admin only, triggers firewall events)
- [x] `wiregui/api/v0/configuration.py` — GET/PUT (admin only, auto-creates singleton)
- [x] Mounted on NiceGUI app at `/api/v0`
---
## Phase 7: Admin UI ✅
- [x] `/admin/users` — table (email, role, devices, status, last sign-in, method, created), create (email/password/role), edit (email/role/password/disabled), delete with cascading cleanup (devices → WG events, rules)
- [x] `/admin/devices` — all devices with user filter, full create form (owner, name, description, all use_default_* toggles with bound override inputs), full edit form, delete with WG events, config + QR on creation
- [x] `/admin/settings` — 3 tabs:
- Client Defaults (endpoint, DNS, allowed IPs, MTU, keepalive)
- Security (VPN session duration, local auth, unpriv device mgmt/config, OIDC auto-disable)
- Authentication (OIDC provider CRUD with table + dialog; SAML placeholder for Phase 8)
- [x] `/admin/diagnostics` — WG interface status, active peers, connectivity checks, system notifications with clear/clear-all
- [x] `wiregui/services/notifications.py` — in-memory deque (capped at 100), add/clear/count/current
- [x] Header notification bell badge (admin only, links to diagnostics)
- [ ] **TODO:** SAML provider management in Authentication tab
---
## Phase 8: Advanced Auth (MFA, OIDC, Magic Links, SAML) ✅
- [x] TOTP MFA (`wiregui/auth/mfa.py`) — secret generation, URI/QR, verification with clock drift tolerance
- [x] MFA challenge page (`/mfa`) — 6-digit code entry, multi-method support, last-used tracking
- [x] Login page updated: checks for MFA methods after password auth, redirects to `/mfa` if present
- [x] OIDC (`wiregui/auth/oidc.py`) — provider registry from Configuration, authlib Starlette integration
- [x] OIDC routes (`/auth/oidc/{provider}` + `/auth/oidc/{provider}/callback`) — auth code flow, user lookup/auto-create, refresh token storage in OIDCConnection
- [x] Login page shows OIDC provider buttons dynamically from config
- [x] OIDC refresh task (`wiregui/tasks/oidc_refresh.py`) — every 10min, refreshes all stored tokens, creates notifications on failure, respects `disable_vpn_on_oidc_error`
- [x] Magic links (`/auth/magic-link` + `/auth/magic/{user_id}/{token}`) — request page, signed JWT with 15min expiry, email via aiosmtplib
- [x] Email service (`wiregui/services/email.py`) — aiosmtplib send, magic link template
- [x] `/account` page — 3 tabs: Profile (details + password change), Two-Factor Auth (TOTP registration with QR + verification, list/delete methods), API Tokens (create with configurable expiry, list, delete)
- [x] OIDC providers registered on startup from Configuration
- [x] WebAuthn MFA (`wiregui/auth/webauthn.py`) — registration/authentication options generation, response verification, credential storage
- [x] SAML (`wiregui/auth/saml.py` + `wiregui/pages/auth_saml.py`) — SP-initiated SSO, metadata endpoint, ACS callback, IdP metadata parsing, attribute mapping
- [x] WebAuthn browser-side JS integration in account page — `ui.run_javascript()` calls `navigator.credentials.create()`, serializes response, server verifies and stores credential
- [x] SAML provider management UI in admin settings Authentication tab — table + add/delete dialog (config ID, label, XML metadata, sign requests/metadata/assertions/envelopes toggles, auto-create users)
---
## Phase 9: Background Tasks & VPN Session Management
- [x] Task scheduler (`wiregui/tasks/__init__.py`) — register/cancel
- [x] Stats polling task (Phase 4)
- [x] OIDC refresh task (Phase 8)
- [x] VPN session expiry task (`wiregui/tasks/vpn_session.py`) — every 60s, finds expired sessions based on `vpn_session_duration` + `last_signed_in_at`, removes WG peers, creates notifications
- [x] Connectivity check poller (`wiregui/tasks/connectivity.py`) — fetches URL, stores result in DB, notification on failure
- [x] Live stats push — `ui.timer(30, ...)` on `/devices` (table refresh), `/devices/{id}` (RX/TX/handshake/remote IP labels), `/admin/devices` (table refresh)
---
## Phase 10: Polish, Testing & Deployment
### Testing (partially done)
- [x] pytest + pytest-asyncio setup, conftest with test DB
- [x] test_models.py (10 tests), test_auth.py (8 tests), test_utils.py (6 tests), test_services.py (6 tests), test_firewall.py (7 tests)
- [x] test_api.py (6 tests) — token generation, resolution, expiry, disabled user
- [x] test_notifications.py (9 tests) — add, ordering, count, clear, max cap, to_dict
- [x] test_admin.py (13 tests) — user CRUD, cascading deletes, config CRUD, OIDC providers, device overrides
- [x] test_mfa.py (11 tests) — TOTP secret gen, URI, code verification (valid/invalid/wrong secret/empty), QR SVG, DB integration, multi-method
- [x] test_magic_link.py (4 tests) — token creation/expiry/user mismatch, disabled user rejection
- [x] test_account.py (8 tests) — password change flow, API token CRUD, OIDC connection CRUD, refresh token update
- [x] test_integration_mfa.py (7 tests) — full TOTP registration flow, MFA blocks login, wrong code, multi-method, last-used tracking, delete allows bypass, disabled user
- [x] test_integration_oidc.py (10 tests) — provider config loading, connection create/update, auto-create user, disabled user, refresh token, multi-provider
- [x] test_tasks.py (6 tests) — VPN session expiry (expired/unlimited/no-config/disabled user), connectivity check (success/failure with notification)
- [ ] HTTP-level integration tests (OIDC redirect/callback flow with respx mocking)
### Coverage gaps (35% overall — run `uv run pytest --cov=wiregui --cov-report=term-missing --cov-branch`)
**100% covered:** models, schemas, config, auth/passwords, auth/jwt, auth/mfa, auth/api_token, utils/crypto, utils/time, services/notifications
**API routes (32-84% — partially covered via httpx TestClient):**
- [x] `wiregui/api/v0/users.py` (84%) — list/get/create/update/delete
- [x] `wiregui/api/v0/rules.py` (71%) — CRUD
- [x] `wiregui/api/v0/devices.py` (67%) — CRUD, permissions
- [x] `wiregui/api/v0/configuration.py` (61%) — get/update, auto-create
- [ ] `wiregui/api/deps.py` (32%) — test get_current_api_user with real Bearer header parsing, require_admin rejection
**Services (62-89% covered):**
- [x] `wiregui/services/wireguard.py` (62%) — add/remove/get peers mocked
- [x] `wiregui/services/firewall.py` (73%) — base tables, chains, rules, rebuild mocked
- [x] `wiregui/services/events.py` (80%) — device + rule events, rebuild chain
- [x] `wiregui/services/email.py` (89%) — send_email, magic link, no-smtp fallback
- [ ] `wiregui/services/wireguard.py` — test ensure_interface, set_private_key, set_listen_port
- [ ] `wiregui/services/firewall.py` — test _nft/_nft_batch error handling, add_device_jump_rule with only ipv4/ipv6
**Tasks (40-84% covered):**
- [x] `wiregui/tasks/stats.py` (77%) — update from peers, no-op, unmatched peer
- [x] `wiregui/tasks/reconcile.py` (84%) — add missing, remove orphaned, in-sync
- [x] `wiregui/tasks/oidc_refresh.py` (40%) — no connections, skip unknown provider
- [ ] `wiregui/tasks/oidc_refresh.py` — test successful refresh, failure with notification, disable_vpn_on_oidc_error
**Auth modules (85-92% covered):**
- [x] `wiregui/auth/oidc.py` (87%) — register providers, get_client, load from config
- [x] `wiregui/auth/webauthn.py` (85%) — registration/authentication options
- [x] `wiregui/auth/session.py` (90%) — no-password, disabled, nonexistent user
- [ ] `wiregui/auth/saml.py` (0%) — needs mock SAML IdP metadata + response parsing
- [ ] `wiregui/auth/webauthn.py` — test verify_registration, verify_authentication with mock credential data
**Pages (0% — requires E2E testing):**
- [ ] Consider Playwright or NiceGUI's testing utilities for E2E page tests
### Logging (done)
- [x] Loguru configured (wiregui/logging.py), no print statements
- [x] File logging to `logs/` when `WG_LOG_TO_FILE=true`
### Deployment
- [ ] Dockerfile (multi-stage)
- [ ] compose.prod.yml (app + postgres + valkey + caddy)
- [ ] Health endpoint `GET /api/health`
- [ ] First-run CLI setup command
- [ ] README.md

36
alembic.ini Normal file
View file

@ -0,0 +1,36 @@
[alembic]
script_location = alembic
prepend_sys_path = .
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

47
alembic/env.py Normal file
View file

@ -0,0 +1,47 @@
import asyncio
from logging.config import fileConfig
from alembic import context
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
from wiregui.config import get_settings
from wiregui.models import * # noqa: F401, F403 — ensure all models are registered
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = SQLModel.metadata
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode — emit SQL to script output."""
context.configure(
url=get_settings().database_url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
"""Run migrations in 'online' mode — connect to the database."""
engine = create_async_engine(get_settings().database_url)
async with engine.connect() as connection:
await connection.run_sync(do_run_migrations)
await engine.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())

27
alembic/script.py.mako Normal file
View file

@ -0,0 +1,27 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View file

@ -0,0 +1,33 @@
"""add server keypair to configuration
Revision ID: 0741bc76e748
Revises: 647a4418cc8c
Create Date: 2026-03-30 15:37:19.276524
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = '0741bc76e748'
down_revision: Union[str, None] = '647a4418cc8c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('configurations', sa.Column('server_private_key', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
op.add_column('configurations', sa.Column('server_public_key', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('configurations', 'server_public_key')
op.drop_column('configurations', 'server_private_key')
# ### end Alembic commands ###

View file

@ -0,0 +1,171 @@
"""initial schema
Revision ID: 647a4418cc8c
Revises:
Create Date: 2026-03-30 13:18:58.766259
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = '647a4418cc8c'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('configurations',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('allow_unprivileged_device_management', sa.Boolean(), nullable=False),
sa.Column('allow_unprivileged_device_configuration', sa.Boolean(), nullable=False),
sa.Column('local_auth_enabled', sa.Boolean(), nullable=False),
sa.Column('disable_vpn_on_oidc_error', sa.Boolean(), nullable=False),
sa.Column('default_client_persistent_keepalive', sa.Integer(), nullable=False),
sa.Column('default_client_mtu', sa.Integer(), nullable=False),
sa.Column('default_client_endpoint', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('default_client_dns', sa.JSON(), nullable=True),
sa.Column('default_client_allowed_ips', sa.JSON(), nullable=True),
sa.Column('vpn_session_duration', sa.Integer(), nullable=False),
sa.Column('logo_url', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('logo_type', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('openid_connect_providers', sa.JSON(), nullable=True),
sa.Column('saml_identity_providers', sa.JSON(), nullable=True),
sa.Column('inserted_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('connectivity_checks',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('url', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('response_code', sa.Integer(), nullable=True),
sa.Column('response_headers', sa.JSON(), nullable=True),
sa.Column('response_body', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('inserted_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('users',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('password_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('role', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('last_signed_in_at', sa.DateTime(), nullable=True),
sa.Column('last_signed_in_method', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('sign_in_token_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('sign_in_token_created_at', sa.DateTime(), nullable=True),
sa.Column('disabled_at', sa.DateTime(), nullable=True),
sa.Column('inserted_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
op.create_table('api_tokens',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('token_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('expires_at', sa.DateTime(), nullable=True),
sa.Column('user_id', sa.Uuid(), nullable=False),
sa.Column('inserted_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_api_tokens_token_hash'), 'api_tokens', ['token_hash'], unique=True)
op.create_index(op.f('ix_api_tokens_user_id'), 'api_tokens', ['user_id'], unique=False)
op.create_table('devices',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('public_key', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('preshared_key', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('use_default_allowed_ips', sa.Boolean(), nullable=False),
sa.Column('use_default_dns', sa.Boolean(), nullable=False),
sa.Column('use_default_endpoint', sa.Boolean(), nullable=False),
sa.Column('use_default_mtu', sa.Boolean(), nullable=False),
sa.Column('use_default_persistent_keepalive', sa.Boolean(), nullable=False),
sa.Column('endpoint', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('mtu', sa.Integer(), nullable=True),
sa.Column('persistent_keepalive', sa.Integer(), nullable=True),
sa.Column('allowed_ips', sa.JSON(), nullable=True),
sa.Column('dns', sa.JSON(), nullable=True),
sa.Column('ipv4', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('ipv6', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('remote_ip', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('rx_bytes', sa.Integer(), nullable=True),
sa.Column('tx_bytes', sa.Integer(), nullable=True),
sa.Column('latest_handshake', sa.DateTime(), nullable=True),
sa.Column('user_id', sa.Uuid(), nullable=False),
sa.Column('inserted_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('ipv4'),
sa.UniqueConstraint('ipv6')
)
op.create_index(op.f('ix_devices_public_key'), 'devices', ['public_key'], unique=True)
op.create_index(op.f('ix_devices_user_id'), 'devices', ['user_id'], unique=False)
op.create_table('mfa_methods',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('payload', sa.JSON(), nullable=True),
sa.Column('last_used_at', sa.DateTime(), nullable=True),
sa.Column('user_id', sa.Uuid(), nullable=False),
sa.Column('inserted_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_mfa_methods_user_id'), 'mfa_methods', ['user_id'], unique=False)
op.create_table('oidc_connections',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('refresh_token', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('refresh_response', sa.JSON(), nullable=True),
sa.Column('refreshed_at', sa.DateTime(), nullable=True),
sa.Column('user_id', sa.Uuid(), nullable=False),
sa.Column('inserted_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_oidc_connections_user_id'), 'oidc_connections', ['user_id'], unique=False)
op.create_table('rules',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('action', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('destination', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('port_type', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('port_range', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('user_id', sa.Uuid(), nullable=True),
sa.Column('inserted_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_rules_user_id'), 'rules', ['user_id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_rules_user_id'), table_name='rules')
op.drop_table('rules')
op.drop_index(op.f('ix_oidc_connections_user_id'), table_name='oidc_connections')
op.drop_table('oidc_connections')
op.drop_index(op.f('ix_mfa_methods_user_id'), table_name='mfa_methods')
op.drop_table('mfa_methods')
op.drop_index(op.f('ix_devices_user_id'), table_name='devices')
op.drop_index(op.f('ix_devices_public_key'), table_name='devices')
op.drop_table('devices')
op.drop_index(op.f('ix_api_tokens_user_id'), table_name='api_tokens')
op.drop_index(op.f('ix_api_tokens_token_hash'), table_name='api_tokens')
op.drop_table('api_tokens')
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_table('users')
op.drop_table('connectivity_checks')
op.drop_table('configurations')
# ### end Alembic commands ###

63
compose.prod.yml Normal file
View file

@ -0,0 +1,63 @@
services:
wiregui:
build:
context: .
dockerfile: Dockerfile
image: wiregui:latest
restart: unless-stopped
ports:
- "13000:13000"
- "51821:51821/udp"
cap_add:
- NET_ADMIN
- SYS_MODULE
sysctls:
- net.ipv4.ip_forward=1
- net.ipv6.conf.all.forwarding=1
- net.ipv6.conf.all.disable_ipv6=0
environment:
WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui
WG_REDIS_URL: redis://valkey:6379/0
WG_SECRET_KEY: ${WG_SECRET_KEY:-change-me-in-production}
WG_WG_ENABLED: "true"
WG_WG_ENDPOINT_HOST: ${WG_ENDPOINT_HOST:-vpn.example.com}
WG_WG_ENDPOINT_PORT: "51821"
WG_HOST: "0.0.0.0"
WG_PORT: "13000"
WG_EXTERNAL_URL: ${WG_EXTERNAL_URL:-http://localhost:13000}
WG_ADMIN_EMAIL: ${WG_ADMIN_EMAIL:-admin@localhost}
WG_ADMIN_PASSWORD: ${WG_ADMIN_PASSWORD:-}
WG_LOG_TO_FILE: "true"
volumes:
- wiregui_logs:/app/logs
depends_on:
postgres:
condition: service_healthy
valkey:
condition: service_started
postgres:
image: postgres:17
restart: unless-stopped
environment:
POSTGRES_USER: wiregui
POSTGRES_PASSWORD: wiregui
POSTGRES_DB: wiregui
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U wiregui"]
interval: 5s
timeout: 5s
retries: 5
valkey:
image: valkey/valkey:8
restart: unless-stopped
volumes:
- valkey_data:/data
volumes:
postgres_data:
valkey_data:
wiregui_logs:

22
compose.yml Normal file
View file

@ -0,0 +1,22 @@
services:
postgres:
image: postgres:17
environment:
POSTGRES_USER: wiregui
POSTGRES_PASSWORD: wiregui
POSTGRES_DB: wiregui
ports:
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
valkey:
image: valkey/valkey:8
ports:
- "6379:6379"
volumes:
- valkey_data:/data
volumes:
postgres_data:
valkey_data:

49
pyproject.toml Normal file
View file

@ -0,0 +1,49 @@
[project]
name = "wiregui"
version = "0.1.0"
description = "WireGuard VPN management platform — Python/NiceGUI rewrite of Wirezone"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
# UI
"nicegui>=2.12",
# ORM & Database
"sqlmodel>=0.0.22",
"asyncpg>=0.30",
"alembic>=1.14",
# Configuration
"pydantic-settings>=2.7",
# Cache
"redis>=5.2",
# Encryption
"cryptography>=44",
# Auth
"bcrypt>=4.0",
"python-jose[cryptography]>=3.3",
"authlib>=1.4",
"pyotp>=2.9",
"webauthn>=2.2",
"python3-saml>=1.16",
# HTTP client
"httpx>=0.28",
# Email
"aiosmtplib>=3.0",
# QR codes
"qrcode[pil]>=8.0",
# Logging
"loguru>=0.7.3",
]
[dependency-groups]
dev = [
"pytest>=8.0",
"pytest-asyncio>=0.24",
"pytest-cov>=7.1.0",
"respx>=0.22.0",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
asyncio_default_test_loop_scope = "session"
testpaths = ["tests"]

0
tests/__init__.py Normal file
View file

65
tests/conftest.py Normal file
View file

@ -0,0 +1,65 @@
"""Shared test fixtures — async DB session using a test database."""
from collections.abc import AsyncGenerator
import pytest
import pytest_asyncio
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlmodel import SQLModel
from wiregui.config import get_settings
# All models must be imported so SQLModel.metadata knows about them
from wiregui.models import * # noqa: F401, F403
def _test_database_url() -> str:
url = get_settings().database_url
base, _dbname = url.rsplit("/", 1)
return f"{base}/wiregui_test"
TEST_DATABASE_URL = _test_database_url()
# Module-level engine creation (runs once via autouse session fixture)
_engine = None
def _ensure_test_db_sync():
"""Ensure wiregui_test database exists (called once)."""
import asyncio
async def _create():
base_url = get_settings().database_url.rsplit("/", 1)[0] + "/postgres"
admin_engine = create_async_engine(base_url, isolation_level="AUTOCOMMIT")
async with admin_engine.connect() as conn:
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = 'wiregui_test'")
)
if result.scalar() is None:
await conn.execute(text("CREATE DATABASE wiregui_test"))
await admin_engine.dispose()
asyncio.run(_create())
# Create test DB once at import time
_ensure_test_db_sync()
@pytest_asyncio.fixture
async def session() -> AsyncGenerator[AsyncSession]:
"""Fresh engine + session per test, with table setup/teardown."""
engine = create_async_engine(TEST_DATABASE_URL)
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
factory = async_sessionmaker(engine, expire_on_commit=False)
async with factory() as sess:
yield sess
await sess.rollback()
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
await engine.dispose()

161
tests/test_account.py Normal file
View file

@ -0,0 +1,161 @@
"""Tests for account functionality — password changes, API tokens, OIDC connections."""
import hashlib
from datetime import timedelta
from sqlmodel import func, select
from wiregui.auth.api_token import generate_api_token
from wiregui.auth.passwords import hash_password, verify_password
from wiregui.models.api_token import ApiToken
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# --- Password change ---
async def test_password_change_flow(session):
"""Simulate the password change flow: verify old, set new."""
user = User(email="pw-change@example.com", password_hash=hash_password("old-password"))
session.add(user)
await session.flush()
# Verify old password
assert verify_password("old-password", user.password_hash) is True
# Change password
user.password_hash = hash_password("new-password")
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert verify_password("new-password", fetched.password_hash) is True
assert verify_password("old-password", fetched.password_hash) is False
async def test_password_change_wrong_current(session):
"""Wrong current password should not allow change."""
user = User(email="pw-wrong@example.com", password_hash=hash_password("correct"))
session.add(user)
await session.flush()
# Simulate check
assert verify_password("wrong", user.password_hash) is False
# --- API token management ---
async def test_create_multiple_tokens(session):
user = User(email="multi-token@example.com")
session.add(user)
await session.flush()
for _ in range(3):
_, token_hash = generate_api_token()
session.add(ApiToken(token_hash=token_hash, user_id=user.id))
await session.flush()
count = (await session.execute(
select(func.count()).select_from(ApiToken).where(ApiToken.user_id == user.id)
)).scalar()
assert count == 3
async def test_token_with_expiry(session):
user = User(email="expiry-token@example.com")
session.add(user)
await session.flush()
_, token_hash = generate_api_token()
expires = utcnow() + timedelta(days=30)
token = ApiToken(token_hash=token_hash, expires_at=expires, user_id=user.id)
session.add(token)
await session.flush()
fetched = await session.get(ApiToken, token.id)
assert fetched.expires_at is not None
assert fetched.expires_at > utcnow()
async def test_delete_token(session):
user = User(email="del-token@example.com")
session.add(user)
await session.flush()
_, token_hash = generate_api_token()
token = ApiToken(token_hash=token_hash, user_id=user.id)
session.add(token)
await session.flush()
await session.delete(token)
await session.flush()
assert await session.get(ApiToken, token.id) is None
# --- OIDC connections ---
async def test_oidc_connection_create(session):
user = User(email="oidc-conn@example.com")
session.add(user)
await session.flush()
conn = OIDCConnection(
provider="google",
refresh_token="refresh-tok-123",
refresh_response={"access_token": "at", "token_type": "Bearer"},
refreshed_at=utcnow(),
user_id=user.id,
)
session.add(conn)
await session.flush()
fetched = (await session.execute(
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar_one()
assert fetched.provider == "google"
assert fetched.refresh_token == "refresh-tok-123"
assert fetched.refresh_response["access_token"] == "at"
async def test_multiple_oidc_providers(session):
user = User(email="multi-oidc@example.com")
session.add(user)
await session.flush()
for provider in ["google", "okta", "azure"]:
conn = OIDCConnection(provider=provider, user_id=user.id)
session.add(conn)
await session.flush()
count = (await session.execute(
select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar()
assert count == 3
async def test_oidc_connection_update_refresh_token(session):
user = User(email="oidc-refresh@example.com")
session.add(user)
await session.flush()
conn = OIDCConnection(
provider="google",
refresh_token="old-token",
user_id=user.id,
)
session.add(conn)
await session.flush()
conn.refresh_token = "new-token"
conn.refreshed_at = utcnow()
session.add(conn)
await session.flush()
fetched = await session.get(OIDCConnection, conn.id)
assert fetched.refresh_token == "new-token"
assert fetched.refreshed_at is not None

283
tests/test_admin.py Normal file
View file

@ -0,0 +1,283 @@
"""Tests for admin functionality — user management, configuration, cascading deletes."""
import pytest
from sqlmodel import func, select
from wiregui.auth.passwords import hash_password, verify_password
from wiregui.models.api_token import ApiToken
from wiregui.models.configuration import Configuration
from wiregui.models.device import Device
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.rule import Rule
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# --- User CRUD ---
async def test_create_user_with_role(session):
user = User(email="new-admin@test.com", password_hash=hash_password("secret"), role="admin")
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert fetched.role == "admin"
assert verify_password("secret", fetched.password_hash)
async def test_update_user_email(session):
user = User(email="old@test.com", password_hash=hash_password("pw"))
session.add(user)
await session.flush()
user.email = "new@test.com"
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert fetched.email == "new@test.com"
async def test_disable_user(session):
user = User(email="active@test.com", password_hash=hash_password("pw"))
session.add(user)
await session.flush()
assert user.disabled_at is None
user.disabled_at = utcnow()
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert fetched.disabled_at is not None
async def test_promote_demote_user(session):
user = User(email="user@test.com", role="unprivileged")
session.add(user)
await session.flush()
assert user.role == "unprivileged"
user.role = "admin"
session.add(user)
await session.flush()
fetched = await session.get(User, user.id)
assert fetched.role == "admin"
user.role = "unprivileged"
session.add(user)
await session.flush()
assert (await session.get(User, user.id)).role == "unprivileged"
# --- Cascading delete (manual, as we do it in the admin page) ---
async def test_delete_user_cascades_devices(session):
user = User(email="cascade@test.com")
session.add(user)
await session.flush()
d1 = Device(name="d1", public_key="pk-cascade-1", ipv4="10.0.0.1", user_id=user.id)
d2 = Device(name="d2", public_key="pk-cascade-2", ipv4="10.0.0.2", user_id=user.id)
session.add_all([d1, d2])
await session.flush()
# Manually delete devices then user (matching admin page behavior)
devices = (await session.execute(select(Device).where(Device.user_id == user.id))).scalars().all()
for d in devices:
await session.delete(d)
await session.delete(user)
await session.flush()
assert (await session.execute(select(func.count()).select_from(Device).where(Device.user_id == user.id))).scalar() == 0
assert await session.get(User, user.id) is None
async def test_delete_user_cascades_rules(session):
user = User(email="rule-cascade@test.com")
session.add(user)
await session.flush()
rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id)
session.add(rule)
await session.flush()
# Delete rules then user
rules = (await session.execute(select(Rule).where(Rule.user_id == user.id))).scalars().all()
for r in rules:
await session.delete(r)
await session.delete(user)
await session.flush()
assert (await session.execute(select(func.count()).select_from(Rule).where(Rule.user_id == user.id))).scalar() == 0
# --- Configuration singleton ---
async def test_configuration_create_and_update(session):
config = Configuration()
session.add(config)
await session.flush()
assert config.default_client_mtu == 1280
assert config.local_auth_enabled is True
config.default_client_mtu = 1400
config.local_auth_enabled = False
config.vpn_session_duration = 3600
session.add(config)
await session.flush()
fetched = await session.get(Configuration, config.id)
assert fetched.default_client_mtu == 1400
assert fetched.local_auth_enabled is False
assert fetched.vpn_session_duration == 3600
async def test_configuration_oidc_providers(session):
config = Configuration()
session.add(config)
await session.flush()
assert config.openid_connect_providers == []
providers = [
{
"id": "google",
"label": "Sign in with Google",
"scope": "openid email profile",
"response_type": "code",
"client_id": "google-client-id",
"client_secret": "google-secret",
"discovery_document_uri": "https://accounts.google.com/.well-known/openid-configuration",
"auto_create_users": True,
},
{
"id": "okta",
"label": "Okta SSO",
"scope": "openid email profile",
"response_type": "code",
"client_id": "okta-client-id",
"client_secret": "okta-secret",
"discovery_document_uri": "https://dev-123.okta.com/.well-known/openid-configuration",
"auto_create_users": False,
},
]
config.openid_connect_providers = providers
session.add(config)
await session.flush()
fetched = await session.get(Configuration, config.id)
assert len(fetched.openid_connect_providers) == 2
assert fetched.openid_connect_providers[0]["id"] == "google"
assert fetched.openid_connect_providers[1]["auto_create_users"] is False
async def test_configuration_update_client_defaults(session):
config = Configuration()
session.add(config)
await session.flush()
config.default_client_endpoint = "vpn.example.com"
config.default_client_dns = ["8.8.8.8", "8.8.4.4"]
config.default_client_allowed_ips = ["10.0.0.0/8"]
config.default_client_persistent_keepalive = 30
session.add(config)
await session.flush()
fetched = await session.get(Configuration, config.id)
assert fetched.default_client_endpoint == "vpn.example.com"
assert fetched.default_client_dns == ["8.8.8.8", "8.8.4.4"]
assert fetched.default_client_allowed_ips == ["10.0.0.0/8"]
assert fetched.default_client_persistent_keepalive == 30
async def test_configuration_security_toggles(session):
config = Configuration()
session.add(config)
await session.flush()
config.allow_unprivileged_device_management = False
config.allow_unprivileged_device_configuration = False
config.disable_vpn_on_oidc_error = True
session.add(config)
await session.flush()
fetched = await session.get(Configuration, config.id)
assert fetched.allow_unprivileged_device_management is False
assert fetched.allow_unprivileged_device_configuration is False
assert fetched.disable_vpn_on_oidc_error is True
# --- Device config overrides ---
async def test_device_with_custom_config(session):
user = User(email="config-user@test.com")
session.add(user)
await session.flush()
device = Device(
name="custom-config",
public_key="pk-custom-config",
user_id=user.id,
use_default_dns=False,
use_default_endpoint=False,
use_default_mtu=False,
use_default_persistent_keepalive=False,
use_default_allowed_ips=False,
dns=["8.8.8.8"],
endpoint="custom-vpn.example.com",
mtu=1400,
persistent_keepalive=15,
allowed_ips=["10.0.0.0/8", "172.16.0.0/12"],
)
session.add(device)
await session.flush()
fetched = await session.get(Device, device.id)
assert fetched.use_default_dns is False
assert fetched.dns == ["8.8.8.8"]
assert fetched.endpoint == "custom-vpn.example.com"
assert fetched.mtu == 1400
assert fetched.persistent_keepalive == 15
assert fetched.allowed_ips == ["10.0.0.0/8", "172.16.0.0/12"]
async def test_device_default_flags_are_true(session):
user = User(email="defaults@test.com")
session.add(user)
await session.flush()
device = Device(name="defaults", public_key="pk-defaults", user_id=user.id)
session.add(device)
await session.flush()
fetched = await session.get(Device, device.id)
assert fetched.use_default_allowed_ips is True
assert fetched.use_default_dns is True
assert fetched.use_default_endpoint is True
assert fetched.use_default_mtu is True
assert fetched.use_default_persistent_keepalive is True
# --- User device count ---
async def test_user_device_count_query(session):
user = User(email="count-user@test.com")
session.add(user)
await session.flush()
for i in range(3):
session.add(Device(name=f"d{i}", public_key=f"pk-count-{i}", user_id=user.id))
await session.flush()
count = (await session.execute(
select(func.count()).select_from(Device).where(Device.user_id == user.id)
)).scalar()
assert count == 3

86
tests/test_api.py Normal file
View file

@ -0,0 +1,86 @@
"""Tests for REST API endpoints and token auth."""
import hashlib
from wiregui.auth.api_token import generate_api_token, resolve_bearer_token
from wiregui.auth.passwords import hash_password
from wiregui.models.api_token import ApiToken
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# --- Token generation ---
def test_generate_api_token():
plaintext, token_hash = generate_api_token()
assert len(plaintext) > 20
assert token_hash == hashlib.sha256(plaintext.encode()).hexdigest()
def test_generate_api_token_unique():
t1, h1 = generate_api_token()
t2, h2 = generate_api_token()
assert t1 != t2
assert h1 != h2
# --- Token resolution ---
async def test_resolve_valid_token(session):
user = User(email="api-user@example.com", password_hash=hash_password("x"), role="admin")
session.add(user)
await session.flush()
plaintext, token_hash = generate_api_token()
token = ApiToken(token_hash=token_hash, user_id=user.id)
session.add(token)
await session.flush()
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is not None
assert resolved.id == user.id
async def test_resolve_invalid_token(session):
resolved = await resolve_bearer_token(session, "bogus-token")
assert resolved is None
async def test_resolve_expired_token(session):
from datetime import timedelta
user = User(email="expired-api@example.com", password_hash=hash_password("x"))
session.add(user)
await session.flush()
plaintext, token_hash = generate_api_token()
token = ApiToken(
token_hash=token_hash,
user_id=user.id,
expires_at=utcnow() - timedelta(hours=1),
)
session.add(token)
await session.flush()
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is None
async def test_resolve_token_disabled_user(session):
user = User(
email="disabled-api@example.com",
password_hash=hash_password("x"),
disabled_at=utcnow(),
)
session.add(user)
await session.flush()
plaintext, token_hash = generate_api_token()
token = ApiToken(token_hash=token_hash, user_id=user.id)
session.add(token)
await session.flush()
resolved = await resolve_bearer_token(session, plaintext)
assert resolved is None

325
tests/test_api_routes.py Normal file
View file

@ -0,0 +1,325 @@
"""Tests for REST API routes via httpx AsyncClient against the FastAPI app."""
import hashlib
from uuid import UUID, uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from httpx import ASGITransport, AsyncClient
from sqlmodel import select
from wiregui.api.deps import get_current_api_user, get_db, require_admin
from wiregui.api.v0 import router as api_router
from wiregui.auth.api_token import generate_api_token
from wiregui.auth.passwords import hash_password
from wiregui.models.api_token import ApiToken
from wiregui.models.configuration import Configuration
from wiregui.models.device import Device
from wiregui.models.rule import Rule
from wiregui.models.user import User
def _build_app(session, admin_user=None, regular_user=None):
"""Build a test FastAPI app with overridden dependencies."""
test_app = FastAPI()
test_app.include_router(api_router, prefix="/api")
async def override_get_db():
yield session
test_app.dependency_overrides[get_db] = override_get_db
if admin_user:
test_app.dependency_overrides[get_current_api_user] = lambda: admin_user
test_app.dependency_overrides[require_admin] = lambda: admin_user
return test_app
async def _make_admin(session) -> User:
user = User(email="api-admin@test.com", password_hash=hash_password("pw"), role="admin")
session.add(user)
await session.flush()
return user
async def _make_user(session, email="api-user@test.com") -> User:
user = User(email=email, password_hash=hash_password("pw"), role="unprivileged")
session.add(user)
await session.flush()
return user
# ========== Users API ==========
async def test_list_users(session):
admin = await _make_admin(session)
await _make_user(session, "user1@test.com")
await _make_user(session, "user2@test.com")
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/v0/users/")
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 3 # admin + 2 users
async def test_get_user(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get(f"/api/v0/users/{admin.id}")
assert resp.status_code == 200
assert resp.json()["email"] == "api-admin@test.com"
async def test_get_user_not_found(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get(f"/api/v0/users/{uuid4()}")
assert resp.status_code == 404
async def test_create_user(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/api/v0/users/", json={
"email": "new-api-user@test.com",
"password": "secret123",
"role": "unprivileged",
})
assert resp.status_code == 201
data = resp.json()
assert data["email"] == "new-api-user@test.com"
assert data["role"] == "unprivileged"
assert "id" in data
async def test_update_user(session):
admin = await _make_admin(session)
user = await _make_user(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put(f"/api/v0/users/{user.id}", json={
"role": "admin",
})
assert resp.status_code == 200
assert resp.json()["role"] == "admin"
async def test_update_user_password(session):
admin = await _make_admin(session)
user = await _make_user(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put(f"/api/v0/users/{user.id}", json={
"password": "new-password-123",
})
assert resp.status_code == 200
from wiregui.auth.passwords import verify_password
refreshed = await session.get(User, user.id)
assert verify_password("new-password-123", refreshed.password_hash)
async def test_delete_user(session):
admin = await _make_admin(session)
user = await _make_user(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.delete(f"/api/v0/users/{user.id}")
assert resp.status_code == 204
assert await session.get(User, user.id) is None
# ========== Devices API ==========
async def test_list_devices_admin_sees_all(session):
admin = await _make_admin(session)
user = await _make_user(session)
session.add(Device(name="d1", public_key="pk-api-d1", user_id=admin.id))
session.add(Device(name="d2", public_key="pk-api-d2", user_id=user.id))
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/v0/devices/")
assert resp.status_code == 200
assert len(resp.json()) >= 2
async def test_list_devices_user_sees_own(session):
admin = await _make_admin(session)
user = await _make_user(session, "own-devices@test.com")
session.add(Device(name="mine", public_key="pk-api-mine", user_id=user.id))
session.add(Device(name="not-mine", public_key="pk-api-notmine", user_id=admin.id))
await session.flush()
# Override to be the regular user
test_app = _build_app(session)
test_app.dependency_overrides[get_current_api_user] = lambda: user
async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
resp = await client.get("/api/v0/devices/")
assert resp.status_code == 200
names = [d["name"] for d in resp.json()]
assert "mine" in names
assert "not-mine" not in names
async def test_get_device(session):
admin = await _make_admin(session)
device = Device(name="detail", public_key="pk-api-detail", user_id=admin.id, ipv4="10.0.0.5")
session.add(device)
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get(f"/api/v0/devices/{device.id}")
assert resp.status_code == 200
assert resp.json()["name"] == "detail"
assert resp.json()["ipv4"] == "10.0.0.5"
async def test_get_device_forbidden_for_other_user(session):
admin = await _make_admin(session)
user = await _make_user(session, "other-dev@test.com")
device = Device(name="admin-dev", public_key="pk-api-forbid", user_id=admin.id)
session.add(device)
await session.flush()
test_app = _build_app(session)
test_app.dependency_overrides[get_current_api_user] = lambda: user
async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
resp = await client.get(f"/api/v0/devices/{device.id}")
assert resp.status_code == 403
async def test_update_device(session):
admin = await _make_admin(session)
device = Device(name="old-name", public_key="pk-api-update", user_id=admin.id)
session.add(device)
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put(f"/api/v0/devices/{device.id}", json={"name": "new-name"})
assert resp.status_code == 200
assert resp.json()["name"] == "new-name"
async def test_delete_device(session):
admin = await _make_admin(session)
device = Device(name="to-delete", public_key="pk-api-del", user_id=admin.id)
session.add(device)
await session.flush()
did = device.id
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.delete(f"/api/v0/devices/{did}")
assert resp.status_code == 204
assert await session.get(Device, did) is None
# ========== Rules API ==========
async def test_list_rules(session):
admin = await _make_admin(session)
session.add(Rule(action="accept", destination="10.0.0.0/8"))
session.add(Rule(action="drop", destination="192.168.0.0/16", user_id=admin.id))
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/v0/rules/")
assert resp.status_code == 200
assert len(resp.json()) >= 2
async def test_create_rule(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/api/v0/rules/", json={
"action": "accept",
"destination": "172.16.0.0/12",
"port_type": "tcp",
"port_range": "443",
})
assert resp.status_code == 201
data = resp.json()
assert data["action"] == "accept"
assert data["destination"] == "172.16.0.0/12"
assert data["port_type"] == "tcp"
assert data["port_range"] == "443"
async def test_update_rule(session):
admin = await _make_admin(session)
rule = Rule(action="accept", destination="10.0.0.0/8")
session.add(rule)
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put(f"/api/v0/rules/{rule.id}", json={"action": "drop"})
assert resp.status_code == 200
assert resp.json()["action"] == "drop"
async def test_delete_rule(session):
admin = await _make_admin(session)
rule = Rule(action="drop", destination="0.0.0.0/0")
session.add(rule)
await session.flush()
rid = rule.id
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.delete(f"/api/v0/rules/{rid}")
assert resp.status_code == 204
assert await session.get(Rule, rid) is None
# ========== Configuration API ==========
async def test_get_configuration_auto_creates(session):
admin = await _make_admin(session)
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/v0/configuration/")
assert resp.status_code == 200
data = resp.json()
assert data["default_client_mtu"] == 1280
assert data["local_auth_enabled"] is True
async def test_update_configuration(session):
admin = await _make_admin(session)
# Pre-create config
config = Configuration()
session.add(config)
await session.flush()
app = _build_app(session, admin_user=admin)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.put("/api/v0/configuration/", json={
"default_client_mtu": 1400,
"vpn_session_duration": 3600,
"default_client_dns": ["8.8.8.8"],
})
assert resp.status_code == 200
data = resp.json()
assert data["default_client_mtu"] == 1400
assert data["vpn_session_duration"] == 3600
assert data["default_client_dns"] == ["8.8.8.8"]

98
tests/test_auth.py Normal file
View file

@ -0,0 +1,98 @@
"""Tests for authentication modules."""
from sqlmodel import select
from wiregui.auth.jwt import create_access_token, decode_access_token
from wiregui.auth.passwords import hash_password, verify_password
from wiregui.auth.seed import seed_admin
from wiregui.models.user import User
# --- Password hashing ---
def test_hash_and_verify():
hashed = hash_password("my-secret")
assert verify_password("my-secret", hashed) is True
def test_verify_wrong_password():
hashed = hash_password("correct")
assert verify_password("wrong", hashed) is False
def test_hash_is_not_plaintext():
hashed = hash_password("plaintext")
assert hashed != "plaintext"
assert hashed.startswith("$2b$")
# --- JWT ---
def test_create_and_decode_token():
token = create_access_token(user_id="user-123", role="admin")
payload = decode_access_token(token)
assert payload is not None
assert payload["sub"] == "user-123"
assert payload["role"] == "admin"
assert "exp" in payload
def test_decode_invalid_token():
assert decode_access_token("garbage.token.value") is None
def test_decode_tampered_token():
token = create_access_token(user_id="user-123", role="admin")
tampered = token[:-4] + "XXXX"
assert decode_access_token(tampered) is None
# --- Admin seed ---
async def test_seed_admin_creates_user(session, monkeypatch):
"""seed_admin should create an admin when no users exist."""
# Patch async_session to use our test session
from unittest.mock import AsyncMock
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.seed.async_session", mock_session)
monkeypatch.setattr("wiregui.auth.seed.get_settings", lambda: type("S", (), {
"admin_email": "seed-test@example.com",
"admin_password": "seed-pass-123",
})())
await seed_admin()
result = await session.execute(select(User).where(User.email == "seed-test@example.com"))
admin = result.scalar_one()
assert admin.role == "admin"
assert verify_password("seed-pass-123", admin.password_hash)
async def test_seed_admin_skips_when_users_exist(session, monkeypatch):
"""seed_admin should not create a second admin if users already exist."""
from contextlib import asynccontextmanager
existing = User(email="existing@example.com", role="unprivileged")
session.add(existing)
await session.flush()
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.seed.async_session", mock_session)
await seed_admin()
result = await session.execute(select(User))
users = result.scalars().all()
assert len(users) == 1
assert users[0].email == "existing@example.com"

226
tests/test_auth_extended.py Normal file
View file

@ -0,0 +1,226 @@
"""Extended auth tests — OIDC registration, WebAuthn options, session edge cases."""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
from wiregui.auth.passwords import hash_password
from wiregui.auth.session import authenticate_user
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# ========== Session / authenticate_user edge cases ==========
async def test_authenticate_user_no_password_hash(session, monkeypatch):
"""Users without a password (OIDC-only) should not authenticate via password."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
user = User(email="no-pw@test.com", password_hash=None)
session.add(user)
await session.flush()
result = await authenticate_user("no-pw@test.com", "anything")
assert result is None
async def test_authenticate_user_disabled(session, monkeypatch):
"""Disabled users should not authenticate."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
user = User(email="disabled-auth@test.com", password_hash=hash_password("pw"), disabled_at=utcnow())
session.add(user)
await session.flush()
result = await authenticate_user("disabled-auth@test.com", "pw")
assert result is None
async def test_authenticate_user_nonexistent(session, monkeypatch):
"""Nonexistent email should return None."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
result = await authenticate_user("ghost@nowhere.com", "pw")
assert result is None
# ========== OIDC provider registration ==========
async def test_register_providers_from_config(session, monkeypatch):
"""register_providers should register configured OIDC providers with authlib."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
from wiregui.models.configuration import Configuration
config = Configuration(openid_connect_providers=[
{
"id": "test-reg",
"label": "Test",
"scope": "openid email",
"client_id": "cid",
"client_secret": "cs",
"discovery_document_uri": "https://idp.test/.well-known/openid-configuration",
}
])
session.add(config)
await session.flush()
with patch("wiregui.auth.oidc.oauth") as mock_oauth:
from wiregui.auth.oidc import register_providers
await register_providers()
mock_oauth.register.assert_called_once()
call_kwargs = mock_oauth.register.call_args[1]
assert call_kwargs["name"] == "test-reg"
assert call_kwargs["client_id"] == "cid"
async def test_get_client_unknown_provider():
"""get_client should raise for unregistered providers."""
import pytest
from wiregui.auth.oidc import get_client
with pytest.raises(ValueError, match="not registered"):
get_client("nonexistent-provider-xyz")
# ========== WebAuthn options ==========
def test_webauthn_registration_options(monkeypatch):
"""create_registration_options should return valid options and challenge."""
monkeypatch.setattr("wiregui.auth.webauthn.get_settings", lambda: type("S", (), {
"external_url": "https://vpn.example.com",
})())
from wiregui.auth.webauthn import create_registration_options
user_id = uuid4()
result = create_registration_options(user_id, "user@example.com")
assert "options_json" in result
assert "challenge" in result
assert len(result["challenge"]) > 10
assert "user@example.com" in result["options_json"]
def test_webauthn_registration_options_with_excludes(monkeypatch):
"""Existing credentials should be excluded from registration options."""
monkeypatch.setattr("wiregui.auth.webauthn.get_settings", lambda: type("S", (), {
"external_url": "https://vpn.example.com",
})())
from wiregui.auth.webauthn import create_registration_options
existing = [{"credential_id": "AQIDBA"}] # base64url of bytes [1,2,3,4]
result = create_registration_options(uuid4(), "user@example.com", existing)
assert "options_json" in result
def test_webauthn_authentication_options(monkeypatch):
"""create_authentication_options should accept credential descriptors."""
monkeypatch.setattr("wiregui.auth.webauthn.get_settings", lambda: type("S", (), {
"external_url": "https://vpn.example.com",
})())
from wiregui.auth.webauthn import create_authentication_options
credentials = [{"credential_id": "AQIDBA"}]
result = create_authentication_options(credentials)
assert "options_json" in result
assert "challenge" in result
# ========== Events — rule update/delete with rebuild ==========
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.firewall")
async def test_on_rule_updated_triggers_rebuild(mock_fw, mock_settings):
"""on_rule_updated should rebuild the user's firewall chain."""
mock_settings.return_value.wg_enabled = True
mock_fw.rebuild_all_rules = AsyncMock()
from wiregui.models.rule import Rule
from wiregui.services.events import on_rule_updated
# Need to mock the DB call inside _rebuild_user_chain
with patch("wiregui.services.events.async_session") as mock_session_factory:
mock_session = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
# Mock the select results
mock_rules_result = MagicMock()
mock_rules_result.scalars.return_value.all.return_value = []
mock_devices_result = MagicMock()
mock_devices_result.scalars.return_value.all.return_value = []
mock_session.execute = AsyncMock(side_effect=[mock_rules_result, mock_devices_result])
mock_session_factory.return_value = mock_session
rule = Rule(action="accept", destination="10.0.0.0/8", user_id="a1b2c3d4-0000-0000-0000-000000000000")
await on_rule_updated(rule)
mock_fw.rebuild_all_rules.assert_awaited_once()
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.firewall")
async def test_on_rule_deleted_triggers_rebuild(mock_fw, mock_settings):
"""on_rule_deleted should rebuild the user's firewall chain."""
mock_settings.return_value.wg_enabled = True
mock_fw.rebuild_all_rules = AsyncMock()
from wiregui.models.rule import Rule
from wiregui.services.events import on_rule_deleted
with patch("wiregui.services.events.async_session") as mock_session_factory:
mock_session = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
mock_rules_result = MagicMock()
mock_rules_result.scalars.return_value.all.return_value = []
mock_devices_result = MagicMock()
mock_devices_result.scalars.return_value.all.return_value = []
mock_session.execute = AsyncMock(side_effect=[mock_rules_result, mock_devices_result])
mock_session_factory.return_value = mock_session
rule = Rule(action="drop", destination="0.0.0.0/0", user_id="a1b2c3d4-0000-0000-0000-000000000000")
await on_rule_deleted(rule)
mock_fw.rebuild_all_rules.assert_awaited_once()
@patch("wiregui.services.events.get_settings")
async def test_on_rule_deleted_skips_when_disabled(mock_settings):
"""Rule events should be no-ops when WG is disabled."""
mock_settings.return_value.wg_enabled = False
from wiregui.models.rule import Rule
from wiregui.services.events import on_rule_deleted, on_rule_updated
rule = Rule(action="drop", destination="0.0.0.0/0", user_id="a1b2c3d4-0000-0000-0000-000000000000")
await on_rule_updated(rule) # Should not raise
await on_rule_deleted(rule) # Should not raise

40
tests/test_firewall.py Normal file
View file

@ -0,0 +1,40 @@
"""Tests for firewall service — rule expression building and chain naming."""
from wiregui.services.firewall import _build_rule_expr, _user_chain_name
def test_user_chain_name():
uid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
name = _user_chain_name(uid)
assert name == "user_a1b2c3d4e5f6"
assert len(name) <= 30
def test_user_chain_name_deterministic():
uid = "12345678-1234-1234-1234-123456789abc"
assert _user_chain_name(uid) == _user_chain_name(uid)
def test_build_rule_expr_ipv4_accept():
expr = _build_rule_expr("10.0.0.0/8", "accept")
assert expr == "ip daddr 10.0.0.0/8 accept"
def test_build_rule_expr_ipv6_drop():
expr = _build_rule_expr("fd00::/64", "drop")
assert expr == "ip6 daddr fd00::/64 drop"
def test_build_rule_expr_with_port():
expr = _build_rule_expr("192.168.0.0/16", "accept", port_type="tcp", port_range="80-443")
assert expr == "ip daddr 192.168.0.0/16 tcp dport 80-443 accept"
def test_build_rule_expr_single_port():
expr = _build_rule_expr("10.0.0.1/32", "drop", port_type="udp", port_range="53")
assert expr == "ip daddr 10.0.0.1/32 udp dport 53 drop"
def test_build_rule_expr_no_port():
expr = _build_rule_expr("0.0.0.0/0", "accept", port_type=None, port_range=None)
assert expr == "ip daddr 0.0.0.0/0 accept"

View file

@ -0,0 +1,239 @@
"""Integration tests for MFA — full registration and authentication flows through the database."""
import pyotp
from sqlmodel import func, select
from wiregui.auth.mfa import generate_totp_secret, verify_totp_code
from wiregui.auth.passwords import hash_password, verify_password
from wiregui.auth.session import authenticate_user
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.user import User
from wiregui.utils.time import utcnow
async def test_full_totp_registration_flow(session, monkeypatch):
"""End-to-end: create user → generate secret → verify code → store method → re-verify from DB."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
# Create user with password
user = User(email="mfa-flow@example.com", password_hash=hash_password("secure123"))
session.add(user)
await session.flush()
# Step 1: Generate TOTP secret (happens in account page)
secret = generate_totp_secret()
# Step 2: User scans QR, enters code from their authenticator
totp = pyotp.TOTP(secret)
code = totp.now()
# Step 3: Verify the code is correct before saving
assert verify_totp_code(secret, code) is True
# Step 4: Save the MFA method to DB
method = MFAMethod(
name="My Authenticator",
type="totp",
payload={"secret": secret},
user_id=user.id,
)
session.add(method)
await session.flush()
# Step 5: Simulate future login — load method from DB and verify a fresh code
fetched_methods = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalars().all()
assert len(fetched_methods) == 1
stored_secret = fetched_methods[0].payload["secret"]
fresh_code = pyotp.TOTP(stored_secret).now()
assert verify_totp_code(stored_secret, fresh_code) is True
async def test_mfa_blocks_login_without_code(session, monkeypatch):
"""User with MFA should not be fully authenticated without completing MFA challenge."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
# Create user with MFA
user = User(email="mfa-block@example.com", password_hash=hash_password("password1"))
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(name="Phone", type="totp", payload={"secret": secret}, user_id=user.id)
session.add(method)
await session.flush()
# Password auth succeeds
authed_user = await authenticate_user("mfa-block@example.com", "password1")
assert authed_user is not None
# But MFA methods exist — login page would redirect to /mfa instead of completing login
mfa_methods = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == authed_user.id)
)).scalars().all()
assert len(mfa_methods) > 0 # Login flow would check this and redirect to /mfa
async def test_mfa_wrong_code_rejected(session):
"""Wrong TOTP code should be rejected even if method is valid."""
user = User(email="mfa-wrong@example.com", password_hash=hash_password("pw"))
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id)
session.add(method)
await session.flush()
# Load from DB and try wrong code
fetched = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar_one()
assert verify_totp_code(fetched.payload["secret"], "000000") is False
assert verify_totp_code(fetched.payload["secret"], "123456") is False
async def test_mfa_multiple_methods_any_valid_code_works(session):
"""If user has multiple TOTP methods, a valid code from any should work."""
user = User(email="mfa-multi@example.com")
session.add(user)
await session.flush()
secret1 = generate_totp_secret()
secret2 = generate_totp_secret()
session.add(MFAMethod(name="Phone", type="totp", payload={"secret": secret1}, user_id=user.id))
session.add(MFAMethod(name="Backup", type="totp", payload={"secret": secret2}, user_id=user.id))
await session.flush()
methods = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalars().all()
# Code from method 1 should verify against method 1's secret
code1 = pyotp.TOTP(secret1).now()
verified = False
for m in methods:
if verify_totp_code(m.payload["secret"], code1):
verified = True
break
assert verified is True
# Code from method 2 should also work
code2 = pyotp.TOTP(secret2).now()
verified2 = False
for m in methods:
if verify_totp_code(m.payload["secret"], code2):
verified2 = True
break
assert verified2 is True
async def test_mfa_method_last_used_tracking(session):
"""Verifying MFA should update last_used_at timestamp."""
user = User(email="mfa-tracking@example.com")
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id)
session.add(method)
await session.flush()
assert method.last_used_at is None
# Simulate successful verification and update
code = pyotp.TOTP(secret).now()
assert verify_totp_code(secret, code) is True
method.last_used_at = utcnow()
session.add(method)
await session.flush()
fetched = await session.get(MFAMethod, method.id)
assert fetched.last_used_at is not None
async def test_mfa_delete_method_allows_login_without_mfa(session, monkeypatch):
"""After removing all MFA methods, user should not be redirected to MFA challenge."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
user = User(email="mfa-remove@example.com", password_hash=hash_password("pw"))
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(name="Temp", type="totp", payload={"secret": secret}, user_id=user.id)
session.add(method)
await session.flush()
# MFA exists
count = (await session.execute(
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar()
assert count == 1
# Delete it
await session.delete(method)
await session.flush()
count = (await session.execute(
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar()
assert count == 0
# Password auth still works
authed = await authenticate_user("mfa-remove@example.com", "pw")
assert authed is not None
# No MFA methods — login flow would skip MFA challenge
mfa_check = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == authed.id)
)).scalars().all()
assert len(mfa_check) == 0
async def test_disabled_user_with_mfa_cannot_login(session, monkeypatch):
"""Disabled user should be rejected at password stage, never reaching MFA."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
user = User(
email="mfa-disabled@example.com",
password_hash=hash_password("pw"),
disabled_at=utcnow(),
)
session.add(user)
await session.flush()
secret = generate_totp_secret()
session.add(MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id))
await session.flush()
# Password auth rejects disabled user before MFA is ever checked
result = await authenticate_user("mfa-disabled@example.com", "pw")
assert result is None

View file

@ -0,0 +1,309 @@
"""Integration tests for OIDC — mock provider endpoints, test full auth code flow."""
import json
import time
from unittest.mock import patch
from uuid import uuid4
import respx
from httpx import Response
from jose import jwt
from sqlmodel import select
from wiregui.auth.oidc import get_provider_config, load_providers, oauth, register_providers
from wiregui.config import get_settings
from wiregui.models.configuration import Configuration
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.user import User
# --- Helper to create a fake OIDC provider config in the DB ---
async def _setup_oidc_config(session) -> Configuration:
"""Insert a Configuration with a test OIDC provider."""
config = Configuration(
openid_connect_providers=[
{
"id": "test-idp",
"label": "Test IdP",
"scope": "openid email profile",
"response_type": "code",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"discovery_document_uri": "https://idp.example.com/.well-known/openid-configuration",
"auto_create_users": True,
}
],
)
session.add(config)
await session.commit()
return config
def _mock_discovery():
"""Mock OIDC discovery document response."""
return {
"issuer": "https://idp.example.com",
"authorization_endpoint": "https://idp.example.com/authorize",
"token_endpoint": "https://idp.example.com/token",
"userinfo_endpoint": "https://idp.example.com/userinfo",
"jwks_uri": "https://idp.example.com/.well-known/jwks.json",
}
def _mock_token_response(email: str = "oidc-user@example.com"):
"""Mock OIDC token endpoint response with ID token."""
now = int(time.time())
id_token_payload = {
"iss": "https://idp.example.com",
"sub": "oidc-subject-123",
"aud": "test-client-id",
"email": email,
"name": "OIDC User",
"iat": now,
"exp": now + 3600,
"nonce": "test-nonce",
}
# Sign with a simple secret (in real life this would be RSA)
id_token = jwt.encode(id_token_payload, "fake-secret", algorithm="HS256")
return {
"access_token": "mock-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "mock-refresh-token",
"id_token": id_token,
}
# --- Provider config loading ---
async def test_load_providers_from_config(session, monkeypatch):
"""Providers should be loaded from the Configuration table."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
await _setup_oidc_config(session)
providers = await load_providers()
assert len(providers) == 1
assert providers[0]["id"] == "test-idp"
assert providers[0]["client_id"] == "test-client-id"
async def test_load_providers_empty_when_no_config(session, monkeypatch):
"""Should return empty list when no Configuration exists."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
providers = await load_providers()
assert providers == []
async def test_get_provider_config_by_id(session, monkeypatch):
"""Should find a specific provider by ID."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
await _setup_oidc_config(session)
config = await get_provider_config("test-idp")
assert config is not None
assert config["label"] == "Test IdP"
config_missing = await get_provider_config("nonexistent")
assert config_missing is None
# --- OIDC connection storage ---
async def test_oidc_connection_created_on_login(session):
"""Simulates what the callback route does: create user + OIDC connection."""
user = User(email="oidc-new@example.com", role="unprivileged")
session.add(user)
await session.flush()
token_data = _mock_token_response("oidc-new@example.com")
conn = OIDCConnection(
provider="test-idp",
refresh_token=token_data["refresh_token"],
refresh_response=token_data,
user_id=user.id,
)
session.add(conn)
await session.flush()
# Verify it was stored
fetched = (await session.execute(
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar_one()
assert fetched.provider == "test-idp"
assert fetched.refresh_token == "mock-refresh-token"
assert fetched.refresh_response["access_token"] == "mock-access-token"
async def test_oidc_connection_updated_on_re_login(session):
"""Re-login should update the existing OIDC connection, not create a duplicate."""
user = User(email="oidc-relogin@example.com")
session.add(user)
await session.flush()
# First login
conn = OIDCConnection(
provider="test-idp",
refresh_token="old-refresh-token",
user_id=user.id,
)
session.add(conn)
await session.flush()
# Re-login — update existing connection (as the callback route does)
existing = (await session.execute(
select(OIDCConnection).where(
OIDCConnection.user_id == user.id,
OIDCConnection.provider == "test-idp",
)
)).scalar_one()
existing.refresh_token = "new-refresh-token"
from wiregui.utils.time import utcnow
existing.refreshed_at = utcnow()
session.add(existing)
await session.flush()
# Should still be one connection
from sqlmodel import func
count = (await session.execute(
select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar()
assert count == 1
fetched = (await session.execute(
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
)).scalar_one()
assert fetched.refresh_token == "new-refresh-token"
async def test_oidc_auto_create_user(session):
"""When auto_create_users is True, a new user should be created from OIDC email."""
email = "auto-created@example.com"
# Verify user doesn't exist
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
assert existing is None
# Simulate what callback does with auto_create
user = User(email=email, role="unprivileged")
session.add(user)
await session.flush()
from wiregui.utils.time import utcnow
user.last_signed_in_at = utcnow()
user.last_signed_in_method = "oidc:test-idp"
session.add(user)
await session.flush()
created = (await session.execute(select(User).where(User.email == email))).scalar_one()
assert created.role == "unprivileged"
assert created.last_signed_in_method == "oidc:test-idp"
async def test_oidc_disabled_user_rejected(session):
"""Disabled users should not be logged in via OIDC."""
from wiregui.utils.time import utcnow
user = User(email="oidc-disabled@example.com", disabled_at=utcnow())
session.add(user)
await session.flush()
# The callback route checks disabled_at before creating session
assert user.disabled_at is not None # Would redirect to /login
async def test_oidc_user_without_auto_create_rejected(session):
"""When auto_create is False and user doesn't exist, login should fail."""
email = "no-auto-create@example.com"
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
assert existing is None
# The callback route checks auto_create_users from provider config
# With auto_create=False and no existing user, it would redirect to /login
# This verifies the precondition
# --- OIDC refresh token flow ---
async def test_oidc_refresh_stores_new_token(session):
"""Simulates a successful token refresh updating the connection."""
user = User(email="oidc-refresh-test@example.com")
session.add(user)
await session.flush()
conn = OIDCConnection(
provider="test-idp",
refresh_token="old-refresh",
user_id=user.id,
)
session.add(conn)
await session.flush()
# Simulate refresh result
new_token = {
"access_token": "new-access",
"refresh_token": "new-refresh",
"expires_in": 3600,
}
conn.refresh_token = new_token.get("refresh_token", conn.refresh_token)
conn.refresh_response = new_token
from wiregui.utils.time import utcnow
conn.refreshed_at = utcnow()
session.add(conn)
await session.flush()
fetched = await session.get(OIDCConnection, conn.id)
assert fetched.refresh_token == "new-refresh"
assert fetched.refresh_response["access_token"] == "new-access"
assert fetched.refreshed_at is not None
async def test_oidc_multiple_providers_per_user(session):
"""User can have connections to multiple OIDC providers."""
user = User(email="multi-provider@example.com")
session.add(user)
await session.flush()
for provider in ["google", "okta", "azure-ad"]:
session.add(OIDCConnection(
provider=provider,
refresh_token=f"token-{provider}",
user_id=user.id,
))
await session.flush()
conns = (await session.execute(
select(OIDCConnection).where(OIDCConnection.user_id == user.id).order_by(OIDCConnection.provider)
)).scalars().all()
assert len(conns) == 3
assert [c.provider for c in conns] == ["azure-ad", "google", "okta"]

58
tests/test_magic_link.py Normal file
View file

@ -0,0 +1,58 @@
"""Tests for magic link authentication flow."""
from datetime import timedelta
from wiregui.auth.jwt import create_access_token, decode_access_token
from wiregui.auth.passwords import hash_password
from wiregui.models.user import User
def test_magic_link_token_creation():
"""Magic link token should be a valid JWT with short expiry."""
token = create_access_token(
user_id="user-123",
role="unprivileged",
expires_delta=timedelta(minutes=15),
)
payload = decode_access_token(token)
assert payload is not None
assert payload["sub"] == "user-123"
assert payload["role"] == "unprivileged"
def test_magic_link_token_expired():
"""Expired magic link token should be rejected."""
token = create_access_token(
user_id="user-123",
role="admin",
expires_delta=timedelta(minutes=-1), # Already expired
)
payload = decode_access_token(token)
assert payload is None
def test_magic_link_token_wrong_user():
"""Token should only be valid for the intended user."""
token = create_access_token(user_id="user-A", role="admin")
payload = decode_access_token(token)
assert payload["sub"] == "user-A"
# Caller is responsible for checking sub matches the URL user_id
async def test_magic_link_disabled_user_rejected(session):
"""Disabled users should not be able to use magic links."""
from wiregui.utils.time import utcnow
user = User(
email="disabled-magic@example.com",
password_hash=hash_password("pw"),
disabled_at=utcnow(),
)
session.add(user)
await session.flush()
# The token would be valid but the page handler checks disabled_at
token = create_access_token(user_id=str(user.id), role="unprivileged")
payload = decode_access_token(token)
assert payload is not None # Token itself is valid
assert user.disabled_at is not None # But user is disabled — handler would reject

127
tests/test_mfa.py Normal file
View file

@ -0,0 +1,127 @@
"""Tests for TOTP MFA functionality."""
import pyotp
from wiregui.auth.mfa import (
generate_totp_qr_svg,
generate_totp_secret,
get_totp_uri,
verify_totp_code,
)
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.user import User
# --- TOTP secret generation ---
def test_generate_secret():
secret = generate_totp_secret()
assert len(secret) == 32 # base32 encoded
assert secret.isalpha() or any(c.isdigit() for c in secret)
def test_generate_secret_unique():
s1 = generate_totp_secret()
s2 = generate_totp_secret()
assert s1 != s2
# --- TOTP URI ---
def test_get_totp_uri():
uri = get_totp_uri("JBSWY3DPEHPK3PXP", "user@example.com")
assert uri.startswith("otpauth://totp/")
assert "user%40example.com" in uri or "user@example.com" in uri
assert "secret=JBSWY3DPEHPK3PXP" in uri
assert "issuer=WireGUI" in uri
def test_get_totp_uri_custom_issuer():
uri = get_totp_uri("SECRET", "test@test.com", issuer="MyVPN")
assert "issuer=MyVPN" in uri
# --- TOTP verification ---
def test_verify_valid_code():
secret = generate_totp_secret()
totp = pyotp.TOTP(secret)
code = totp.now()
assert verify_totp_code(secret, code) is True
def test_verify_invalid_code():
secret = generate_totp_secret()
assert verify_totp_code(secret, "000000") is False
def test_verify_wrong_secret():
secret1 = generate_totp_secret()
secret2 = generate_totp_secret()
code = pyotp.TOTP(secret1).now()
assert verify_totp_code(secret2, code) is False
def test_verify_empty_code():
secret = generate_totp_secret()
assert verify_totp_code(secret, "") is False
# --- QR code generation ---
def test_generate_qr_svg():
uri = get_totp_uri("SECRET", "test@test.com")
svg = generate_totp_qr_svg(uri)
assert "<svg" in svg
assert "</svg>" in svg
# --- MFA method model integration ---
async def test_create_totp_method(session):
user = User(email="mfa-test@example.com")
session.add(user)
await session.flush()
secret = generate_totp_secret()
method = MFAMethod(
name="My Phone",
type="totp",
payload={"secret": secret},
user_id=user.id,
)
session.add(method)
await session.flush()
from sqlmodel import select
fetched = (await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar_one()
assert fetched.name == "My Phone"
assert fetched.type == "totp"
stored_secret = fetched.payload["secret"]
code = pyotp.TOTP(stored_secret).now()
assert verify_totp_code(stored_secret, code) is True
async def test_user_multiple_mfa_methods(session):
user = User(email="multi-mfa@example.com")
session.add(user)
await session.flush()
m1 = MFAMethod(name="Phone", type="totp", payload={"secret": generate_totp_secret()}, user_id=user.id)
m2 = MFAMethod(name="Backup", type="totp", payload={"secret": generate_totp_secret()}, user_id=user.id)
session.add_all([m1, m2])
await session.flush()
from sqlmodel import select, func
count = (await session.execute(
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
)).scalar()
assert count == 2

168
tests/test_models.py Normal file
View file

@ -0,0 +1,168 @@
"""Tests for SQLModel table definitions."""
import pytest # noqa: F401 — needed for pytest.raises
from sqlmodel import select
from wiregui.models.api_token import ApiToken
from wiregui.models.configuration import Configuration
from wiregui.models.connectivity_check import ConnectivityCheck
from wiregui.models.device import Device
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.rule import Rule
from wiregui.models.user import User
async def test_create_user(session):
user = User(email="alice@example.com", role="admin")
session.add(user)
await session.flush()
result = await session.execute(select(User).where(User.email == "alice@example.com"))
fetched = result.scalar_one()
assert fetched.id == user.id
assert fetched.role == "admin"
assert fetched.disabled_at is None
async def test_create_device_with_user(session):
user = User(email="bob@example.com")
session.add(user)
await session.flush()
device = Device(
name="laptop",
public_key="pk-test-device-001",
user_id=user.id,
)
session.add(device)
await session.flush()
result = await session.execute(select(Device).where(Device.public_key == "pk-test-device-001"))
fetched = result.scalar_one()
assert fetched.name == "laptop"
assert fetched.user_id == user.id
assert fetched.use_default_dns is True
assert fetched.use_default_allowed_ips is True
assert fetched.rx_bytes is None
async def test_device_unique_public_key(session):
user = User(email="carol@example.com")
session.add(user)
await session.flush()
d1 = Device(name="d1", public_key="duplicate-key", user_id=user.id)
session.add(d1)
await session.flush()
d2 = Device(name="d2", public_key="duplicate-key", user_id=user.id)
session.add(d2)
with pytest.raises(Exception): # IntegrityError
await session.flush()
async def test_create_rule(session):
user = User(email="dave@example.com")
session.add(user)
await session.flush()
rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id)
session.add(rule)
await session.flush()
result = await session.execute(select(Rule).where(Rule.user_id == user.id))
fetched = result.scalar_one()
assert fetched.action == "accept"
assert fetched.destination == "10.0.0.0/8"
assert fetched.port_type is None
assert fetched.port_range is None
async def test_create_rule_with_port(session):
rule = Rule(
action="drop",
destination="192.168.0.0/16",
port_type="tcp",
port_range="80-443",
)
session.add(rule)
await session.flush()
fetched = (await session.execute(select(Rule).where(Rule.id == rule.id))).scalar_one()
assert fetched.port_type == "tcp"
assert fetched.port_range == "80-443"
assert fetched.user_id is None # global rule
async def test_create_mfa_method(session):
user = User(email="eve@example.com")
session.add(user)
await session.flush()
mfa = MFAMethod(
name="My Authenticator",
type="totp",
payload={"secret": "JBSWY3DPEHPK3PXP"},
user_id=user.id,
)
session.add(mfa)
await session.flush()
fetched = (await session.execute(select(MFAMethod).where(MFAMethod.user_id == user.id))).scalar_one()
assert fetched.type == "totp"
assert fetched.payload["secret"] == "JBSWY3DPEHPK3PXP"
async def test_create_oidc_connection(session):
user = User(email="frank@example.com")
session.add(user)
await session.flush()
conn = OIDCConnection(provider="google", refresh_token="tok_abc", user_id=user.id)
session.add(conn)
await session.flush()
fetched = (await session.execute(select(OIDCConnection).where(OIDCConnection.user_id == user.id))).scalar_one()
assert fetched.provider == "google"
assert fetched.refresh_token == "tok_abc"
async def test_create_api_token(session):
user = User(email="grace@example.com")
session.add(user)
await session.flush()
token = ApiToken(token_hash="sha256_fake_hash", user_id=user.id)
session.add(token)
await session.flush()
fetched = (await session.execute(select(ApiToken).where(ApiToken.user_id == user.id))).scalar_one()
assert fetched.token_hash == "sha256_fake_hash"
assert fetched.expires_at is None
async def test_create_connectivity_check(session):
check = ConnectivityCheck(url="https://example.com", response_code=200)
session.add(check)
await session.flush()
fetched = (await session.execute(select(ConnectivityCheck).where(ConnectivityCheck.id == check.id))).scalar_one()
assert fetched.response_code == 200
async def test_configuration_defaults(session):
config = Configuration()
session.add(config)
await session.flush()
fetched = (await session.execute(select(Configuration).where(Configuration.id == config.id))).scalar_one()
assert fetched.allow_unprivileged_device_management is True
assert fetched.local_auth_enabled is True
assert fetched.default_client_mtu == 1280
assert fetched.default_client_persistent_keepalive == 25
assert fetched.default_client_dns == ["1.1.1.1", "1.0.0.1"]
assert fetched.default_client_allowed_ips == ["0.0.0.0/0", "::/0"]
assert fetched.vpn_session_duration == 0
assert fetched.openid_connect_providers == []
assert fetched.saml_identity_providers == []

View file

@ -0,0 +1,89 @@
"""Tests for the notification service."""
from wiregui.services import notifications
def setup_function():
"""Clear notifications before each test."""
notifications.clear_all()
def test_add_notification():
n = notifications.add("info", "Test message")
assert n.severity == "info"
assert n.message == "Test message"
assert n.user is None
assert n.id is not None
assert n.timestamp is not None
def test_add_notification_with_user():
n = notifications.add("error", "Something broke", user="admin@example.com")
assert n.user == "admin@example.com"
assert n.severity == "error"
def test_current_returns_newest_first():
notifications.add("info", "First")
notifications.add("warning", "Second")
notifications.add("error", "Third")
current = notifications.current()
assert len(current) == 3
assert current[0].message == "Third"
assert current[1].message == "Second"
assert current[2].message == "First"
def test_count():
assert notifications.count() == 0
notifications.add("info", "One")
notifications.add("info", "Two")
assert notifications.count() == 2
def test_clear_specific():
n1 = notifications.add("info", "Keep this")
n2 = notifications.add("error", "Remove this")
notifications.clear(n2.id)
current = notifications.current()
assert len(current) == 1
assert current[0].id == n1.id
def test_clear_nonexistent_id_is_noop():
notifications.add("info", "Test")
notifications.clear("nonexistent-id")
assert notifications.count() == 1
def test_clear_all():
notifications.add("info", "One")
notifications.add("info", "Two")
notifications.add("info", "Three")
assert notifications.count() == 3
notifications.clear_all()
assert notifications.count() == 0
assert notifications.current() == []
def test_to_dict():
n = notifications.add("warning", "Test dict", user="someone@example.com")
d = n.to_dict()
assert d["severity"] == "warning"
assert d["message"] == "Test dict"
assert d["user"] == "someone@example.com"
assert "id" in d
assert "timestamp" in d
def test_max_notifications():
"""Deque should cap at MAX_NOTIFICATIONS."""
for i in range(notifications.MAX_NOTIFICATIONS + 10):
notifications.add("info", f"Notification {i}")
assert notifications.count() == notifications.MAX_NOTIFICATIONS
# Newest should be the last one added
assert notifications.current()[0].message == f"Notification {notifications.MAX_NOTIFICATIONS + 9}"

124
tests/test_services.py Normal file
View file

@ -0,0 +1,124 @@
"""Tests for services — WireGuard and events."""
from unittest.mock import AsyncMock, patch
from wiregui.models.device import Device
from wiregui.models.rule import Rule
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated, on_rule_created
def _make_device(**kwargs) -> Device:
defaults = dict(
name="test",
public_key="pk-test",
preshared_key="psk-test",
ipv4="10.3.2.5",
ipv6="fd00::3:2:5",
user_id="00000000-0000-0000-0000-000000000000",
)
defaults.update(kwargs)
return Device(**defaults)
# --- Events (with WG enabled) ---
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.firewall")
@patch("wiregui.services.events.wireguard")
async def test_on_device_created_calls_add_peer(mock_wg, mock_fw, mock_settings):
mock_settings.return_value.wg_enabled = True
mock_wg.add_peer = AsyncMock()
mock_fw.add_device_jump_rule = AsyncMock()
device = _make_device()
await on_device_created(device)
mock_wg.add_peer.assert_awaited_once_with(
public_key="pk-test",
allowed_ips=["10.3.2.5/32", "fd00::3:2:5/128"],
preshared_key="psk-test",
)
mock_fw.add_device_jump_rule.assert_awaited_once()
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.wireguard")
async def test_on_device_deleted_calls_remove_peer(mock_wg, mock_settings):
mock_settings.return_value.wg_enabled = True
mock_wg.remove_peer = AsyncMock()
device = _make_device()
await on_device_deleted(device)
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-test")
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.wireguard")
async def test_on_device_updated_calls_add_peer(mock_wg, mock_settings):
mock_settings.return_value.wg_enabled = True
mock_wg.add_peer = AsyncMock()
device = _make_device()
await on_device_updated(device)
mock_wg.add_peer.assert_awaited_once()
# --- Events (WG disabled) ---
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.wireguard")
async def test_events_skip_when_wg_disabled(mock_wg, mock_settings):
mock_settings.return_value.wg_enabled = False
mock_wg.add_peer = AsyncMock()
mock_wg.remove_peer = AsyncMock()
device = _make_device()
await on_device_created(device)
await on_device_deleted(device)
await on_device_updated(device)
mock_wg.add_peer.assert_not_awaited()
mock_wg.remove_peer.assert_not_awaited()
# --- Events (WG error handling) ---
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.firewall")
@patch("wiregui.services.events.wireguard")
async def test_on_device_created_handles_wg_error(mock_wg, mock_fw, mock_settings):
mock_settings.return_value.wg_enabled = True
mock_wg.add_peer = AsyncMock(side_effect=RuntimeError("wg failed"))
mock_fw.add_device_jump_rule = AsyncMock()
device = _make_device()
# Should not raise — error is logged
await on_device_created(device)
# --- Rule events ---
@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.firewall")
async def test_on_rule_created_calls_apply_rule(mock_fw, mock_settings):
mock_settings.return_value.wg_enabled = True
mock_fw.apply_rule = AsyncMock()
rule = Rule(
action="accept",
destination="10.0.0.0/8",
port_type="tcp",
port_range="80",
user_id="00000000-0000-0000-0000-000000000000",
)
await on_rule_created(rule)
mock_fw.apply_rule.assert_awaited_once_with(
"00000000-0000-0000-0000-000000000000", "10.0.0.0/8", "accept", "tcp", "80",
)

View file

@ -0,0 +1,203 @@
"""Extended service tests — wireguard subprocess mocking, firewall nft mocking, email."""
from unittest.mock import AsyncMock, MagicMock, patch
from wiregui.services.wireguard import PeerInfo, add_peer, get_peers, remove_peer
# ========== WireGuard service (mocked subprocess) ==========
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_add_peer_without_psk(mock_run):
mock_run.return_value = ""
await add_peer("pubkey123", ["10.0.0.1/32", "fd00::1/128"], iface="wg-test")
mock_run.assert_awaited_once()
args = mock_run.call_args[0][0]
assert "wg" in args
assert "set" in args
assert "pubkey123" in args
assert "10.0.0.1/32,fd00::1/128" in args
@patch("asyncio.create_subprocess_exec")
async def test_add_peer_with_psk(mock_exec):
"""PSK path uses subprocess directly with stdin."""
mock_proc = AsyncMock()
mock_proc.communicate.return_value = (b"", b"")
mock_proc.returncode = 0
mock_exec.return_value = mock_proc
await add_peer("pubkey456", ["10.0.0.2/32"], preshared_key="psk-data", iface="wg-test")
mock_exec.assert_awaited_once()
call_args = mock_exec.call_args[0]
assert "preshared-key" in call_args
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_remove_peer(mock_run):
mock_run.return_value = ""
await remove_peer("pubkey789", iface="wg-test")
mock_run.assert_awaited_once()
args = mock_run.call_args[0][0]
assert "remove" in args
assert "pubkey789" in args
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_get_peers_parses_dump(mock_run):
dump_output = (
"privkey\tpubkey\t51820\toff\n"
"peerkey1\t(none)\t1.2.3.4:51820\t10.0.0.1/32\t1700000000\t12345\t67890\t25\n"
"peerkey2\t(none)\t(none)\t10.0.0.2/32,fd00::2/128\t0\t0\t0\t0\n"
)
mock_run.return_value = dump_output
peers = await get_peers(iface="wg-test")
assert len(peers) == 2
assert peers[0].public_key == "peerkey1"
assert peers[0].endpoint == "1.2.3.4:51820"
assert peers[0].rx_bytes == 12345
assert peers[0].tx_bytes == 67890
assert peers[0].latest_handshake is not None
assert peers[1].public_key == "peerkey2"
assert peers[1].endpoint is None
assert peers[1].rx_bytes == 0
assert peers[1].latest_handshake is None
assert len(peers[1].allowed_ips) == 2
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_get_peers_returns_empty_on_error(mock_run):
mock_run.side_effect = RuntimeError("interface not found")
peers = await get_peers(iface="wg-test")
assert peers == []
# ========== Firewall (mocked nft) ==========
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_setup_base_tables(mock_batch):
from wiregui.services.firewall import setup_base_tables
await setup_base_tables()
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("add table" in c for c in cmds)
assert any("forward" in c for c in cmds)
assert any("postrouting" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_add_user_chain(mock_batch):
from wiregui.services.firewall import add_user_chain
await add_user_chain("a1b2c3d4-0000-0000-0000-000000000000")
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("user_a1b2c3d40000" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_remove_user_chain(mock_batch):
from wiregui.services.firewall import remove_user_chain
await remove_user_chain("a1b2c3d4-0000-0000-0000-000000000000")
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("flush" in c for c in cmds)
assert any("delete" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_add_device_jump_rule(mock_batch):
from wiregui.services.firewall import add_device_jump_rule
await add_device_jump_rule("user-id-123", "10.0.0.5", "fd00::5")
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("10.0.0.5" in c and "jump" in c for c in cmds)
assert any("fd00::5" in c and "jump" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_apply_rule(mock_batch):
from wiregui.services.firewall import apply_rule
await apply_rule("user-123", "10.0.0.0/8", "accept", "tcp", "80-443")
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("10.0.0.0/8" in c and "accept" in c and "tcp dport 80-443" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_rebuild_all_rules(mock_batch):
from wiregui.services.firewall import rebuild_all_rules
await rebuild_all_rules([
{
"user_id": "user-1",
"devices": [{"ipv4": "10.0.0.1", "ipv6": "fd00::1"}],
"rules": [
{"destination": "0.0.0.0/0", "action": "accept", "port_type": None, "port_range": None},
{"destination": "192.168.0.0/16", "action": "drop", "port_type": "tcp", "port_range": "22"},
],
}
])
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("flush chain" in c and "forward" in c for c in cmds)
assert any("0.0.0.0/0" in c and "accept" in c for c in cmds)
assert any("192.168.0.0/16" in c and "drop" in c for c in cmds)
assert any("10.0.0.1" in c and "jump" in c for c in cmds)
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
async def test_setup_masquerade(mock_batch):
from wiregui.services.firewall import setup_masquerade
await setup_masquerade(iface="wg0")
mock_batch.assert_awaited_once()
cmds = mock_batch.call_args[0][0]
assert any("masquerade" in c for c in cmds)
# ========== Email service (mocked smtp) ==========
@patch("wiregui.services.email.aiosmtplib.send", new_callable=AsyncMock)
async def test_send_email_success(mock_send, monkeypatch):
monkeypatch.setattr("wiregui.services.email.get_settings", lambda: type("S", (), {
"smtp_host": "smtp.test.com",
"smtp_port": 587,
"smtp_user": "user",
"smtp_password": "pass",
"smtp_from": "test@test.com",
})())
from wiregui.services.email import send_email
result = await send_email("to@test.com", "Subject", "Body")
assert result is True
mock_send.assert_awaited_once()
async def test_send_email_no_smtp_configured(monkeypatch):
monkeypatch.setattr("wiregui.services.email.get_settings", lambda: type("S", (), {
"smtp_host": None,
})())
from wiregui.services.email import send_email
result = await send_email("to@test.com", "Subject", "Body")
assert result is False
@patch("wiregui.services.email.aiosmtplib.send", new_callable=AsyncMock)
async def test_send_magic_link(mock_send, monkeypatch):
monkeypatch.setattr("wiregui.services.email.get_settings", lambda: type("S", (), {
"smtp_host": "smtp.test.com",
"smtp_port": 587,
"smtp_user": "u",
"smtp_password": "p",
"smtp_from": "noreply@test.com",
})())
from wiregui.services.email import send_magic_link
result = await send_magic_link("user@test.com", "https://app.test/magic/123/token")
assert result is True
mock_send.assert_awaited_once()

231
tests/test_tasks.py Normal file
View file

@ -0,0 +1,231 @@
"""Tests for background tasks — VPN session expiry and connectivity checks."""
from datetime import timedelta
from unittest.mock import AsyncMock, patch
from sqlmodel import select
from wiregui.auth.passwords import hash_password
from wiregui.models.configuration import Configuration
from wiregui.models.connectivity_check import ConnectivityCheck
from wiregui.models.device import Device
from wiregui.models.user import User
from wiregui.utils.time import utcnow
# --- VPN session expiry ---
async def test_vpn_session_expiry_removes_expired_peers(session, monkeypatch):
"""Users whose session expired should have their WG peers removed."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
# Create config with 1-hour session duration
config = Configuration(vpn_session_duration=3600)
session.add(config)
await session.flush()
# Create a user who signed in 2 hours ago (expired)
expired_user = User(
email="expired@example.com",
password_hash=hash_password("pw"),
last_signed_in_at=utcnow() - timedelta(hours=2),
)
session.add(expired_user)
await session.flush()
device = Device(name="laptop", public_key="pk-expired", user_id=expired_user.id)
session.add(device)
await session.flush()
# Create a user who signed in 30 min ago (still valid)
active_user = User(
email="active@example.com",
password_hash=hash_password("pw"),
last_signed_in_at=utcnow() - timedelta(minutes=30),
)
session.add(active_user)
await session.flush()
active_device = Device(name="phone", public_key="pk-active", user_id=active_user.id)
session.add(active_device)
await session.flush()
# Mock WireGuard
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.vpn_session import _expire_sessions
await _expire_sessions()
# Only expired user's peer should be removed
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-expired")
async def test_vpn_session_no_expiry_when_duration_zero(session, monkeypatch):
"""When vpn_session_duration is 0 (unlimited), no peers should be removed."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
config = Configuration(vpn_session_duration=0)
session.add(config)
await session.flush()
user = User(
email="unlimited@example.com",
last_signed_in_at=utcnow() - timedelta(days=365),
)
session.add(user)
await session.flush()
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.vpn_session import _expire_sessions
await _expire_sessions()
mock_wg.remove_peer.assert_not_awaited()
async def test_vpn_session_no_expiry_when_no_config(session, monkeypatch):
"""When no Configuration exists, no peers should be removed."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
# No Configuration row at all
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.vpn_session import _expire_sessions
await _expire_sessions()
mock_wg.remove_peer.assert_not_awaited()
async def test_vpn_session_skips_disabled_users(session, monkeypatch):
"""Disabled users should be skipped even if their session is expired."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
config = Configuration(vpn_session_duration=3600)
session.add(config)
await session.flush()
user = User(
email="disabled-session@example.com",
last_signed_in_at=utcnow() - timedelta(hours=2),
disabled_at=utcnow(),
)
session.add(user)
await session.flush()
device = Device(name="d", public_key="pk-disabled-session", user_id=user.id)
session.add(device)
await session.flush()
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.vpn_session import _expire_sessions
await _expire_sessions()
mock_wg.remove_peer.assert_not_awaited()
# --- Connectivity checks ---
async def test_connectivity_check_success(session, monkeypatch):
"""Successful connectivity check should store result in DB."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.connectivity.async_session", mock_session)
# Mock httpx to return a successful response
import httpx
class MockResponse:
status_code = 200
headers = {"content-type": "text/plain"}
text = "203.0.113.1"
class MockAsyncClient:
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def get(self, url):
return MockResponse()
monkeypatch.setattr("wiregui.tasks.connectivity.httpx.AsyncClient", lambda **kw: MockAsyncClient())
from wiregui.tasks.connectivity import _check_connectivity
await _check_connectivity()
result = (await session.execute(select(ConnectivityCheck).limit(1))).scalar_one()
assert result.response_code == 200
assert result.response_body == "203.0.113.1"
async def test_connectivity_check_failure(session, monkeypatch):
"""Failed connectivity check should store error and create notification."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.connectivity.async_session", mock_session)
class MockAsyncClient:
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def get(self, url):
raise ConnectionError("Network unreachable")
monkeypatch.setattr("wiregui.tasks.connectivity.httpx.AsyncClient", lambda **kw: MockAsyncClient())
from wiregui.services import notifications
notifications.clear_all()
from wiregui.tasks.connectivity import _check_connectivity
await _check_connectivity()
result = (await session.execute(select(ConnectivityCheck).limit(1))).scalar_one()
assert result.response_code is None
assert "Network unreachable" in result.response_body
assert notifications.count() > 0
assert "connectivity" in notifications.current()[0].message.lower()

View file

@ -0,0 +1,229 @@
"""Extended task tests — stats polling, reconciliation, OIDC refresh."""
from datetime import timedelta
from unittest.mock import AsyncMock, patch
from sqlmodel import select
from wiregui.auth.passwords import hash_password
from wiregui.models.configuration import Configuration
from wiregui.models.device import Device
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.user import User
from wiregui.services.wireguard import PeerInfo
from wiregui.utils.time import utcnow
# ========== Stats task ==========
async def test_stats_update_from_wg_peers(session, monkeypatch):
"""Stats task should update device records from WireGuard peer data."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.stats.async_session", mock_session)
user = User(email="stats-user@test.com")
session.add(user)
await session.flush()
device = Device(name="stats-dev", public_key="pk-stats-test", user_id=user.id)
session.add(device)
await session.flush()
mock_peers = [
PeerInfo(
public_key="pk-stats-test",
endpoint="1.2.3.4:51820",
rx_bytes=123456,
tx_bytes=789012,
latest_handshake=utcnow(),
)
]
with patch("wiregui.tasks.stats.wireguard") as mock_wg:
mock_wg.get_peers = AsyncMock(return_value=mock_peers)
from wiregui.tasks.stats import _update_stats
await _update_stats()
refreshed = await session.get(Device, device.id)
assert refreshed.rx_bytes == 123456
assert refreshed.tx_bytes == 789012
assert refreshed.remote_ip == "1.2.3.4"
assert refreshed.latest_handshake is not None
async def test_stats_no_peers_is_noop(session, monkeypatch):
"""No WG peers should result in no DB changes."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.stats.async_session", mock_session)
with patch("wiregui.tasks.stats.wireguard") as mock_wg:
mock_wg.get_peers = AsyncMock(return_value=[])
from wiregui.tasks.stats import _update_stats
await _update_stats() # Should not raise
async def test_stats_unmatched_peer_ignored(session, monkeypatch):
"""Peers not matching any device should be ignored."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.stats.async_session", mock_session)
mock_peers = [
PeerInfo(public_key="unknown-peer-key", rx_bytes=100, tx_bytes=200)
]
with patch("wiregui.tasks.stats.wireguard") as mock_wg:
mock_wg.get_peers = AsyncMock(return_value=mock_peers)
from wiregui.tasks.stats import _update_stats
await _update_stats() # Should not raise
# ========== Reconciliation task ==========
async def test_reconcile_adds_missing_peers(session, monkeypatch):
"""Devices in DB but not in WG should be added."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.reconcile.async_session", mock_session)
user = User(email="reconcile@test.com")
session.add(user)
await session.flush()
device = Device(name="missing", public_key="pk-missing", ipv4="10.0.0.5", user_id=user.id)
session.add(device)
await session.flush()
with patch("wiregui.tasks.reconcile.wireguard") as mock_wg:
mock_wg.get_peers = AsyncMock(return_value=[]) # WG has no peers
mock_wg.add_peer = AsyncMock()
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.reconcile import reconcile
await reconcile()
mock_wg.add_peer.assert_awaited_once()
call_kwargs = mock_wg.add_peer.call_args[1]
assert call_kwargs["public_key"] == "pk-missing"
assert "10.0.0.5/32" in call_kwargs["allowed_ips"]
mock_wg.remove_peer.assert_not_awaited()
async def test_reconcile_removes_orphaned_peers(session, monkeypatch):
"""Peers in WG but not in DB should be removed."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.reconcile.async_session", mock_session)
# No devices in DB, but WG has a peer
orphan = PeerInfo(public_key="pk-orphan", rx_bytes=0, tx_bytes=0)
with patch("wiregui.tasks.reconcile.wireguard") as mock_wg:
mock_wg.get_peers = AsyncMock(return_value=[orphan])
mock_wg.add_peer = AsyncMock()
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.reconcile import reconcile
await reconcile()
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-orphan")
mock_wg.add_peer.assert_not_awaited()
async def test_reconcile_in_sync(session, monkeypatch):
"""When DB and WG match, nothing should happen."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.reconcile.async_session", mock_session)
user = User(email="in-sync@test.com")
session.add(user)
await session.flush()
device = Device(name="synced", public_key="pk-synced", user_id=user.id)
session.add(device)
await session.flush()
peer = PeerInfo(public_key="pk-synced", rx_bytes=0, tx_bytes=0)
with patch("wiregui.tasks.reconcile.wireguard") as mock_wg:
mock_wg.get_peers = AsyncMock(return_value=[peer])
mock_wg.add_peer = AsyncMock()
mock_wg.remove_peer = AsyncMock()
from wiregui.tasks.reconcile import reconcile
await reconcile()
mock_wg.add_peer.assert_not_awaited()
mock_wg.remove_peer.assert_not_awaited()
# ========== OIDC refresh task ==========
async def test_oidc_refresh_no_connections_is_noop(session, monkeypatch):
"""No OIDC connections should result in no refresh attempts."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.oidc_refresh.async_session", mock_session)
monkeypatch.setattr("wiregui.auth.oidc.load_providers", AsyncMock(return_value=[]))
from wiregui.tasks.oidc_refresh import _refresh_all
await _refresh_all() # Should not raise
async def test_oidc_refresh_skips_unknown_provider(session, monkeypatch):
"""Connections for unknown providers should be skipped."""
from contextlib import asynccontextmanager
@asynccontextmanager
async def mock_session():
yield session
monkeypatch.setattr("wiregui.tasks.oidc_refresh.async_session", mock_session)
monkeypatch.setattr("wiregui.auth.oidc.load_providers", AsyncMock(return_value=[
{"id": "known-provider", "client_id": "cid", "client_secret": "cs", "discovery_document_uri": "https://x"}
]))
user = User(email="oidc-skip@test.com")
session.add(user)
await session.flush()
conn = OIDCConnection(provider="unknown-provider", refresh_token="tok", user_id=user.id)
session.add(conn)
await session.flush()
from wiregui.tasks.oidc_refresh import _refresh_all
await _refresh_all() # Should skip gracefully

120
tests/test_utils.py Normal file
View file

@ -0,0 +1,120 @@
"""Tests for utility modules."""
import subprocess
import pytest
from sqlmodel import select
from wiregui.models.device import Device
from wiregui.models.user import User
from wiregui.utils.network import allocate_ipv4, allocate_ipv6
from wiregui.utils.wg_conf import build_client_config
# --- IP allocation ---
async def test_allocate_ipv4_first_device(session):
user = User(email="net-test@example.com")
session.add(user)
await session.flush()
ip = await allocate_ipv4(session, "10.3.2.0/24")
assert ip.startswith("10.3.2.")
# Should not be the network (.0) or gateway (.1)
last_octet = int(ip.split(".")[-1])
assert last_octet >= 2
async def test_allocate_ipv4_skips_used(session):
user = User(email="net-skip@example.com")
session.add(user)
await session.flush()
# Exhaust a tiny /30 network (4 addresses: .0 network, .1 gateway, .2 usable, .3 broadcast)
d1 = Device(name="d1", public_key="pk-net-1", ipv4="10.99.0.2", user_id=user.id)
session.add(d1)
await session.flush()
# Only .2 was usable in a /30 — allocation should fail
with pytest.raises(ValueError, match="No available"):
await allocate_ipv4(session, "10.99.0.0/30")
async def test_allocate_ipv6(session):
user = User(email="net6-test@example.com")
session.add(user)
await session.flush()
ip = await allocate_ipv6(session, "fd00::3:2:0/120")
assert ip.startswith("fd00::3:2:")
# --- WireGuard config builder ---
def test_build_client_config():
device = Device(
name="test-device",
public_key="device-pub-key",
preshared_key="device-psk",
ipv4="10.3.2.5",
ipv6="fd00::3:2:5",
use_default_allowed_ips=True,
use_default_dns=True,
use_default_endpoint=True,
use_default_mtu=True,
use_default_persistent_keepalive=True,
user_id="00000000-0000-0000-0000-000000000000",
)
config = build_client_config(device, "PRIVATE_KEY_HERE", "SERVER_PUB_KEY")
assert "[Interface]" in config
assert "PrivateKey = PRIVATE_KEY_HERE" in config
assert "10.3.2.5/32" in config
assert "fd00::3:2:5/128" in config
assert "[Peer]" in config
assert "PublicKey = SERVER_PUB_KEY" in config
assert "PresharedKey = device-psk" in config
assert "Endpoint = " in config
def test_build_client_config_no_psk():
device = Device(
name="no-psk",
public_key="pub",
preshared_key=None,
ipv4="10.3.2.6",
ipv6=None,
use_default_allowed_ips=True,
use_default_dns=True,
use_default_endpoint=True,
use_default_mtu=True,
use_default_persistent_keepalive=True,
user_id="00000000-0000-0000-0000-000000000000",
)
config = build_client_config(device, "PRIV", "SERVPUB")
assert "PresharedKey" not in config
assert "fd00::" not in config # no ipv6
# --- Crypto (only if wg is installed) ---
def test_generate_keypair():
"""Test keypair generation — requires `wg` CLI to be installed."""
try:
subprocess.run(["wg", "--version"], capture_output=True, check=True)
except FileNotFoundError:
pytest.skip("wg CLI not installed")
from wiregui.utils.crypto import generate_keypair, generate_preshared_key
priv, pub = generate_keypair()
assert len(priv) == 44 # base64-encoded 32 bytes
assert len(pub) == 44
psk = generate_preshared_key()
assert len(psk) == 44

2016
uv.lock generated Normal file

File diff suppressed because it is too large Load diff

0
wiregui/__init__.py Normal file
View file

0
wiregui/api/__init__.py Normal file
View file

38
wiregui/api/deps.py Normal file
View file

@ -0,0 +1,38 @@
"""Shared FastAPI dependencies for the REST API."""
from collections.abc import AsyncGenerator
from fastapi import Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from wiregui.auth.api_token import resolve_bearer_token
from wiregui.db import async_session
from wiregui.models.user import User
async def get_db() -> AsyncGenerator[AsyncSession]:
async with async_session() as session:
yield session
async def get_current_api_user(
request: Request,
session: AsyncSession = Depends(get_db),
) -> User:
"""Extract Bearer token from Authorization header and resolve the user."""
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
token = auth[7:]
user = await resolve_bearer_token(session, token)
if user is None:
raise HTTPException(status_code=401, detail="Invalid or expired API token")
return user
async def require_admin(user: User = Depends(get_current_api_user)) -> User:
"""Require the authenticated user to be an admin."""
if user.role != "admin":
raise HTTPException(status_code=403, detail="Admin access required")
return user

View file

@ -0,0 +1,11 @@
"""v0 API router — aggregates all sub-routers."""
from fastapi import APIRouter
from wiregui.api.v0 import configuration, devices, rules, users
router = APIRouter(prefix="/v0")
router.include_router(users.router)
router.include_router(devices.router)
router.include_router(rules.router)
router.include_router(configuration.router)

View file

@ -0,0 +1,46 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from wiregui.api.deps import get_db, require_admin
from wiregui.models.configuration import Configuration
from wiregui.models.user import User
from wiregui.schemas.configuration import ConfigurationRead, ConfigurationUpdate
router = APIRouter(prefix="/configuration", tags=["configuration"])
async def _get_config(session: AsyncSession) -> Configuration:
result = await session.execute(select(Configuration).limit(1))
config = result.scalar_one_or_none()
if not config:
config = Configuration()
session.add(config)
await session.commit()
await session.refresh(config)
return config
@router.get("/", response_model=ConfigurationRead)
async def get_configuration(
session: AsyncSession = Depends(get_db),
_admin: User = Depends(require_admin),
):
return await _get_config(session)
@router.put("/", response_model=ConfigurationRead)
async def update_configuration(
body: ConfigurationUpdate,
session: AsyncSession = Depends(get_db),
_admin: User = Depends(require_admin),
):
config = await _get_config(session)
for key, val in body.model_dump(exclude_unset=True).items():
setattr(config, key, val)
session.add(config)
await session.commit()
await session.refresh(config)
return config

119
wiregui/api/v0/devices.py Normal file
View file

@ -0,0 +1,119 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from wiregui.api.deps import get_current_api_user, get_db
from wiregui.config import get_settings
from wiregui.models.device import Device
from wiregui.models.user import User
from wiregui.schemas.device import DeviceCreate, DeviceRead, DeviceUpdate
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated
from wiregui.utils.crypto import generate_keypair, generate_preshared_key
from wiregui.utils.network import allocate_ipv4, allocate_ipv6
router = APIRouter(prefix="/devices", tags=["devices"])
@router.get("/", response_model=list[DeviceRead])
async def list_devices(
session: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_api_user),
):
if current_user.role == "admin":
result = await session.execute(select(Device).order_by(Device.inserted_at.desc()))
else:
result = await session.execute(
select(Device).where(Device.user_id == current_user.id).order_by(Device.inserted_at.desc())
)
return result.scalars().all()
@router.get("/{device_id}", response_model=DeviceRead)
async def get_device(
device_id: UUID,
session: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_api_user),
):
device = await session.get(Device, device_id)
if not device:
raise HTTPException(404, "Device not found")
if current_user.role != "admin" and device.user_id != current_user.id:
raise HTTPException(403, "Access denied")
return device
@router.post("/", response_model=DeviceRead, status_code=201)
async def create_device(
body: DeviceCreate,
session: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_api_user),
):
settings = get_settings()
owner_id = body.user_id if (body.user_id and current_user.role == "admin") else current_user.id
_private_key, public_key = generate_keypair()
psk = generate_preshared_key()
ipv4 = await allocate_ipv4(session, settings.wg_ipv4_network)
ipv6 = await allocate_ipv6(session, settings.wg_ipv6_network)
device = Device(
name=body.name,
description=body.description,
public_key=public_key,
preshared_key=psk,
ipv4=ipv4,
ipv6=ipv6,
user_id=owner_id,
)
session.add(device)
await session.commit()
await session.refresh(device)
logger.info("API: device created {} ({})", device.name, device.ipv4)
await on_device_created(device)
return device
@router.put("/{device_id}", response_model=DeviceRead)
async def update_device(
device_id: UUID,
body: DeviceUpdate,
session: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_api_user),
):
device = await session.get(Device, device_id)
if not device:
raise HTTPException(404, "Device not found")
if current_user.role != "admin" and device.user_id != current_user.id:
raise HTTPException(403, "Access denied")
for key, val in body.model_dump(exclude_unset=True).items():
setattr(device, key, val)
session.add(device)
await session.commit()
await session.refresh(device)
await on_device_updated(device)
return device
@router.delete("/{device_id}", status_code=204)
async def delete_device(
device_id: UUID,
session: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_api_user),
):
device = await session.get(Device, device_id)
if not device:
raise HTTPException(404, "Device not found")
if current_user.role != "admin" and device.user_id != current_user.id:
raise HTTPException(403, "Access denied")
await session.delete(device)
await session.commit()
logger.info("API: device deleted {}", device.name)
await on_device_deleted(device)

86
wiregui/api/v0/rules.py Normal file
View file

@ -0,0 +1,86 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from wiregui.api.deps import get_db, require_admin
from wiregui.models.rule import Rule
from wiregui.models.user import User
from wiregui.schemas.rule import RuleCreate, RuleRead, RuleUpdate
from wiregui.services.events import on_rule_created, on_rule_deleted
router = APIRouter(prefix="/rules", tags=["rules"])
@router.get("/", response_model=list[RuleRead])
async def list_rules(
session: AsyncSession = Depends(get_db),
_admin: User = Depends(require_admin),
):
result = await session.execute(select(Rule).order_by(Rule.inserted_at.desc()))
return result.scalars().all()
@router.get("/{rule_id}", response_model=RuleRead)
async def get_rule(
rule_id: UUID,
session: AsyncSession = Depends(get_db),
_admin: User = Depends(require_admin),
):
rule = await session.get(Rule, rule_id)
if not rule:
raise HTTPException(404, "Rule not found")
return rule
@router.post("/", response_model=RuleRead, status_code=201)
async def create_rule(
body: RuleCreate,
session: AsyncSession = Depends(get_db),
_admin: User = Depends(require_admin),
):
rule = Rule(**body.model_dump())
session.add(rule)
await session.commit()
await session.refresh(rule)
logger.info("API: rule created {} -> {}", rule.action, rule.destination)
await on_rule_created(rule)
return rule
@router.put("/{rule_id}", response_model=RuleRead)
async def update_rule(
rule_id: UUID,
body: RuleUpdate,
session: AsyncSession = Depends(get_db),
_admin: User = Depends(require_admin),
):
rule = await session.get(Rule, rule_id)
if not rule:
raise HTTPException(404, "Rule not found")
for key, val in body.model_dump(exclude_unset=True).items():
setattr(rule, key, val)
session.add(rule)
await session.commit()
await session.refresh(rule)
return rule
@router.delete("/{rule_id}", status_code=204)
async def delete_rule(
rule_id: UUID,
session: AsyncSession = Depends(get_db),
_admin: User = Depends(require_admin),
):
rule = await session.get(Rule, rule_id)
if not rule:
raise HTTPException(404, "Rule not found")
await session.delete(rule)
await session.commit()
logger.info("API: rule deleted {} {}", rule.action, rule.destination)
await on_rule_deleted(rule)

86
wiregui/api/v0/users.py Normal file
View file

@ -0,0 +1,86 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from wiregui.api.deps import get_db, require_admin
from wiregui.auth.passwords import hash_password
from wiregui.models.user import User
from wiregui.schemas.user import UserCreate, UserRead, UserUpdate
router = APIRouter(prefix="/users", tags=["users"])
@router.get("/", response_model=list[UserRead])
async def list_users(
session: AsyncSession = Depends(get_db),
_admin: User = Depends(require_admin),
):
result = await session.execute(select(User).order_by(User.email))
return result.scalars().all()
@router.get("/{user_id}", response_model=UserRead)
async def get_user(
user_id: UUID,
session: AsyncSession = Depends(get_db),
_admin: User = Depends(require_admin),
):
user = await session.get(User, user_id)
if not user:
raise HTTPException(404, "User not found")
return user
@router.post("/", response_model=UserRead, status_code=201)
async def create_user(
body: UserCreate,
session: AsyncSession = Depends(get_db),
_admin: User = Depends(require_admin),
):
user = User(
email=body.email,
password_hash=hash_password(body.password),
role=body.role,
)
session.add(user)
await session.commit()
await session.refresh(user)
return user
@router.put("/{user_id}", response_model=UserRead)
async def update_user(
user_id: UUID,
body: UserUpdate,
session: AsyncSession = Depends(get_db),
_admin: User = Depends(require_admin),
):
user = await session.get(User, user_id)
if not user:
raise HTTPException(404, "User not found")
updates = body.model_dump(exclude_unset=True)
if "password" in updates:
updates["password_hash"] = hash_password(updates.pop("password"))
for key, val in updates.items():
setattr(user, key, val)
session.add(user)
await session.commit()
await session.refresh(user)
return user
@router.delete("/{user_id}", status_code=204)
async def delete_user(
user_id: UUID,
session: AsyncSession = Depends(get_db),
_admin: User = Depends(require_admin),
):
user = await session.get(User, user_id)
if not user:
raise HTTPException(404, "User not found")
await session.delete(user)
await session.commit()

0
wiregui/auth/__init__.py Normal file
View file

42
wiregui/auth/api_token.py Normal file
View file

@ -0,0 +1,42 @@
"""API token authentication — Bearer token via Authorization header."""
import hashlib
import secrets
from loguru import logger
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from wiregui.models.api_token import ApiToken
from wiregui.models.user import User
from wiregui.utils.time import utcnow
def generate_api_token() -> tuple[str, str]:
"""Generate a new API token. Returns (plaintext_token, token_hash)."""
plaintext = secrets.token_urlsafe(32)
token_hash = hashlib.sha256(plaintext.encode()).hexdigest()
return plaintext, token_hash
async def resolve_bearer_token(session: AsyncSession, token: str) -> User | None:
"""Look up a Bearer token and return the associated user, or None."""
token_hash = hashlib.sha256(token.encode()).hexdigest()
result = await session.execute(
select(ApiToken).where(ApiToken.token_hash == token_hash)
)
api_token = result.scalar_one_or_none()
if api_token is None:
return None
# Check expiry
if api_token.expires_at and api_token.expires_at < utcnow():
logger.debug("API token expired for user_id={}", api_token.user_id)
return None
# Resolve user
user = await session.get(User, api_token.user_id)
if user is None or user.disabled_at is not None:
return None
return user

26
wiregui/auth/jwt.py Normal file
View file

@ -0,0 +1,26 @@
from datetime import datetime, timedelta, timezone
from jose import JWTError, jwt
from wiregui.config import get_settings
ALGORITHM = "HS256"
DEFAULT_EXPIRE_HOURS = 8
def create_access_token(
user_id: str,
role: str,
expires_delta: timedelta | None = None,
) -> str:
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(hours=DEFAULT_EXPIRE_HOURS))
payload = {"sub": user_id, "role": role, "exp": expire}
return jwt.encode(payload, get_settings().secret_key, algorithm=ALGORITHM)
def decode_access_token(token: str) -> dict | None:
"""Decode and validate a JWT. Returns the payload dict or None if invalid/expired."""
try:
return jwt.decode(token, get_settings().secret_key, algorithms=[ALGORITHM])
except JWTError:
return None

31
wiregui/auth/mfa.py Normal file
View file

@ -0,0 +1,31 @@
"""TOTP Multi-Factor Authentication using pyotp."""
import io
from urllib.parse import quote
import pyotp
import qrcode
import qrcode.image.svg
def generate_totp_secret() -> str:
"""Generate a new random TOTP secret."""
return pyotp.random_base32()
def get_totp_uri(secret: str, email: str, issuer: str = "WireGUI") -> str:
"""Build an otpauth:// URI for QR code scanning."""
return pyotp.TOTP(secret).provisioning_uri(name=email, issuer_name=issuer)
def verify_totp_code(secret: str, code: str) -> bool:
"""Verify a TOTP code against a secret. Allows 1 window of clock drift."""
return pyotp.TOTP(secret).verify(code, valid_window=1)
def generate_totp_qr_svg(uri: str) -> str:
"""Generate an SVG QR code for a TOTP provisioning URI."""
qr = qrcode.make(uri, image_factory=qrcode.image.svg.SvgPathImage)
buf = io.BytesIO()
qr.save(buf)
return buf.getvalue().decode()

View file

@ -0,0 +1,20 @@
"""NiceGUI auth middleware — redirects unauthenticated requests to /login."""
from fastapi import Request
from fastapi.responses import RedirectResponse
from starlette.middleware.base import BaseHTTPMiddleware
# Paths that don't require authentication
PUBLIC_PREFIXES = ("/login", "/_nicegui", "/api")
class AuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if any(request.url.path.startswith(p) for p in PUBLIC_PREFIXES):
return await call_next(request)
# NiceGUI stores auth state in the Starlette session (cookie-backed)
if not request.session.get("authenticated"):
return RedirectResponse(url="/login")
return await call_next(request)

59
wiregui/auth/oidc.py Normal file
View file

@ -0,0 +1,59 @@
"""OIDC authentication via authlib — provider registry and authorization code flow."""
from authlib.integrations.starlette_client import OAuth
from loguru import logger
from wiregui.db import async_session
from wiregui.models.configuration import Configuration
# Global OAuth instance — providers are registered dynamically
oauth = OAuth()
async def load_providers() -> list[dict]:
"""Load OIDC provider configs from the Configuration singleton."""
from sqlmodel import select
async with async_session() as session:
result = await session.execute(select(Configuration).limit(1))
config = result.scalar_one_or_none()
if not config:
return []
return config.openid_connect_providers or []
async def register_providers() -> None:
"""Register all configured OIDC providers with authlib. Call on startup."""
providers = await load_providers()
for p in providers:
provider_id = p.get("id")
if not provider_id:
continue
try:
oauth.register(
name=provider_id,
client_id=p.get("client_id"),
client_secret=p.get("client_secret"),
server_metadata_url=p.get("discovery_document_uri"),
client_kwargs={"scope": p.get("scope", "openid email profile")},
)
logger.info("OIDC provider registered: {}", provider_id)
except Exception as e:
logger.error("Failed to register OIDC provider {}: {}", provider_id, e)
def get_client(provider_id: str):
"""Get an authlib OAuth client for a registered provider."""
client = oauth.create_client(provider_id)
if client is None:
raise ValueError(f"OIDC provider '{provider_id}' is not registered")
return client
async def get_provider_config(provider_id: str) -> dict | None:
"""Get the config dict for a specific provider."""
providers = await load_providers()
for p in providers:
if p.get("id") == provider_id:
return p
return None

View file

@ -0,0 +1,9 @@
import bcrypt
def hash_password(plain: str) -> str:
return bcrypt.hashpw(plain.encode(), bcrypt.gensalt()).decode()
def verify_password(plain: str, hashed: str) -> bool:
return bcrypt.checkpw(plain.encode(), hashed.encode())

114
wiregui/auth/saml.py Normal file
View file

@ -0,0 +1,114 @@
"""SAML SP-initiated SSO via python3-saml."""
from loguru import logger
from onelogin.saml2.auth import OneLogin_Saml2_Auth
from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser
from wiregui.config import get_settings
def _build_saml_settings(provider_config: dict) -> dict:
"""Build python3-saml settings dict from our provider config."""
settings = get_settings()
base_url = settings.external_url
# Parse IdP metadata XML to extract endpoints and certs
idp_data = OneLogin_Saml2_IdPMetadataParser.parse(provider_config.get("metadata", ""))
idp_settings = idp_data.get("idp", {})
return {
"strict": True,
"debug": False,
"sp": {
"entityId": f"{base_url}/auth/saml/{provider_config['id']}/metadata",
"assertionConsumerService": {
"url": f"{base_url}/auth/saml/{provider_config['id']}/callback",
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
},
"NameIDFormat": "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
},
"idp": idp_settings,
"security": {
"authnRequestsSigned": provider_config.get("sign_requests", True),
"wantAssertionsSigned": provider_config.get("signed_assertion_in_resp", True),
"signMetadata": provider_config.get("sign_metadata", True),
},
}
def prepare_saml_request(request_data: dict) -> dict:
"""Prepare a dict that python3-saml expects from an HTTP request.
Args:
request_data: dict with keys: http_host, script_name, server_port,
get_data (dict), post_data (dict), https (str "on"/"off")
"""
return {
"http_host": request_data.get("http_host", "localhost"),
"script_name": request_data.get("script_name", ""),
"server_port": request_data.get("server_port", 443),
"get_data": request_data.get("get_data", {}),
"post_data": request_data.get("post_data", {}),
"https": request_data.get("https", "on"),
}
def create_saml_auth(provider_config: dict, request_data: dict) -> OneLogin_Saml2_Auth:
"""Create a python3-saml Auth instance for a provider."""
saml_settings = _build_saml_settings(provider_config)
req = prepare_saml_request(request_data)
return OneLogin_Saml2_Auth(req, saml_settings)
def get_login_url(auth: OneLogin_Saml2_Auth) -> str:
"""Get the SSO redirect URL."""
return auth.login()
def process_response(auth: OneLogin_Saml2_Auth) -> dict | None:
"""Process the SAML response and return user attributes.
Returns dict with 'email' key, or None on failure.
"""
auth.process_response()
errors = auth.get_errors()
if errors:
logger.error("SAML response errors: {}", errors)
return None
if not auth.is_authenticated():
logger.warning("SAML: user not authenticated")
return None
attrs = auth.get_attributes()
name_id = auth.get_nameid()
# Try to extract email from various attribute names
email = (
attrs.get("email", [None])[0]
or attrs.get("http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress", [None])[0]
or attrs.get("urn:oid:0.9.2342.19200300.100.1.3", [None])[0]
or name_id
)
if not email:
logger.error("SAML: no email found in attributes or NameID")
return None
return {
"email": email,
"name_id": name_id,
"attributes": {k: v for k, v in attrs.items()},
}
def get_metadata(provider_config: dict) -> str:
"""Generate SP metadata XML."""
settings = _build_saml_settings(provider_config)
from onelogin.saml2.settings import OneLogin_Saml2_Settings
saml_settings = OneLogin_Saml2_Settings(settings, sp_validation_only=True)
metadata = saml_settings.get_sp_metadata()
errors = saml_settings.validate_metadata(metadata)
if errors:
logger.error("SP metadata validation errors: {}", errors)
return metadata.decode() if isinstance(metadata, bytes) else metadata

61
wiregui/auth/seed.py Normal file
View file

@ -0,0 +1,61 @@
"""Seed the initial admin user and server keypair on first startup."""
import secrets
from loguru import logger
from sqlmodel import select
from wiregui.auth.passwords import hash_password
from wiregui.config import get_settings
from wiregui.db import async_session
from wiregui.models.configuration import Configuration
from wiregui.models.user import User
async def seed_admin() -> None:
"""Create admin user if no users exist in the database."""
async with async_session() as session:
result = await session.execute(select(User).limit(1))
if result.scalar_one_or_none() is not None:
return # users already exist
settings = get_settings()
password = settings.admin_password or secrets.token_urlsafe(16)
admin = User(
email=settings.admin_email,
password_hash=hash_password(password),
role="admin",
)
session.add(admin)
await session.commit()
logger.info("Admin user created: {}", settings.admin_email)
if settings.admin_password is None:
logger.warning("Generated admin password: {}", password)
async def ensure_server_keypair() -> None:
"""Generate and store the server WireGuard keypair in Configuration if missing."""
from wiregui.utils.crypto import generate_keypair
async with async_session() as session:
result = await session.execute(select(Configuration).limit(1))
config = result.scalar_one_or_none()
if config is None:
config = Configuration()
session.add(config)
if config.server_public_key and config.server_private_key:
return # already have keys
try:
private_key, public_key = generate_keypair()
config.server_private_key = private_key
config.server_public_key = public_key
session.add(config)
await session.commit()
logger.info("Server WireGuard keypair generated (pubkey: {}...)", public_key[:20])
except Exception as e:
logger.warning("Could not generate server keypair (wg CLI not available?): {}", e)

22
wiregui/auth/session.py Normal file
View file

@ -0,0 +1,22 @@
"""Authentication helpers for NiceGUI pages."""
from sqlmodel import select
from wiregui.db import async_session
from wiregui.models.user import User
async def authenticate_user(email: str, password: str) -> User | None:
"""Verify email/password and return the User if valid, else None."""
from wiregui.auth.passwords import verify_password
async with async_session() as session:
stmt = select(User).where(User.email == email)
user = (await session.execute(stmt)).scalar_one_or_none()
if user is None or user.password_hash is None:
return None
if not verify_password(password, user.password_hash):
return None
if user.disabled_at is not None:
return None
return user

134
wiregui/auth/webauthn.py Normal file
View file

@ -0,0 +1,134 @@
"""WebAuthn (FIDO2) MFA via the webauthn library.
Registration and authentication ceremonies for platform (native) and
cross-platform (portable/security key) authenticators.
"""
import json
from uuid import UUID
from loguru import logger
from webauthn import (
generate_authentication_options,
generate_registration_options,
verify_authentication_response,
verify_registration_response,
)
from webauthn.helpers import bytes_to_base64url, base64url_to_bytes
from webauthn.helpers.structs import (
AuthenticatorSelectionCriteria,
PublicKeyCredentialDescriptor,
ResidentKeyRequirement,
UserVerificationRequirement,
)
from wiregui.config import get_settings
def _rp_id() -> str:
"""Get the Relying Party ID from the external URL hostname."""
from urllib.parse import urlparse
return urlparse(get_settings().external_url).hostname
def _rp_name() -> str:
return "WireGUI"
def _origin() -> str:
return get_settings().external_url
def create_registration_options(user_id: UUID, user_email: str, existing_credentials: list[dict] = None) -> dict:
"""Generate WebAuthn registration options to send to the browser.
Returns a dict with 'options' (serialized for JSON) and 'challenge' (to store in session).
"""
exclude = []
for cred in (existing_credentials or []):
cred_id = cred.get("credential_id")
if cred_id:
exclude.append(PublicKeyCredentialDescriptor(id=base64url_to_bytes(cred_id)))
options = generate_registration_options(
rp_id=_rp_id(),
rp_name=_rp_name(),
user_id=str(user_id).encode(),
user_name=user_email,
user_display_name=user_email,
exclude_credentials=exclude,
authenticator_selection=AuthenticatorSelectionCriteria(
resident_key=ResidentKeyRequirement.PREFERRED,
user_verification=UserVerificationRequirement.PREFERRED,
),
)
# Serialize for JSON transport
from webauthn.helpers import options_to_json
return {
"options_json": options_to_json(options),
"challenge": bytes_to_base64url(options.challenge),
}
def verify_registration(credential_json: str, challenge: str) -> dict:
"""Verify a WebAuthn registration response from the browser.
Returns a dict with credential data to store in MFAMethod.payload.
"""
verification = verify_registration_response(
credential=credential_json,
expected_challenge=base64url_to_bytes(challenge),
expected_rp_id=_rp_id(),
expected_origin=_origin(),
)
return {
"credential_id": bytes_to_base64url(verification.credential_id),
"public_key": bytes_to_base64url(verification.credential_public_key),
"sign_count": verification.sign_count,
"attestation_type": verification.fmt if hasattr(verification, 'fmt') else "none",
}
def create_authentication_options(credentials: list[dict]) -> dict:
"""Generate WebAuthn authentication options for existing credentials.
Returns a dict with 'options' (serialized) and 'challenge' (to store in session).
"""
allow = []
for cred in credentials:
cred_id = cred.get("credential_id")
if cred_id:
allow.append(PublicKeyCredentialDescriptor(id=base64url_to_bytes(cred_id)))
options = generate_authentication_options(
rp_id=_rp_id(),
allow_credentials=allow,
user_verification=UserVerificationRequirement.PREFERRED,
)
from webauthn.helpers import options_to_json
return {
"options_json": options_to_json(options),
"challenge": bytes_to_base64url(options.challenge),
}
def verify_authentication(credential_json: str, challenge: str, stored_credential: dict) -> dict:
"""Verify a WebAuthn authentication response.
Returns updated credential data (new sign_count).
"""
verification = verify_authentication_response(
credential=credential_json,
expected_challenge=base64url_to_bytes(challenge),
expected_rp_id=_rp_id(),
expected_origin=_origin(),
credential_public_key=base64url_to_bytes(stored_credential["public_key"]),
credential_current_sign_count=stored_credential.get("sign_count", 0),
)
return {
"new_sign_count": verification.new_sign_count,
}

55
wiregui/config.py Normal file
View file

@ -0,0 +1,55 @@
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="WG_", env_file=".env")
# Database
database_url: str = "postgresql+asyncpg://wiregui:wiregui@localhost/wiregui"
# Redis / Valkey
redis_url: str = "redis://localhost:6379/0"
# Secret key for JWT signing and Fernet encryption
secret_key: str = "change-me-in-production"
# WireGuard
wg_enabled: bool = False # set True in production (requires NET_ADMIN capability)
wg_interface: str = "wg0"
wg_endpoint_host: str = "localhost"
wg_endpoint_port: int = 51820
wg_ipv4_network: str = "10.3.2.0/24"
wg_ipv6_network: str = "fd00::3:2:0/120"
wg_dns: str = "1.1.1.1, 1.0.0.1"
wg_mtu: int = 1280
wg_persistent_keepalive: int = 25
wg_allowed_ips: str = "0.0.0.0/0, ::/0"
# Auth
admin_email: str = "admin@localhost"
admin_password: str | None = None
local_auth_enabled: bool = True
magic_link_enabled: bool = True
vpn_session_duration: int = 0 # seconds, 0 = unlimited
# SMTP
smtp_host: str | None = None
smtp_port: int = 587
smtp_user: str | None = None
smtp_password: str | None = None
smtp_from: str = "wiregui@localhost"
# Logging
log_to_file: bool = True # write timestamped log file to logs/ directory
# App
host: str = "0.0.0.0"
port: int = 13000
external_url: str = "http://localhost:13000"
@lru_cache
def get_settings() -> Settings:
return Settings()

22
wiregui/db.py Normal file
View file

@ -0,0 +1,22 @@
from collections.abc import AsyncGenerator
from loguru import logger
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from wiregui.config import get_settings
engine = create_async_engine(get_settings().database_url)
async_session = async_sessionmaker(engine, expire_on_commit=False)
async def get_session() -> AsyncGenerator[AsyncSession]:
async with async_session() as session:
yield session
async def init_db() -> None:
"""Test database connectivity."""
async with engine.begin() as conn:
await conn.execute(text("SELECT 1"))
logger.info("Database connection OK")

28
wiregui/logging.py Normal file
View file

@ -0,0 +1,28 @@
"""Loguru configuration for WireGUI."""
import sys
from datetime import datetime
from loguru import logger
def setup_logging(log_to_file: bool = False) -> None:
"""Configure loguru sinks. Call once at startup."""
# Remove default stderr sink and re-add with our format
logger.remove()
logger.add(
sys.stderr,
format="<green>{time:HH:mm:ss}</green> | <level>{level:<7}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan> - <level>{message}</level>",
level="DEBUG",
)
if log_to_file:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
logger.add(
f"logs/wiregui_{timestamp}.log",
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level:<7} | {name}:{function}:{line} - {message}",
level="DEBUG",
rotation="10 MB",
retention="7 days",
)
logger.info("File logging enabled: logs/wiregui_{}.log", timestamp)

95
wiregui/main.py Normal file
View file

@ -0,0 +1,95 @@
from loguru import logger
from nicegui import app, ui
from wiregui.api.v0 import router as api_router
from wiregui.auth.seed import ensure_server_keypair, seed_admin
from wiregui.config import get_settings
from wiregui.db import init_db
from wiregui.logging import setup_logging
# Mount REST API
app.include_router(api_router, prefix="/api")
@app.get("/api/health")
async def health():
return {"status": "ok"}
# Import pages so their @ui.page decorators register routes
import wiregui.pages.account # noqa: F401
import wiregui.pages.admin.devices # noqa: F401
import wiregui.pages.admin.diagnostics # noqa: F401
import wiregui.pages.admin.rules # noqa: F401
import wiregui.pages.admin.settings # noqa: F401
import wiregui.pages.admin.users # noqa: F401
import wiregui.pages.auth_magic # noqa: F401
import wiregui.pages.auth_oidc # noqa: F401
import wiregui.pages.auth_saml # noqa: F401
import wiregui.pages.devices # noqa: F401
import wiregui.pages.home # noqa: F401
import wiregui.pages.login # noqa: F401
import wiregui.pages.mfa_challenge # noqa: F401
async def startup() -> None:
settings = get_settings()
setup_logging(log_to_file=settings.log_to_file)
await init_db()
await seed_admin()
await ensure_server_keypair()
# Register OIDC providers from config
from wiregui.auth.oidc import register_providers
await register_providers()
from wiregui.tasks import register_task
from wiregui.tasks.oidc_refresh import oidc_refresh_loop
from wiregui.tasks.connectivity import connectivity_loop
from wiregui.tasks.vpn_session import vpn_session_loop
# Always run these tasks (even without WG for OIDC refresh and connectivity)
register_task(oidc_refresh_loop(), name="oidc-refresh")
register_task(connectivity_loop(), name="connectivity-check")
if settings.wg_enabled:
from wiregui.services.firewall import setup_base_tables, setup_masquerade
from wiregui.services.wireguard import configure_interface, ensure_interface
from wiregui.tasks.reconcile import reconcile
from wiregui.tasks.stats import stats_loop
await ensure_interface()
await configure_interface()
await setup_base_tables()
await setup_masquerade()
await reconcile()
register_task(stats_loop(), name="wg-stats")
register_task(vpn_session_loop(), name="vpn-session-expiry")
else:
logger.info("WireGuard disabled (WG_WG_ENABLED=false) — running in UI-only mode")
logger.info("WireGUI ready")
async def shutdown() -> None:
from wiregui.tasks import cancel_all
await cancel_all()
app.on_startup(startup)
app.on_shutdown(shutdown)
def main() -> None:
settings = get_settings()
ui.run(
host=settings.host,
port=settings.port,
title="WireGUI",
storage_secret=settings.secret_key,
reload=True,
)
if __name__ in {"__main__", "__mp_main__"}:
main()

View file

@ -0,0 +1,21 @@
"""All SQLModel table models — imported here so Alembic autogenerate can discover them."""
from wiregui.models.api_token import ApiToken
from wiregui.models.configuration import Configuration
from wiregui.models.connectivity_check import ConnectivityCheck
from wiregui.models.device import Device
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.rule import Rule
from wiregui.models.user import User
__all__ = [
"ApiToken",
"Configuration",
"ConnectivityCheck",
"Device",
"MFAMethod",
"OIDCConnection",
"Rule",
"User",
]

View file

@ -0,0 +1,24 @@
from datetime import datetime
from wiregui.utils.time import utcnow
from uuid import UUID, uuid4
from sqlmodel import Field, Relationship, SQLModel
class ApiToken(SQLModel, table=True):
__tablename__ = "api_tokens"
id: UUID = Field(default_factory=uuid4, primary_key=True)
token_hash: str = Field(unique=True, index=True)
expires_at: datetime | None = None
user_id: UUID = Field(foreign_key="users.id", index=True)
inserted_at: datetime = Field(default_factory=utcnow)
# Relationships
user: "User" = Relationship(back_populates="api_tokens")
from wiregui.models.user import User # noqa: E402, F401

View file

@ -0,0 +1,61 @@
from datetime import datetime
from wiregui.utils.time import utcnow
from uuid import UUID, uuid4
from sqlmodel import Field, JSON, Column, SQLModel
class Configuration(SQLModel, table=True):
__tablename__ = "configurations"
id: UUID = Field(default_factory=uuid4, primary_key=True)
# Device management permissions
allow_unprivileged_device_management: bool = Field(default=True)
allow_unprivileged_device_configuration: bool = Field(default=True)
# Auth
local_auth_enabled: bool = Field(default=True)
disable_vpn_on_oidc_error: bool = Field(default=False)
# Client defaults
default_client_persistent_keepalive: int = Field(default=25)
default_client_mtu: int = Field(default=1280)
default_client_endpoint: str | None = None
default_client_dns: list[str] = Field(
default_factory=lambda: ["1.1.1.1", "1.0.0.1"],
sa_column=Column(JSON, default=["1.1.1.1", "1.0.0.1"]),
)
default_client_allowed_ips: list[str] = Field(
default_factory=lambda: ["0.0.0.0/0", "::/0"],
sa_column=Column(JSON, default=["0.0.0.0/0", "::/0"]),
)
# Server WireGuard keypair (generated on first startup)
server_private_key: str | None = None
server_public_key: str | None = None
# VPN session
vpn_session_duration: int = Field(default=0) # seconds, 0 = unlimited
# Logo
logo_url: str | None = None
logo_type: str | None = None # "url" | "file" | "upload" | None (default)
# OIDC providers (list of dicts)
# Each: {id, label, scope, response_type, client_id, client_secret,
# discovery_document_uri, redirect_uri, auto_create_users}
openid_connect_providers: list[dict] = Field(
default_factory=list, sa_column=Column(JSON, default=[])
)
# SAML identity providers (list of dicts)
# Each: {id, label, base_url, metadata, sign_requests, sign_metadata,
# signed_assertion_in_resp, signed_envelopes_in_resp, auto_create_users}
saml_identity_providers: list[dict] = Field(
default_factory=list, sa_column=Column(JSON, default=[])
)
inserted_at: datetime = Field(default_factory=utcnow)
updated_at: datetime = Field(default_factory=utcnow)

View file

@ -0,0 +1,18 @@
from datetime import datetime
from wiregui.utils.time import utcnow
from uuid import UUID, uuid4
from sqlmodel import Field, JSON, Column, SQLModel
class ConnectivityCheck(SQLModel, table=True):
__tablename__ = "connectivity_checks"
id: UUID = Field(default_factory=uuid4, primary_key=True)
url: str
response_code: int | None = None
response_headers: dict | None = Field(default=None, sa_column=Column(JSON))
response_body: str | None = None
inserted_at: datetime = Field(default_factory=utcnow)

52
wiregui/models/device.py Normal file
View file

@ -0,0 +1,52 @@
from datetime import datetime
from wiregui.utils.time import utcnow
from uuid import UUID, uuid4
from sqlmodel import Field, JSON, Column, Relationship, SQLModel
class Device(SQLModel, table=True):
__tablename__ = "devices"
id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str
description: str | None = None
public_key: str = Field(unique=True, index=True)
preshared_key: str | None = None # encrypted at application level
# Client config: use server defaults or per-device overrides
use_default_allowed_ips: bool = Field(default=True)
use_default_dns: bool = Field(default=True)
use_default_endpoint: bool = Field(default=True)
use_default_mtu: bool = Field(default=True)
use_default_persistent_keepalive: bool = Field(default=True)
# Per-device overrides (used when use_default_* is False)
endpoint: str | None = None
mtu: int | None = None
persistent_keepalive: int | None = None
allowed_ips: list[str] = Field(default_factory=list, sa_column=Column(JSON, default=[]))
dns: list[str] = Field(default_factory=list, sa_column=Column(JSON, default=[]))
# Assigned tunnel addresses
ipv4: str | None = Field(default=None, unique=True)
ipv6: str | None = Field(default=None, unique=True)
# Peer stats (updated periodically from WireGuard)
remote_ip: str | None = None
rx_bytes: int | None = None
tx_bytes: int | None = None
latest_handshake: datetime | None = None
user_id: UUID = Field(foreign_key="users.id", index=True)
inserted_at: datetime = Field(default_factory=utcnow)
updated_at: datetime = Field(default_factory=utcnow)
# Relationships
user: "User" = Relationship(back_populates="devices")
from wiregui.models.user import User # noqa: E402, F401

View file

@ -0,0 +1,27 @@
from datetime import datetime
from wiregui.utils.time import utcnow
from uuid import UUID, uuid4
from sqlmodel import Field, JSON, Column, Relationship, SQLModel
class MFAMethod(SQLModel, table=True):
__tablename__ = "mfa_methods"
id: UUID = Field(default_factory=uuid4, primary_key=True)
name: str
type: str # "totp" | "native" | "portable"
payload: dict = Field(default_factory=dict, sa_column=Column(JSON)) # encrypted at app level
last_used_at: datetime | None = None
user_id: UUID = Field(foreign_key="users.id", index=True)
inserted_at: datetime = Field(default_factory=utcnow)
updated_at: datetime = Field(default_factory=utcnow)
# Relationships
user: "User" = Relationship(back_populates="mfa_methods")
from wiregui.models.user import User # noqa: E402, F401

View file

@ -0,0 +1,27 @@
from datetime import datetime
from wiregui.utils.time import utcnow
from uuid import UUID, uuid4
from sqlmodel import Field, JSON, Column, Relationship, SQLModel
class OIDCConnection(SQLModel, table=True):
__tablename__ = "oidc_connections"
id: UUID = Field(default_factory=uuid4, primary_key=True)
provider: str
refresh_token: str | None = None # encrypted at application level
refresh_response: dict | None = Field(default=None, sa_column=Column(JSON))
refreshed_at: datetime | None = None
user_id: UUID = Field(foreign_key="users.id", index=True)
inserted_at: datetime = Field(default_factory=utcnow)
updated_at: datetime = Field(default_factory=utcnow)
# Relationships
user: "User" = Relationship(back_populates="oidc_connections")
from wiregui.models.user import User # noqa: E402, F401

27
wiregui/models/rule.py Normal file
View file

@ -0,0 +1,27 @@
from datetime import datetime
from wiregui.utils.time import utcnow
from uuid import UUID, uuid4
from sqlmodel import Field, Relationship, SQLModel
class Rule(SQLModel, table=True):
__tablename__ = "rules"
id: UUID = Field(default_factory=uuid4, primary_key=True)
action: str = Field(default="drop") # "drop" | "accept"
destination: str # CIDR notation, e.g. "10.0.0.0/8" or "0.0.0.0/0"
port_type: str | None = None # "tcp" | "udp" | None (any)
port_range: str | None = None # e.g. "80-443" or "22" or None (any)
user_id: UUID | None = Field(default=None, foreign_key="users.id", index=True)
inserted_at: datetime = Field(default_factory=utcnow)
updated_at: datetime = Field(default_factory=utcnow)
# Relationships
user: "User" = Relationship(back_populates="rules")
from wiregui.models.user import User # noqa: E402, F401

41
wiregui/models/user.py Normal file
View file

@ -0,0 +1,41 @@
from datetime import datetime
from wiregui.utils.time import utcnow
from uuid import UUID, uuid4
from sqlmodel import Field, Relationship, SQLModel
class User(SQLModel, table=True):
__tablename__ = "users"
id: UUID = Field(default_factory=uuid4, primary_key=True)
email: str = Field(unique=True, index=True)
password_hash: str | None = None
role: str = Field(default="unprivileged") # "admin" | "unprivileged"
last_signed_in_at: datetime | None = None
last_signed_in_method: str | None = None
sign_in_token_hash: str | None = None
sign_in_token_created_at: datetime | None = None
disabled_at: datetime | None = None
inserted_at: datetime = Field(default_factory=utcnow)
updated_at: datetime = Field(default_factory=utcnow)
# Relationships
devices: list["Device"] = Relationship(back_populates="user")
oidc_connections: list["OIDCConnection"] = Relationship(back_populates="user")
api_tokens: list["ApiToken"] = Relationship(back_populates="user")
mfa_methods: list["MFAMethod"] = Relationship(back_populates="user")
rules: list["Rule"] = Relationship(back_populates="user")
# Avoid circular imports — these are resolved at runtime by SQLModel
from wiregui.models.api_token import ApiToken # noqa: E402, F401
from wiregui.models.device import Device # noqa: E402, F401
from wiregui.models.mfa_method import MFAMethod # noqa: E402, F401
from wiregui.models.oidc_connection import OIDCConnection # noqa: E402, F401
from wiregui.models.rule import Rule # noqa: E402, F401

View file

388
wiregui/pages/account.py Normal file
View file

@ -0,0 +1,388 @@
"""User account page — password change, MFA management, API tokens."""
from uuid import UUID
from loguru import logger
from nicegui import app, ui
from sqlmodel import select
import json
from wiregui.auth.api_token import generate_api_token
from wiregui.auth.mfa import generate_totp_qr_svg, generate_totp_secret, get_totp_uri, verify_totp_code
from wiregui.auth.passwords import hash_password, verify_password
from wiregui.auth.webauthn import create_registration_options, verify_registration
from wiregui.db import async_session
from wiregui.models.api_token import ApiToken
from wiregui.models.mfa_method import MFAMethod
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.user import User
from wiregui.pages.layout import layout
from wiregui.utils.time import utcnow
@ui.page("/account")
async def account_page():
if not app.storage.user.get("authenticated"):
return ui.navigate.to("/login")
layout()
user_id = UUID(app.storage.user["user_id"])
async with async_session() as session:
user = await session.get(User, user_id)
with ui.column().classes("w-full p-4"):
ui.label("Account Settings").classes("text-h5 q-mb-md")
with ui.tabs().classes("w-full") as tabs:
profile_tab = ui.tab("Profile")
mfa_tab = ui.tab("Two-Factor Auth")
tokens_tab = ui.tab("API Tokens")
with ui.tab_panels(tabs, value=profile_tab).classes("w-full"):
# === Profile ===
with ui.tab_panel(profile_tab):
with ui.card().classes("w-full"):
ui.label("Account Details").classes("text-subtitle1 text-bold")
ui.separator()
with ui.grid(columns=2).classes("w-full gap-2 q-pa-sm"):
ui.label("Email:").classes("text-bold")
ui.label(user.email)
ui.label("Role:").classes("text-bold")
ui.label(user.role)
ui.label("Last Sign-in:").classes("text-bold")
ui.label(str(user.last_signed_in_at)[:19] if user.last_signed_in_at else "-")
ui.label("Method:").classes("text-bold")
ui.label(user.last_signed_in_method or "-")
with ui.card().classes("w-full q-mt-md"):
ui.label("Change Password").classes("text-subtitle1 text-bold")
ui.separator()
current_pw = ui.input("Current Password", password=True, password_toggle_button=True).props("outlined dense").classes("w-full")
new_pw = ui.input("New Password", password=True, password_toggle_button=True).props("outlined dense").classes("w-full")
confirm_pw = ui.input("Confirm New Password", password=True, password_toggle_button=True).props("outlined dense").classes("w-full")
async def change_password():
if not current_pw.value or not new_pw.value:
ui.notify("Fill in all password fields", type="negative")
return
if new_pw.value != confirm_pw.value:
ui.notify("New passwords do not match", type="negative")
return
if len(new_pw.value) < 8:
ui.notify("Password must be at least 8 characters", type="negative")
return
async with async_session() as session:
u = await session.get(User, user_id)
if not verify_password(current_pw.value, u.password_hash):
ui.notify("Current password is incorrect", type="negative")
return
u.password_hash = hash_password(new_pw.value)
session.add(u)
await session.commit()
logger.info("Password changed for {}", user.email)
ui.notify("Password changed", type="positive")
current_pw.value = ""
new_pw.value = ""
confirm_pw.value = ""
ui.button("Change Password", on_click=change_password).props("color=primary").classes("q-mt-sm")
# OIDC connections
async with async_session() as session:
oidc_conns = (await session.execute(
select(OIDCConnection).where(OIDCConnection.user_id == user_id)
)).scalars().all()
if oidc_conns:
with ui.card().classes("w-full q-mt-md"):
ui.label("Connected SSO Providers").classes("text-subtitle1 text-bold")
ui.separator()
for conn in oidc_conns:
with ui.row().classes("w-full items-center justify-between q-pa-xs"):
ui.label(f"{conn.provider}").classes("text-bold")
ui.label(f"Last refreshed: {str(conn.refreshed_at)[:19] if conn.refreshed_at else 'Never'}")
# === MFA ===
with ui.tab_panel(mfa_tab):
await _render_mfa_panel(user_id, user.email)
# === API Tokens ===
with ui.tab_panel(tokens_tab):
await _render_tokens_panel(user_id)
async def _render_mfa_panel(user_id: UUID, email: str):
"""Render the MFA management tab."""
async def load_methods():
async with async_session() as session:
result = await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user_id).order_by(MFAMethod.inserted_at)
)
return result.scalars().all()
async def refresh_methods():
methods = await load_methods()
methods_container.clear()
with methods_container:
if methods:
for m in methods:
with ui.row().classes("w-full items-center justify-between q-pa-xs"):
with ui.row().classes("items-center gap-2"):
ui.icon("security").props("color=primary")
ui.label(m.name).classes("text-bold")
ui.label(f"({m.type})").classes("text-caption text-grey-7")
with ui.row().classes("items-center gap-2"):
ui.label(f"Last used: {str(m.last_used_at)[:19] if m.last_used_at else 'Never'}").classes("text-caption")
ui.button(icon="delete", on_click=lambda mid=m.id: delete_method(mid)).props("flat dense color=negative")
ui.separator()
else:
ui.label("No MFA methods configured.").classes("text-caption text-grey-7 q-pa-sm")
async def delete_method(method_id):
async with async_session() as session:
m = await session.get(MFAMethod, method_id)
if m and m.user_id == user_id:
await session.delete(m)
await session.commit()
logger.info("MFA method deleted for user {}", email)
ui.notify("MFA method removed")
await refresh_methods()
# Registration state
registration = {"secret": None}
def start_registration():
secret = generate_totp_secret()
registration["secret"] = secret
uri = get_totp_uri(secret, email)
svg = generate_totp_qr_svg(uri)
reg_container.clear()
with reg_container:
ui.label("Scan this QR code with your authenticator app:").classes("text-body2")
ui.html(svg).classes("w-64 q-my-sm")
ui.label(f"Or enter this secret manually: {secret}").classes("text-caption font-mono")
reg_name_input = ui.input("Method Name", value="Authenticator").props("outlined dense").classes("w-full")
reg_code_input = ui.input("Verification Code", placeholder="Enter 6-digit code").props("outlined dense maxlength=6").classes("w-full")
async def verify_and_save():
code = reg_code_input.value.strip()
name = reg_name_input.value.strip() or "Authenticator"
if not verify_totp_code(registration["secret"], code):
ui.notify("Invalid code — check your authenticator", type="negative")
return
async with async_session() as session:
method = MFAMethod(
name=name,
type="totp",
payload={"secret": registration["secret"]},
user_id=user_id,
)
session.add(method)
await session.commit()
logger.info("MFA TOTP registered for {}", email)
ui.notify("MFA method added!", type="positive")
registration["secret"] = None
reg_container.clear()
await refresh_methods()
ui.button("Verify & Save", on_click=verify_and_save).props("color=primary").classes("q-mt-sm")
ui.button("Cancel", on_click=lambda: reg_container.clear()).props("flat")
with ui.card().classes("w-full"):
ui.label("Two-Factor Authentication Methods").classes("text-subtitle1 text-bold")
ui.separator()
methods_container = ui.column().classes("w-full")
await refresh_methods()
with ui.row().classes("q-mt-sm gap-2"):
ui.button("Add TOTP Method", icon="add", on_click=start_registration).props("outline")
ui.button("Add Security Key", icon="key", on_click=lambda: start_webauthn_registration()).props("outline")
reg_container = ui.column().classes("w-full q-mt-md")
webauthn_state = {"challenge": None}
async def start_webauthn_registration():
# Get existing webauthn credentials to exclude
existing = []
async with async_session() as session:
from sqlmodel import select as sel
result = await session.execute(
sel(MFAMethod).where(MFAMethod.user_id == user_id, MFAMethod.type.in_(["native", "portable"]))
)
for m in result.scalars().all():
existing.append(m.payload)
try:
reg_data = create_registration_options(user_id, email, existing)
except Exception as e:
ui.notify(f"WebAuthn not available: {e}", type="negative")
return
webauthn_state["challenge"] = reg_data["challenge"]
options_json = reg_data["options_json"]
# Call browser's navigator.credentials.create() via JavaScript
js = f"""
async function() {{
try {{
const options = JSON.parse('{options_json}');
// Convert base64url strings to ArrayBuffers
options.challenge = Uint8Array.from(atob(options.challenge.replace(/-/g,'+').replace(/_/g,'/')), c => c.charCodeAt(0));
options.user.id = Uint8Array.from(atob(options.user.id.replace(/-/g,'+').replace(/_/g,'/')), c => c.charCodeAt(0));
if (options.excludeCredentials) {{
options.excludeCredentials = options.excludeCredentials.map(c => ({{
...c,
id: Uint8Array.from(atob(c.id.replace(/-/g,'+').replace(/_/g,'/')), ch => ch.charCodeAt(0))
}}));
}}
const credential = await navigator.credentials.create({{publicKey: options}});
// Serialize the response
const response = {{
id: credential.id,
rawId: btoa(String.fromCharCode(...new Uint8Array(credential.rawId))).replace(/\\+/g,'-').replace(/\\//g,'_').replace(/=/g,''),
type: credential.type,
response: {{
attestationObject: btoa(String.fromCharCode(...new Uint8Array(credential.response.attestationObject))).replace(/\\+/g,'-').replace(/\\//g,'_').replace(/=/g,''),
clientDataJSON: btoa(String.fromCharCode(...new Uint8Array(credential.response.clientDataJSON))).replace(/\\+/g,'-').replace(/\\//g,'_').replace(/=/g,''),
}},
}};
return JSON.stringify(response);
}} catch(e) {{
return JSON.stringify({{"error": e.message}});
}}
}}
"""
result = await ui.run_javascript(f"({js})()")
await _handle_webauthn_response(result)
async def _handle_webauthn_response(result_json: str):
try:
result = json.loads(result_json)
except (json.JSONDecodeError, TypeError):
ui.notify("WebAuthn response error", type="negative")
return
if "error" in result:
ui.notify(f"WebAuthn failed: {result['error']}", type="negative")
return
challenge = webauthn_state.get("challenge")
if not challenge:
ui.notify("No pending WebAuthn challenge", type="negative")
return
try:
credential_data = verify_registration(result_json, challenge)
except Exception as e:
ui.notify(f"Verification failed: {e}", type="negative")
return
async with async_session() as session:
method = MFAMethod(
name="Security Key",
type="portable",
payload=credential_data,
user_id=user_id,
)
session.add(method)
await session.commit()
logger.info("WebAuthn key registered for {}", email)
ui.notify("Security key registered!", type="positive")
webauthn_state["challenge"] = None
await refresh_methods()
async def _render_tokens_panel(user_id: UUID):
"""Render the API tokens tab."""
async def load_tokens():
async with async_session() as session:
result = await session.execute(
select(ApiToken).where(ApiToken.user_id == user_id).order_by(ApiToken.inserted_at.desc())
)
return result.scalars().all()
async def refresh_tokens():
tokens = await load_tokens()
token_table.rows = [
{
"id": str(t.id),
"created": str(t.inserted_at)[:19],
"expires": str(t.expires_at)[:19] if t.expires_at else "Never",
"status": "Expired" if t.expires_at and t.expires_at < utcnow() else "Active",
}
for t in tokens
]
token_table.update()
async def create_token():
from datetime import timedelta
days = int(token_days.value) if token_days.value else 30
plaintext, token_hash = generate_api_token()
expires_at = utcnow() + timedelta(days=days) if days > 0 else None
async with async_session() as session:
token = ApiToken(token_hash=token_hash, expires_at=expires_at, user_id=user_id)
session.add(token)
await session.commit()
logger.info("API token created (expires in {} days)", days)
# Show the token once
with ui.dialog(value=True) as token_dialog:
with ui.card().classes("w-96"):
ui.label("API Token Created").classes("text-h6")
ui.label("Copy this token now — it won't be shown again.").classes("text-caption text-negative")
ui.input(value=plaintext).props("readonly outlined dense").classes("w-full font-mono q-mt-sm")
ui.button("Close", on_click=token_dialog.close).props("flat").classes("w-full q-mt-sm")
await refresh_tokens()
async def delete_token(token_id: str):
async with async_session() as session:
t = await session.get(ApiToken, UUID(token_id))
if t and t.user_id == user_id:
await session.delete(t)
await session.commit()
ui.notify("Token deleted")
await refresh_tokens()
with ui.card().classes("w-full"):
ui.label("API Tokens").classes("text-subtitle1 text-bold")
ui.separator()
ui.label("Use API tokens for programmatic access to the REST API.").classes("text-caption text-grey-7")
token_columns = [
{"name": "created", "label": "Created", "field": "created", "align": "left"},
{"name": "expires", "label": "Expires", "field": "expires", "align": "left"},
{"name": "status", "label": "Status", "field": "status", "align": "left"},
{"name": "actions", "label": "", "field": "id", "align": "center"},
]
token_table = ui.table(columns=token_columns, rows=[], row_key="id").classes("w-full")
token_table.add_slot(
"body-cell-actions",
'''
<q-td :props="props">
<q-btn flat dense icon="delete" color="negative"
@click.stop="() => $parent.$emit('delete', props.row.id)" />
</q-td>
''',
)
token_table.on("delete", lambda e: delete_token(e.args))
with ui.row().classes("items-center gap-2 q-mt-sm"):
token_days = ui.input("Expires in (days)", value="30").props("outlined dense").classes("w-40")
ui.button("Create Token", icon="add", on_click=create_token).props("color=primary")
await refresh_tokens()

View file

View file

@ -0,0 +1,350 @@
"""Admin device management — view and manage all devices across all users."""
import io
from uuid import UUID
import qrcode
import qrcode.image.svg
from loguru import logger
from nicegui import app, ui
from sqlmodel import select
from wiregui.config import get_settings
from wiregui.db import async_session
from wiregui.models.device import Device
from wiregui.models.user import User
from wiregui.pages.layout import layout
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated
from wiregui.utils.crypto import generate_keypair, generate_preshared_key
from wiregui.utils.network import allocate_ipv4, allocate_ipv6
from wiregui.utils.server_key import get_server_public_key
from wiregui.utils.wg_conf import build_client_config
def _guard():
if not app.storage.user.get("authenticated") or app.storage.user.get("role") != "admin":
ui.navigate.to("/login")
return False
return True
def _format_bytes(b: int | None) -> str:
if b is None:
return "-"
for unit in ("B", "KB", "MB", "GB", "TB"):
if b < 1024:
return f"{b:.1f} {unit}"
b /= 1024
return f"{b:.1f} PB"
@ui.page("/admin/devices")
async def admin_devices_page():
if not _guard():
return
layout()
# Load users for filter and create form
async with async_session() as session:
users = (await session.execute(select(User).order_by(User.email))).scalars().all()
user_map = {str(u.id): u.email for u in users}
async def load_devices(user_filter: str | None = None) -> list[dict]:
async with async_session() as session:
stmt = select(Device).order_by(Device.inserted_at.desc())
if user_filter and user_filter != "all":
stmt = stmt.where(Device.user_id == UUID(user_filter))
result = await session.execute(stmt)
return [
{
"id": str(d.id),
"name": d.name,
"user": user_map.get(str(d.user_id), "Unknown"),
"ipv4": d.ipv4 or "-",
"ipv6": d.ipv6 or "-",
"public_key": d.public_key[:16] + "...",
"rx": _format_bytes(d.rx_bytes),
"tx": _format_bytes(d.tx_bytes),
"handshake": str(d.latest_handshake)[:19] if d.latest_handshake else "-",
}
for d in result.scalars().all()
]
async def refresh_table():
table.rows = await load_devices(user_filter_select.value)
table.update()
async def on_filter_change():
await refresh_table()
# --- Create device ---
async def create_device():
name = create_name.value.strip()
owner_id = create_user_select.value
if not name or not owner_id:
ui.notify("Name and user are required", type="negative")
return
try:
settings = get_settings()
private_key, public_key = generate_keypair()
psk = generate_preshared_key()
async with async_session() as session:
ipv4 = await allocate_ipv4(session, settings.wg_ipv4_network)
ipv6 = await allocate_ipv6(session, settings.wg_ipv6_network)
device = Device(
name=name,
description=create_desc.value.strip() or None,
public_key=public_key,
preshared_key=psk,
ipv4=ipv4,
ipv6=ipv6,
user_id=UUID(owner_id),
# Apply config overrides if not using defaults
use_default_allowed_ips=create_use_default_ips.value,
use_default_dns=create_use_default_dns.value,
use_default_endpoint=create_use_default_endpoint.value,
use_default_mtu=create_use_default_mtu.value,
use_default_persistent_keepalive=create_use_default_keepalive.value,
endpoint=create_endpoint.value.strip() or None if not create_use_default_endpoint.value else None,
dns=([s.strip() for s in create_dns.value.split(",") if s.strip()]
if not create_use_default_dns.value and create_dns.value else []),
mtu=int(create_mtu.value) if not create_use_default_mtu.value and create_mtu.value else None,
persistent_keepalive=(int(create_keepalive.value)
if not create_use_default_keepalive.value and create_keepalive.value else None),
allowed_ips=([s.strip() for s in create_allowed_ips.value.split(",") if s.strip()]
if not create_use_default_ips.value and create_allowed_ips.value else []),
)
session.add(device)
await session.commit()
await session.refresh(device)
logger.info("Admin created device: {} for {}", device.name, user_map.get(owner_id))
await on_device_created(device)
# Show config
server_pubkey = await get_server_public_key()
config_text = build_client_config(device, private_key, server_pubkey)
_show_config_dialog(device.name, config_text)
create_dialog.close()
_reset_create_form()
await refresh_table()
except Exception as e:
logger.error("Failed to create device: {}", e)
ui.notify(f"Error: {e}", type="negative")
def _reset_create_form():
create_name.value = ""
create_desc.value = ""
create_use_default_ips.value = True
create_use_default_dns.value = True
create_use_default_endpoint.value = True
create_use_default_mtu.value = True
create_use_default_keepalive.value = True
# --- Edit device ---
edit_device_id = {"value": None}
async def open_edit(device_id: str):
async with async_session() as session:
device = await session.get(Device, UUID(device_id))
if not device:
return
edit_device_id["value"] = device_id
edit_name.value = device.name
edit_desc.value = device.description or ""
edit_use_default_ips.value = device.use_default_allowed_ips
edit_use_default_dns.value = device.use_default_dns
edit_use_default_endpoint.value = device.use_default_endpoint
edit_use_default_mtu.value = device.use_default_mtu
edit_use_default_keepalive.value = device.use_default_persistent_keepalive
edit_endpoint.value = device.endpoint or ""
edit_dns.value = ", ".join(device.dns) if device.dns else ""
edit_mtu.value = str(device.mtu) if device.mtu else ""
edit_keepalive.value = str(device.persistent_keepalive) if device.persistent_keepalive else ""
edit_allowed_ips.value = ", ".join(device.allowed_ips) if device.allowed_ips else ""
edit_dialog.open()
async def save_edit():
did = edit_device_id["value"]
if not did:
return
async with async_session() as session:
device = await session.get(Device, UUID(did))
if not device:
return
device.name = edit_name.value.strip()
device.description = edit_desc.value.strip() or None
device.use_default_allowed_ips = edit_use_default_ips.value
device.use_default_dns = edit_use_default_dns.value
device.use_default_endpoint = edit_use_default_endpoint.value
device.use_default_mtu = edit_use_default_mtu.value
device.use_default_persistent_keepalive = edit_use_default_keepalive.value
if not device.use_default_endpoint:
device.endpoint = edit_endpoint.value.strip() or None
if not device.use_default_dns:
device.dns = [s.strip() for s in edit_dns.value.split(",") if s.strip()]
if not device.use_default_mtu:
device.mtu = int(edit_mtu.value) if edit_mtu.value else None
if not device.use_default_persistent_keepalive:
device.persistent_keepalive = int(edit_keepalive.value) if edit_keepalive.value else None
if not device.use_default_allowed_ips:
device.allowed_ips = [s.strip() for s in edit_allowed_ips.value.split(",") if s.strip()]
session.add(device)
await session.commit()
await session.refresh(device)
await on_device_updated(device)
logger.info("Admin updated device: {}", edit_name.value)
ui.notify("Device updated")
edit_dialog.close()
await refresh_table()
# --- Delete device ---
async def delete_device(device_id: str):
async with async_session() as session:
device = await session.get(Device, UUID(device_id))
if device:
await session.delete(device)
await session.commit()
logger.info("Admin deleted device: {}", device.name)
await on_device_deleted(device)
ui.notify(f"Deleted {device.name}")
await refresh_table()
# --- Page content ---
with ui.column().classes("w-full p-4"):
with ui.row().classes("w-full items-center justify-between"):
ui.label("All Devices").classes("text-h5")
with ui.row().classes("items-center gap-4"):
filter_options = {"all": "All Users"}
filter_options.update(user_map)
user_filter_select = ui.select(
filter_options, value="all", label="Filter by User",
on_change=lambda: on_filter_change(),
).props("outlined dense").classes("w-48")
ui.button("Add Device", icon="add", on_click=lambda: create_dialog.open()).props("color=primary")
columns = [
{"name": "name", "label": "Name", "field": "name", "align": "left", "sortable": True},
{"name": "user", "label": "User", "field": "user", "align": "left", "sortable": True},
{"name": "ipv4", "label": "IPv4", "field": "ipv4", "align": "left"},
{"name": "ipv6", "label": "IPv6", "field": "ipv6", "align": "left"},
{"name": "public_key", "label": "Public Key", "field": "public_key", "align": "left"},
{"name": "rx", "label": "RX", "field": "rx", "align": "right"},
{"name": "tx", "label": "TX", "field": "tx", "align": "right"},
{"name": "handshake", "label": "Last Handshake", "field": "handshake", "align": "left"},
{"name": "actions", "label": "", "field": "id", "align": "center"},
]
table = ui.table(columns=columns, rows=[], row_key="id").classes("w-full")
table.add_slot(
"body-cell-actions",
'''
<q-td :props="props">
<q-btn flat dense icon="edit" color="primary"
@click.stop="() => $parent.$emit('edit', props.row.id)" />
<q-btn flat dense icon="delete" color="negative"
@click.stop="() => $parent.$emit('delete', props.row.id)" />
</q-td>
''',
)
table.on("edit", lambda e: open_edit(e.args))
table.on("delete", lambda e: delete_device(e.args))
# --- Create dialog (full form) ---
with ui.dialog() as create_dialog:
with ui.card().classes("w-[600px]"):
ui.label("New Device").classes("text-h6")
create_user_select = ui.select(
user_map, value=list(user_map.keys())[0] if user_map else None,
label="Owner",
).props("outlined dense").classes("w-full")
create_name = ui.input("Device Name").props("outlined dense").classes("w-full")
create_desc = ui.input("Description (optional)").props("outlined dense").classes("w-full")
ui.separator().classes("q-my-sm")
ui.label("Configuration Overrides").classes("text-subtitle2")
with ui.grid(columns=2).classes("w-full gap-2"):
create_use_default_ips = ui.switch("Use default Allowed IPs", value=True)
create_allowed_ips = ui.input("Allowed IPs", placeholder="0.0.0.0/0, ::/0").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_ips, "value", backward=lambda v: not v)
create_use_default_dns = ui.switch("Use default DNS", value=True)
create_dns = ui.input("DNS Servers", placeholder="1.1.1.1, 1.0.0.1").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_dns, "value", backward=lambda v: not v)
create_use_default_endpoint = ui.switch("Use default Endpoint", value=True)
create_endpoint = ui.input("Endpoint", placeholder="vpn.example.com").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_endpoint, "value", backward=lambda v: not v)
create_use_default_mtu = ui.switch("Use default MTU", value=True)
create_mtu = ui.input("MTU", placeholder="1280").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_mtu, "value", backward=lambda v: not v)
create_use_default_keepalive = ui.switch("Use default Keepalive", value=True)
create_keepalive = ui.input("Persistent Keepalive", placeholder="25").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_keepalive, "value", backward=lambda v: not v)
with ui.row().classes("w-full justify-end q-mt-sm"):
ui.button("Cancel", on_click=create_dialog.close).props("flat")
ui.button("Create", on_click=create_device).props("color=primary")
# --- Edit dialog (full form) ---
with ui.dialog() as edit_dialog:
with ui.card().classes("w-[600px]"):
ui.label("Edit Device").classes("text-h6")
edit_name = ui.input("Device Name").props("outlined dense").classes("w-full")
edit_desc = ui.input("Description").props("outlined dense").classes("w-full")
ui.separator().classes("q-my-sm")
ui.label("Configuration Overrides").classes("text-subtitle2")
with ui.grid(columns=2).classes("w-full gap-2"):
edit_use_default_ips = ui.switch("Use default Allowed IPs", value=True)
edit_allowed_ips = ui.input("Allowed IPs").props("outlined dense").classes("w-full").bind_enabled_from(edit_use_default_ips, "value", backward=lambda v: not v)
edit_use_default_dns = ui.switch("Use default DNS", value=True)
edit_dns = ui.input("DNS Servers").props("outlined dense").classes("w-full").bind_enabled_from(edit_use_default_dns, "value", backward=lambda v: not v)
edit_use_default_endpoint = ui.switch("Use default Endpoint", value=True)
edit_endpoint = ui.input("Endpoint").props("outlined dense").classes("w-full").bind_enabled_from(edit_use_default_endpoint, "value", backward=lambda v: not v)
edit_use_default_mtu = ui.switch("Use default MTU", value=True)
edit_mtu = ui.input("MTU").props("outlined dense").classes("w-full").bind_enabled_from(edit_use_default_mtu, "value", backward=lambda v: not v)
edit_use_default_keepalive = ui.switch("Use default Keepalive", value=True)
edit_keepalive = ui.input("Persistent Keepalive").props("outlined dense").classes("w-full").bind_enabled_from(edit_use_default_keepalive, "value", backward=lambda v: not v)
with ui.row().classes("w-full justify-end q-mt-sm"):
ui.button("Cancel", on_click=edit_dialog.close).props("flat")
ui.button("Save", on_click=save_edit).props("color=primary")
await refresh_table()
# Auto-refresh stats every 30 seconds
ui.timer(30, refresh_table)
def _show_config_dialog(device_name: str, config_text: str):
with ui.dialog(value=True) as dialog:
with ui.card().classes("w-96"):
ui.label(f"Config for {device_name}").classes("text-h6")
ui.label("Save this — the private key won't be shown again.").classes("text-caption text-negative")
ui.textarea(value=config_text).props("readonly outlined").classes("w-full font-mono text-xs q-mt-sm").style("min-height: 200px")
try:
qr = qrcode.make(config_text, image_factory=qrcode.image.svg.SvgPathImage)
buf = io.BytesIO()
qr.save(buf)
ui.html(buf.getvalue().decode()).classes("w-full q-mt-sm")
except Exception:
pass
ui.button("Download .conf", on_click=lambda: ui.download(config_text.encode(), f"{device_name}.conf")).props("color=primary outline").classes("w-full q-mt-sm")
ui.button("Close", on_click=dialog.close).props("flat").classes("w-full")

View file

@ -0,0 +1,162 @@
"""Admin diagnostics page — connectivity checks, WG status, peer stats."""
from nicegui import app, ui
from sqlmodel import select
from wiregui.config import get_settings
from wiregui.db import async_session
from wiregui.models.connectivity_check import ConnectivityCheck
from wiregui.models.device import Device
from wiregui.pages.layout import layout
from wiregui.services import notifications
def _guard():
if not app.storage.user.get("authenticated") or app.storage.user.get("role") != "admin":
ui.navigate.to("/login")
return False
return True
def _format_bytes(b: int | None) -> str:
if b is None:
return "-"
for unit in ("B", "KB", "MB", "GB", "TB"):
if b < 1024:
return f"{b:.1f} {unit}"
b /= 1024
return f"{b:.1f} PB"
@ui.page("/admin/diagnostics")
async def diagnostics_page():
if not _guard():
return
layout()
settings = get_settings()
with ui.column().classes("w-full p-4"):
ui.label("Diagnostics").classes("text-h5 q-mb-md")
# --- WireGuard Status ---
with ui.card().classes("w-full"):
ui.label("WireGuard Interface").classes("text-subtitle1 text-bold")
ui.separator()
with ui.grid(columns=2).classes("w-full gap-2 q-pa-sm"):
ui.label("Interface:").classes("text-bold")
ui.label(settings.wg_interface)
ui.label("Status:").classes("text-bold")
ui.label("Enabled" if settings.wg_enabled else "Disabled (UI-only mode)").classes(
"text-positive" if settings.wg_enabled else "text-warning"
)
ui.label("IPv4 Network:").classes("text-bold")
ui.label(settings.wg_ipv4_network)
ui.label("IPv6 Network:").classes("text-bold")
ui.label(settings.wg_ipv6_network)
ui.label("Endpoint:").classes("text-bold")
ui.label(f"{settings.wg_endpoint_host}:{settings.wg_endpoint_port}")
# --- Active Peers ---
with ui.card().classes("w-full q-mt-md"):
ui.label("Active Peers (from DB)").classes("text-subtitle1 text-bold")
ui.separator()
async with async_session() as session:
result = await session.execute(
select(Device).where(Device.latest_handshake.is_not(None)).order_by(Device.latest_handshake.desc())
)
active_devices = result.scalars().all()
if active_devices:
peer_columns = [
{"name": "name", "label": "Name", "field": "name", "align": "left"},
{"name": "public_key", "label": "Public Key", "field": "public_key", "align": "left"},
{"name": "ipv4", "label": "IPv4", "field": "ipv4", "align": "left"},
{"name": "endpoint", "label": "Remote IP", "field": "endpoint", "align": "left"},
{"name": "handshake", "label": "Last Handshake", "field": "handshake", "align": "left"},
{"name": "rx", "label": "RX", "field": "rx", "align": "right"},
{"name": "tx", "label": "TX", "field": "tx", "align": "right"},
]
peer_rows = [
{
"name": d.name,
"public_key": d.public_key[:16] + "...",
"ipv4": d.ipv4 or "-",
"endpoint": d.remote_ip or "-",
"handshake": str(d.latest_handshake)[:19] if d.latest_handshake else "-",
"rx": _format_bytes(d.rx_bytes),
"tx": _format_bytes(d.tx_bytes),
}
for d in active_devices
]
ui.table(columns=peer_columns, rows=peer_rows, row_key="name").classes("w-full")
else:
ui.label("No active peers with recent handshakes.").classes("text-caption text-grey-7 q-pa-sm")
# --- Connectivity Checks ---
with ui.card().classes("w-full q-mt-md"):
ui.label("WAN Connectivity Checks").classes("text-subtitle1 text-bold")
ui.separator()
async with async_session() as session:
result = await session.execute(
select(ConnectivityCheck).order_by(ConnectivityCheck.inserted_at.desc()).limit(20)
)
checks = result.scalars().all()
if checks:
check_columns = [
{"name": "time", "label": "Checked At", "field": "time", "align": "left"},
{"name": "url", "label": "URL", "field": "url", "align": "left"},
{"name": "status", "label": "Status", "field": "status", "align": "center"},
{"name": "body", "label": "Response", "field": "body", "align": "left"},
]
check_rows = [
{
"time": str(c.inserted_at)[:19],
"url": c.url,
"status": str(c.response_code or "Error"),
"body": (c.response_body or "")[:50],
}
for c in checks
]
ui.table(columns=check_columns, rows=check_rows, row_key="time").classes("w-full")
else:
ui.label("No connectivity checks recorded yet.").classes("text-caption text-grey-7 q-pa-sm")
# --- Notifications ---
with ui.card().classes("w-full q-mt-md"):
ui.label("System Notifications").classes("text-subtitle1 text-bold")
ui.separator()
notifs = notifications.current()
if notifs:
for n in notifs:
color = {"error": "negative", "warning": "warning", "info": "info"}.get(n.severity, "grey")
with ui.row().classes("w-full items-center q-pa-xs"):
ui.icon("error" if n.severity == "error" else "warning" if n.severity == "warning" else "info").props(f"color={color}")
ui.label(f"{n.timestamp.strftime('%H:%M:%S')}{n.message}").classes("text-sm")
if n.user:
ui.label(f"({n.user})").classes("text-caption text-grey-7")
ui.button(icon="close", on_click=lambda nid=n.id: _clear_notif(nid)).props("flat dense size=xs")
else:
ui.label("No notifications.").classes("text-caption text-grey-7 q-pa-sm")
if notifs:
ui.button("Clear All", on_click=lambda: _clear_all_notifs()).props("flat color=negative").classes("q-mt-sm")
def _clear_notif(nid: str):
notifications.clear(nid)
ui.navigate.to("/admin/diagnostics")
def _clear_all_notifs():
notifications.clear_all()
ui.navigate.to("/admin/diagnostics")

View file

@ -0,0 +1,228 @@
"""Admin firewall rules management page."""
from uuid import UUID
from loguru import logger
from nicegui import app, ui
from sqlmodel import select
from wiregui.db import async_session
from wiregui.models.rule import Rule
from wiregui.models.user import User
from wiregui.pages.layout import layout
from wiregui.services.events import on_rule_created, on_rule_deleted, on_rule_updated
@ui.page("/admin/rules")
async def rules_page():
if not app.storage.user.get("authenticated") or app.storage.user.get("role") != "admin":
return ui.navigate.to("/login")
layout()
# Load users for the dropdown
async with async_session() as session:
users = (await session.execute(select(User).order_by(User.email))).scalars().all()
user_options = {str(u.id): u.email for u in users}
async def load_rules() -> list[dict]:
async with async_session() as session:
result = await session.execute(select(Rule).order_by(Rule.inserted_at.desc()))
rules = result.scalars().all()
return [
{
"id": str(r.id),
"action": r.action,
"destination": r.destination,
"port_type": r.port_type or "any",
"port_range": r.port_range or "any",
"user": user_options.get(str(r.user_id), "Global") if r.user_id else "Global",
}
for r in rules
]
async def refresh_table():
table.rows = await load_rules()
table.update()
async def create_rule():
dest = dest_input.value.strip()
if not dest:
ui.notify("Destination is required", type="negative")
return
action_val = action_select.value
port_type_val = port_type_select.value if port_type_select.value != "any" else None
port_range_val = port_range_input.value.strip() or None
user_id_val = user_select.value if user_select.value != "global" else None
async with async_session() as session:
rule = Rule(
action=action_val,
destination=dest,
port_type=port_type_val,
port_range=port_range_val,
user_id=UUID(user_id_val) if user_id_val else None,
)
session.add(rule)
await session.commit()
await session.refresh(rule)
logger.info("Rule created: {} {} -> {}", rule.action, rule.destination, user_id_val or "global")
await on_rule_created(rule)
create_dialog.close()
_reset_form()
await refresh_table()
# --- Edit rule ---
edit_rule_id = {"value": None}
async def open_edit(rule_id: str):
async with async_session() as session:
rule = await session.get(Rule, UUID(rule_id))
if not rule:
return
edit_rule_id["value"] = rule_id
edit_action.value = rule.action
edit_dest.value = rule.destination
edit_port_type.value = rule.port_type or "any"
edit_port_range.value = rule.port_range or ""
edit_user.value = str(rule.user_id) if rule.user_id else "global"
edit_dialog.open()
async def save_edit():
rid = edit_rule_id["value"]
if not rid:
return
async with async_session() as session:
rule = await session.get(Rule, UUID(rid))
if not rule:
return
rule.action = edit_action.value
rule.destination = edit_dest.value.strip()
rule.port_type = edit_port_type.value if edit_port_type.value != "any" else None
rule.port_range = edit_port_range.value.strip() or None
rule.user_id = UUID(edit_user.value) if edit_user.value != "global" else None
session.add(rule)
await session.commit()
await session.refresh(rule)
await on_rule_updated(rule)
logger.info("Rule updated: {} {}", edit_action.value, edit_dest.value)
ui.notify("Rule updated")
edit_dialog.close()
await refresh_table()
async def delete_rule(rule_id: str):
async with async_session() as session:
rule = await session.get(Rule, UUID(rule_id))
if rule:
await session.delete(rule)
await session.commit()
logger.info("Rule deleted: {} {}", rule.action, rule.destination)
await on_rule_deleted(rule)
await refresh_table()
def _reset_form():
dest_input.value = ""
action_select.value = "accept"
port_type_select.value = "any"
port_range_input.value = ""
user_select.value = "global"
# Page content
with ui.column().classes("w-full p-4"):
with ui.row().classes("w-full items-center justify-between"):
ui.label("Firewall Rules").classes("text-h5")
ui.button("Add Rule", icon="add", on_click=lambda: create_dialog.open()).props("color=primary")
columns = [
{"name": "action", "label": "Action", "field": "action", "align": "left", "sortable": True},
{"name": "destination", "label": "Destination", "field": "destination", "align": "left", "sortable": True},
{"name": "port_type", "label": "Protocol", "field": "port_type", "align": "left"},
{"name": "port_range", "label": "Port(s)", "field": "port_range", "align": "left"},
{"name": "user", "label": "User", "field": "user", "align": "left"},
{"name": "actions", "label": "", "field": "id", "align": "center"},
]
table = ui.table(columns=columns, rows=[], row_key="id").classes("w-full")
table.add_slot(
"body-cell-actions",
'''
<q-td :props="props">
<q-btn flat dense icon="edit" color="primary"
@click.stop="() => $parent.$emit('edit', props.row.id)" />
<q-btn flat dense icon="delete" color="negative"
@click.stop="() => $parent.$emit('delete', props.row.id)" />
</q-td>
''',
)
table.on("edit", lambda e: open_edit(e.args))
table.on("delete", lambda e: delete_rule(e.args))
# Create rule dialog
with ui.dialog() as create_dialog:
with ui.card().classes("w-96"):
ui.label("New Firewall Rule").classes("text-h6")
action_select = ui.select(
["accept", "drop"], value="accept", label="Action",
).props("outlined dense").classes("w-full")
dest_input = ui.input("Destination (CIDR)", placeholder="e.g. 10.0.0.0/8 or 0.0.0.0/0").props(
"outlined dense"
).classes("w-full")
port_type_select = ui.select(
["any", "tcp", "udp"], value="any", label="Protocol",
).props("outlined dense").classes("w-full")
port_range_input = ui.input("Port Range", placeholder="e.g. 80 or 80-443 (optional)").props(
"outlined dense"
).classes("w-full")
user_options_list = [{"label": "Global (all users)", "value": "global"}] + [
{"label": email, "value": uid} for uid, email in user_options.items()
]
user_select = ui.select(
{item["value"]: item["label"] for item in user_options_list},
value="global",
label="Applies to",
).props("outlined dense").classes("w-full")
with ui.row().classes("w-full justify-end q-mt-sm"):
ui.button("Cancel", on_click=create_dialog.close).props("flat")
ui.button("Create", on_click=create_rule).props("color=primary")
# Edit rule dialog
user_options_map = {"global": "Global (all users)"}
user_options_map.update(user_options)
with ui.dialog() as edit_dialog:
with ui.card().classes("w-96"):
ui.label("Edit Firewall Rule").classes("text-h6")
edit_action = ui.select(
["accept", "drop"], value="accept", label="Action",
).props("outlined dense").classes("w-full")
edit_dest = ui.input("Destination (CIDR)").props("outlined dense").classes("w-full")
edit_port_type = ui.select(
["any", "tcp", "udp"], value="any", label="Protocol",
).props("outlined dense").classes("w-full")
edit_port_range = ui.input("Port Range").props("outlined dense").classes("w-full")
edit_user = ui.select(
user_options_map, value="global", label="Applies to",
).props("outlined dense").classes("w-full")
with ui.row().classes("w-full justify-end q-mt-sm"):
ui.button("Cancel", on_click=edit_dialog.close).props("flat")
ui.button("Save", on_click=save_edit).props("color=primary")
await refresh_table()

View file

@ -0,0 +1,367 @@
"""Admin settings pages — tabbed interface for configuration management."""
from uuid import UUID
from loguru import logger
from nicegui import app, ui
from sqlmodel import select
from wiregui.db import async_session
from wiregui.models.configuration import Configuration
from wiregui.pages.layout import layout
from wiregui.utils.time import utcnow
def _guard():
if not app.storage.user.get("authenticated") or app.storage.user.get("role") != "admin":
ui.navigate.to("/login")
return False
return True
async def _get_or_create_config() -> Configuration:
async with async_session() as session:
result = await session.execute(select(Configuration).limit(1))
config = result.scalar_one_or_none()
if not config:
config = Configuration()
session.add(config)
await session.commit()
await session.refresh(config)
return config
VPN_SESSION_OPTIONS = {
0: "Never (unlimited)",
3600: "Every Hour",
86400: "Every Day",
604800: "Every Week",
2592000: "Every 30 Days",
7776000: "Every 90 Days",
}
@ui.page("/admin/settings")
async def settings_page():
if not _guard():
return
layout()
config = await _get_or_create_config()
# --- Client Defaults tab ---
async def save_defaults():
async with async_session() as session:
c = await session.get(Configuration, config.id)
c.default_client_endpoint = defaults_endpoint.value.strip() or None
c.default_client_dns = [s.strip() for s in defaults_dns.value.split(",") if s.strip()]
c.default_client_mtu = int(defaults_mtu.value) if defaults_mtu.value else 1280
c.default_client_persistent_keepalive = int(defaults_keepalive.value) if defaults_keepalive.value else 25
c.default_client_allowed_ips = [s.strip() for s in defaults_allowed_ips.value.split(",") if s.strip()]
c.updated_at = utcnow()
session.add(c)
await session.commit()
logger.info("Client defaults updated")
ui.notify("Client defaults saved", type="positive")
# --- Security tab ---
async def save_security():
async with async_session() as session:
c = await session.get(Configuration, config.id)
c.vpn_session_duration = security_vpn_duration.value
c.local_auth_enabled = security_local_auth.value
c.allow_unprivileged_device_management = security_unpriv_mgmt.value
c.allow_unprivileged_device_configuration = security_unpriv_config.value
c.disable_vpn_on_oidc_error = security_disable_vpn_oidc.value
c.updated_at = utcnow()
session.add(c)
await session.commit()
logger.info("Security settings updated")
ui.notify("Security settings saved", type="positive")
# --- OIDC provider management ---
async def save_oidc_provider():
provider = {
"id": oidc_id.value.strip(),
"label": oidc_label.value.strip(),
"scope": oidc_scope.value.strip(),
"response_type": "code",
"client_id": oidc_client_id.value.strip(),
"client_secret": oidc_client_secret.value.strip(),
"discovery_document_uri": oidc_discovery.value.strip(),
"auto_create_users": oidc_auto_create.value,
}
if not all([provider["id"], provider["label"], provider["client_id"],
provider["client_secret"], provider["discovery_document_uri"]]):
ui.notify("All required fields must be filled", type="negative")
return
async with async_session() as session:
c = await session.get(Configuration, config.id)
providers = list(c.openid_connect_providers or [])
# Replace existing or add new
providers = [p for p in providers if p.get("id") != provider["id"]]
providers.append(provider)
c.openid_connect_providers = providers
c.updated_at = utcnow()
session.add(c)
await session.commit()
logger.info("OIDC provider saved: {}", provider["id"])
ui.notify(f"OIDC provider '{provider['label']}' saved", type="positive")
oidc_dialog.close()
await refresh_oidc_table()
async def delete_oidc_provider(provider_id: str):
async with async_session() as session:
c = await session.get(Configuration, config.id)
c.openid_connect_providers = [p for p in (c.openid_connect_providers or []) if p.get("id") != provider_id]
c.updated_at = utcnow()
session.add(c)
await session.commit()
logger.info("OIDC provider deleted: {}", provider_id)
ui.notify("OIDC provider deleted")
await refresh_oidc_table()
async def refresh_oidc_table():
async with async_session() as session:
c = await session.get(Configuration, config.id)
providers = c.openid_connect_providers or []
oidc_table.rows = [
{
"id": p.get("id", ""),
"label": p.get("label", ""),
"client_id": p.get("client_id", ""),
"discovery": p.get("discovery_document_uri", "")[:50] + "...",
"auto_create": "Yes" if p.get("auto_create_users") else "No",
}
for p in providers
]
oidc_table.update()
# --- Page content ---
with ui.column().classes("w-full p-4"):
ui.label("Settings").classes("text-h5 q-mb-md")
with ui.tabs().classes("w-full") as tabs:
defaults_tab = ui.tab("Client Defaults")
security_tab = ui.tab("Security")
auth_tab = ui.tab("Authentication")
with ui.tab_panels(tabs, value=defaults_tab).classes("w-full"):
# === Client Defaults ===
with ui.tab_panel(defaults_tab):
with ui.card().classes("w-full"):
ui.label("Default Client Configuration").classes("text-subtitle1 text-bold")
ui.label("These defaults apply to new devices unless overridden per-device.").classes("text-caption text-grey-7")
ui.separator()
defaults_endpoint = ui.input(
"Endpoint", value=config.default_client_endpoint or "",
placeholder="vpn.example.com",
).props("outlined dense").classes("w-full")
ui.label("IPv4/IPv6 address or FQDN clients connect to").classes("text-caption text-grey-7")
defaults_dns = ui.input(
"DNS Servers", value=", ".join(config.default_client_dns),
placeholder="1.1.1.1, 1.0.0.1",
).props("outlined dense").classes("w-full q-mt-sm")
ui.label("Comma-separated. Leave blank to omit.").classes("text-caption text-grey-7")
defaults_allowed_ips = ui.input(
"Allowed IPs", value=", ".join(config.default_client_allowed_ips),
placeholder="0.0.0.0/0, ::/0",
).props("outlined dense").classes("w-full q-mt-sm")
ui.label("CIDR ranges for split or full tunnel.").classes("text-caption text-grey-7")
with ui.row().classes("w-full gap-4 q-mt-sm"):
defaults_mtu = ui.input(
"MTU", value=str(config.default_client_mtu),
placeholder="1280",
).props("outlined dense").classes("w-48")
defaults_keepalive = ui.input(
"Persistent Keepalive", value=str(config.default_client_persistent_keepalive),
placeholder="25",
).props("outlined dense").classes("w-48")
ui.button("Save Defaults", on_click=save_defaults).props("color=primary").classes("q-mt-md")
# === Security ===
with ui.tab_panel(security_tab):
with ui.card().classes("w-full"):
ui.label("Authentication & Access").classes("text-subtitle1 text-bold")
ui.separator()
security_vpn_duration = ui.select(
VPN_SESSION_OPTIONS,
value=config.vpn_session_duration,
label="VPN Session Duration",
).props("outlined dense").classes("w-full")
ui.label("How often users must re-authenticate to maintain VPN access.").classes("text-caption text-grey-7")
ui.separator().classes("q-my-md")
security_local_auth = ui.switch("Local Authentication (email/password)", value=config.local_auth_enabled)
security_unpriv_mgmt = ui.switch("Allow Unprivileged Device Management", value=config.allow_unprivileged_device_management)
security_unpriv_config = ui.switch("Allow Unprivileged Device Configuration", value=config.allow_unprivileged_device_configuration)
ui.separator().classes("q-my-md")
ui.label("SSO Behavior").classes("text-subtitle2")
security_disable_vpn_oidc = ui.switch("Auto-disable VPN on OIDC refresh error", value=config.disable_vpn_on_oidc_error)
ui.button("Save Security Settings", on_click=save_security).props("color=primary").classes("q-mt-md")
# === Authentication (OIDC/SAML) ===
with ui.tab_panel(auth_tab):
with ui.card().classes("w-full"):
ui.label("OpenID Connect Providers").classes("text-subtitle1 text-bold")
ui.separator()
oidc_columns = [
{"name": "id", "label": "Config ID", "field": "id", "align": "left"},
{"name": "label", "label": "Label", "field": "label", "align": "left"},
{"name": "client_id", "label": "Client ID", "field": "client_id", "align": "left"},
{"name": "discovery", "label": "Discovery URI", "field": "discovery", "align": "left"},
{"name": "auto_create", "label": "Auto-create", "field": "auto_create", "align": "center"},
{"name": "actions", "label": "", "field": "id", "align": "center"},
]
oidc_table = ui.table(columns=oidc_columns, rows=[], row_key="id").classes("w-full")
oidc_table.add_slot(
"body-cell-actions",
'''
<q-td :props="props">
<q-btn flat dense icon="delete" color="negative"
@click.stop="() => $parent.$emit('delete', props.row.id)" />
</q-td>
''',
)
oidc_table.on("delete", lambda e: delete_oidc_provider(e.args))
ui.button("Add OIDC Provider", icon="add", on_click=lambda: oidc_dialog.open()).props("outline").classes("q-mt-sm")
with ui.card().classes("w-full q-mt-md"):
ui.label("SAML Identity Providers").classes("text-subtitle1 text-bold")
ui.separator()
saml_columns = [
{"name": "id", "label": "Config ID", "field": "id", "align": "left"},
{"name": "label", "label": "Label", "field": "label", "align": "left"},
{"name": "metadata", "label": "Metadata", "field": "metadata", "align": "left"},
{"name": "auto_create", "label": "Auto-create", "field": "auto_create", "align": "center"},
{"name": "actions", "label": "", "field": "id", "align": "center"},
]
saml_table = ui.table(columns=saml_columns, rows=[], row_key="id").classes("w-full")
saml_table.add_slot(
"body-cell-actions",
'''
<q-td :props="props">
<q-btn flat dense icon="delete" color="negative"
@click.stop="() => $parent.$emit('delete', props.row.id)" />
</q-td>
''',
)
saml_table.on("delete", lambda e: delete_saml_provider(e.args))
ui.button("Add SAML Provider", icon="add", on_click=lambda: saml_dialog.open()).props("outline").classes("q-mt-sm")
# --- SAML provider management ---
async def save_saml_provider():
provider = {
"id": saml_id.value.strip(),
"label": saml_label.value.strip(),
"metadata": saml_metadata_input.value.strip(),
"base_url": f"{get_settings().external_url}/auth/saml",
"sign_requests": saml_sign_requests.value,
"sign_metadata": saml_sign_metadata.value,
"signed_assertion_in_resp": saml_signed_assertion.value,
"signed_envelopes_in_resp": saml_signed_envelopes.value,
"auto_create_users": saml_auto_create.value,
}
if not all([provider["id"], provider["label"], provider["metadata"]]):
ui.notify("Config ID, Label, and Metadata are required", type="negative")
return
async with async_session() as session:
c = await session.get(Configuration, config.id)
providers = list(c.saml_identity_providers or [])
providers = [p for p in providers if p.get("id") != provider["id"]]
providers.append(provider)
c.saml_identity_providers = providers
c.updated_at = utcnow()
session.add(c)
await session.commit()
logger.info("SAML provider saved: {}", provider["id"])
ui.notify(f"SAML provider '{provider['label']}' saved", type="positive")
saml_dialog.close()
await refresh_saml_table()
async def delete_saml_provider(provider_id: str):
async with async_session() as session:
c = await session.get(Configuration, config.id)
c.saml_identity_providers = [p for p in (c.saml_identity_providers or []) if p.get("id") != provider_id]
c.updated_at = utcnow()
session.add(c)
await session.commit()
logger.info("SAML provider deleted: {}", provider_id)
ui.notify("SAML provider deleted")
await refresh_saml_table()
async def refresh_saml_table():
async with async_session() as session:
c = await session.get(Configuration, config.id)
providers = c.saml_identity_providers or []
saml_table.rows = [
{
"id": p.get("id", ""),
"label": p.get("label", ""),
"metadata": (p.get("metadata", ""))[:40] + "..." if len(p.get("metadata", "")) > 40 else p.get("metadata", ""),
"auto_create": "Yes" if p.get("auto_create_users") else "No",
}
for p in providers
]
saml_table.update()
# --- OIDC provider dialog ---
with ui.dialog() as oidc_dialog:
with ui.card().classes("w-[500px]"):
ui.label("OIDC Provider").classes("text-h6")
oidc_id = ui.input("Config ID", placeholder="google").props("outlined dense").classes("w-full")
oidc_label = ui.input("Label", placeholder="Sign in with Google").props("outlined dense").classes("w-full")
oidc_scope = ui.input("Scope", value="openid email profile").props("outlined dense").classes("w-full")
oidc_client_id = ui.input("Client ID").props("outlined dense").classes("w-full")
oidc_client_secret = ui.input("Client Secret", password=True, password_toggle_button=True).props("outlined dense").classes("w-full")
oidc_discovery = ui.input("Discovery Document URI", placeholder="https://accounts.google.com/.well-known/openid-configuration").props("outlined dense").classes("w-full")
oidc_auto_create = ui.switch("Auto-create users", value=False)
with ui.row().classes("w-full justify-end q-mt-sm"):
ui.button("Cancel", on_click=oidc_dialog.close).props("flat")
ui.button("Save", on_click=save_oidc_provider).props("color=primary")
# --- SAML provider dialog ---
with ui.dialog() as saml_dialog:
with ui.card().classes("w-[500px]"):
ui.label("SAML Identity Provider").classes("text-h6")
saml_id = ui.input("Config ID", placeholder="okta-saml").props("outlined dense").classes("w-full")
saml_label = ui.input("Label", placeholder="Sign in with Okta").props("outlined dense").classes("w-full")
saml_metadata_input = ui.textarea("IdP Metadata (XML)").props("outlined").classes("w-full").style("min-height: 120px")
ui.label("Paste the full XML metadata from your identity provider.").classes("text-caption text-grey-7")
ui.separator().classes("q-my-sm")
ui.label("Security Options").classes("text-subtitle2")
saml_sign_requests = ui.switch("Sign authentication requests", value=True)
saml_sign_metadata = ui.switch("Sign SP metadata", value=True)
saml_signed_assertion = ui.switch("Require signed assertions in response", value=True)
saml_signed_envelopes = ui.switch("Require signed envelopes in response", value=True)
ui.separator().classes("q-my-sm")
saml_auto_create = ui.switch("Auto-create users", value=False)
with ui.row().classes("w-full justify-end q-mt-sm"):
ui.button("Cancel", on_click=saml_dialog.close).props("flat")
ui.button("Save", on_click=save_saml_provider).props("color=primary")
await refresh_oidc_table()
await refresh_saml_table()

View file

@ -0,0 +1,236 @@
"""Admin user management page."""
from uuid import UUID
from loguru import logger
from nicegui import app, ui
from sqlalchemy.orm import selectinload
from sqlmodel import func, select
from wiregui.auth.passwords import hash_password
from wiregui.db import async_session
from wiregui.models.device import Device
from wiregui.models.rule import Rule
from wiregui.models.user import User
from wiregui.pages.layout import layout
from wiregui.services.events import on_device_deleted
from wiregui.utils.time import utcnow
def _guard():
if not app.storage.user.get("authenticated") or app.storage.user.get("role") != "admin":
ui.navigate.to("/login")
return False
return True
@ui.page("/admin/users")
async def users_page():
if not _guard():
return
layout()
async def load_users() -> list[dict]:
async with async_session() as session:
# Get users with device counts via subquery
device_count_sq = (
select(Device.user_id, func.count().label("device_count"))
.group_by(Device.user_id)
.subquery()
)
result = await session.execute(
select(User).order_by(User.email)
)
users = result.scalars().all()
# Get device counts separately
counts_result = await session.execute(
select(Device.user_id, func.count().label("cnt")).group_by(Device.user_id)
)
counts = {str(row[0]): row[1] for row in counts_result.all()}
return [
{
"id": str(u.id),
"email": u.email,
"role": u.role,
"devices": counts.get(str(u.id), 0),
"last_signed_in": str(u.last_signed_in_at or "-"),
"method": u.last_signed_in_method or "-",
"status": "Disabled" if u.disabled_at else "Active",
"created": str(u.inserted_at)[:19],
}
for u in users
]
async def refresh_table():
table.rows = await load_users()
table.update()
# --- Create user ---
async def create_user():
email = create_email.value.strip()
pwd = create_password.value
role = create_role.value
if not email or not pwd:
ui.notify("Email and password are required", type="negative")
return
async with async_session() as session:
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
if existing:
ui.notify(f"User {email} already exists", type="negative")
return
user = User(email=email, password_hash=hash_password(pwd), role=role)
session.add(user)
await session.commit()
logger.info("Admin created user: {} ({})", email, role)
ui.notify(f"User {email} created")
create_dialog.close()
create_email.value = ""
create_password.value = ""
create_role.value = "unprivileged"
await refresh_table()
# --- Edit user ---
edit_user_id = {"value": None}
async def open_edit(user_id: str):
async with async_session() as session:
user = await session.get(User, UUID(user_id))
if not user:
return
edit_user_id["value"] = user_id
edit_email.value = user.email
edit_role.value = user.role
edit_password.value = ""
edit_disabled.value = user.disabled_at is not None
edit_dialog.open()
async def save_edit():
uid = edit_user_id["value"]
if not uid:
return
async with async_session() as session:
user = await session.get(User, UUID(uid))
if not user:
return
user.email = edit_email.value.strip()
user.role = edit_role.value
if edit_password.value:
user.password_hash = hash_password(edit_password.value)
if edit_disabled.value and not user.disabled_at:
user.disabled_at = utcnow()
elif not edit_disabled.value and user.disabled_at:
user.disabled_at = None
session.add(user)
await session.commit()
logger.info("Admin updated user: {}", edit_email.value)
ui.notify("User updated")
edit_dialog.close()
await refresh_table()
# --- Delete user ---
async def delete_user(user_id: str):
current_user_id = app.storage.user.get("user_id")
if user_id == current_user_id:
ui.notify("Cannot delete your own account", type="negative")
return
async with async_session() as session:
user = await session.get(User, UUID(user_id))
if not user:
return
# Delete user's devices (and fire WG events)
devices_result = await session.execute(
select(Device).where(Device.user_id == user.id)
)
for device in devices_result.scalars().all():
await session.delete(device)
await on_device_deleted(device)
# Delete user's rules
await session.execute(
select(Rule).where(Rule.user_id == user.id)
)
rules_result = await session.execute(select(Rule).where(Rule.user_id == user.id))
for rule in rules_result.scalars().all():
await session.delete(rule)
await session.delete(user)
await session.commit()
logger.info("Admin deleted user: {}", user.email)
ui.notify(f"User {user.email} deleted")
await refresh_table()
def on_row_click(e):
open_edit(e.args["id"])
# --- Page content ---
with ui.column().classes("w-full p-4"):
with ui.row().classes("w-full items-center justify-between"):
ui.label("Users").classes("text-h5")
ui.button("Add User", icon="person_add", on_click=lambda: create_dialog.open()).props("color=primary")
columns = [
{"name": "email", "label": "Email", "field": "email", "align": "left", "sortable": True},
{"name": "role", "label": "Role", "field": "role", "align": "left", "sortable": True},
{"name": "devices", "label": "Devices", "field": "devices", "align": "center"},
{"name": "status", "label": "Status", "field": "status", "align": "left"},
{"name": "last_signed_in", "label": "Last Sign-in", "field": "last_signed_in", "align": "left"},
{"name": "method", "label": "Method", "field": "method", "align": "left"},
{"name": "created", "label": "Created", "field": "created", "align": "left"},
{"name": "actions", "label": "", "field": "id", "align": "center"},
]
table = ui.table(columns=columns, rows=[], row_key="id").classes("w-full")
table.on("rowClick", on_row_click)
table.add_slot(
"body-cell-actions",
'''
<q-td :props="props">
<q-btn flat dense icon="edit" color="primary"
@click.stop="() => $parent.$emit('edit', props.row.id)" />
<q-btn flat dense icon="delete" color="negative"
@click.stop="() => $parent.$emit('delete', props.row.id)" />
</q-td>
''',
)
table.on("edit", lambda e: open_edit(e.args))
table.on("delete", lambda e: delete_user(e.args))
# --- Create dialog ---
with ui.dialog() as create_dialog:
with ui.card().classes("w-96"):
ui.label("New User").classes("text-h6")
create_email = ui.input("Email").props("outlined dense").classes("w-full")
create_password = ui.input("Password", password=True, password_toggle_button=True).props("outlined dense").classes("w-full")
create_role = ui.select(["unprivileged", "admin"], value="unprivileged", label="Role").props("outlined dense").classes("w-full")
with ui.row().classes("w-full justify-end q-mt-sm"):
ui.button("Cancel", on_click=create_dialog.close).props("flat")
ui.button("Create", on_click=create_user).props("color=primary")
# --- Edit dialog ---
with ui.dialog() as edit_dialog:
with ui.card().classes("w-96"):
ui.label("Edit User").classes("text-h6")
edit_email = ui.input("Email").props("outlined dense").classes("w-full")
edit_role = ui.select(["unprivileged", "admin"], value="unprivileged", label="Role").props("outlined dense").classes("w-full")
edit_password = ui.input("New Password (leave blank to keep)", password=True, password_toggle_button=True).props("outlined dense").classes("w-full")
edit_disabled = ui.switch("Disabled")
with ui.row().classes("w-full justify-end q-mt-sm"):
ui.button("Cancel", on_click=edit_dialog.close).props("flat")
ui.button("Save", on_click=save_edit).props("color=primary")
await refresh_table()

View file

@ -0,0 +1,91 @@
"""Magic link authentication — request and verify signed JWT email links."""
from datetime import timedelta
from uuid import UUID
from loguru import logger
from nicegui import app, ui
from sqlmodel import select
from wiregui.auth.jwt import create_access_token, decode_access_token
from wiregui.config import get_settings
from wiregui.db import async_session
from wiregui.models.user import User
from wiregui.services.email import send_magic_link
from wiregui.utils.time import utcnow
@ui.page("/auth/magic-link")
async def magic_link_request_page():
"""Page to request a magic link email."""
if app.storage.user.get("authenticated"):
return ui.navigate.to("/")
async def request_link():
email_val = email_input.value.strip()
if not email_val:
ui.notify("Enter your email", type="negative")
return
# Always show success to avoid user enumeration
ui.notify("If an account exists, a sign-in link has been sent.", type="positive")
async with async_session() as session:
result = await session.execute(select(User).where(User.email == email_val))
user = result.scalar_one_or_none()
if user and user.disabled_at is None:
settings = get_settings()
token = create_access_token(
user_id=str(user.id),
role=user.role,
expires_delta=timedelta(minutes=15),
)
link = f"{settings.external_url}/auth/magic/{user.id}/{token}"
await send_magic_link(email_val, link)
logger.info("Magic link sent to {}", email_val)
with ui.column().classes("absolute-center items-center"):
ui.label("WireGUI").classes("text-h4 text-bold")
ui.label("Sign in with magic link").classes("text-subtitle1 q-mb-md")
with ui.card().classes("w-80"):
email_input = ui.input("Email").props("outlined dense").classes("w-full")
ui.button("Send Magic Link", on_click=request_link).classes("w-full q-mt-sm")
email_input.on("keydown.enter", request_link)
ui.button("Back to login", on_click=lambda: ui.navigate.to("/login")).props("flat").classes("q-mt-md")
@ui.page("/auth/magic/{user_id}/{token}")
async def magic_link_verify_page(user_id: str, token: str):
"""Verify a magic link token and sign the user in."""
payload = decode_access_token(token)
if not payload or payload.get("sub") != user_id:
with ui.column().classes("absolute-center items-center"):
ui.label("Invalid or expired link").classes("text-h5 text-negative")
ui.button("Back to login", on_click=lambda: ui.navigate.to("/login")).props("flat")
return
async with async_session() as session:
user = await session.get(User, UUID(user_id))
if not user or user.disabled_at is not None:
with ui.column().classes("absolute-center items-center"):
ui.label("Account not found or disabled").classes("text-h5 text-negative")
ui.button("Back to login", on_click=lambda: ui.navigate.to("/login")).props("flat")
return
user.last_signed_in_at = utcnow()
user.last_signed_in_method = "magic_link"
session.add(user)
await session.commit()
logger.info("Magic link login: {}", user.email)
app.storage.user.update(
authenticated=True,
user_id=str(user.id),
email=user.email,
role=user.role,
)
ui.navigate.to("/")

120
wiregui/pages/auth_oidc.py Normal file
View file

@ -0,0 +1,120 @@
"""OIDC authentication routes — redirect to provider and handle callback."""
from loguru import logger
from nicegui import app
from fastapi import Request
from fastapi.responses import RedirectResponse
from wiregui.auth.oidc import get_client, get_provider_config
from wiregui.config import get_settings
from wiregui.db import async_session
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.user import User
from wiregui.utils.time import utcnow
from sqlmodel import select
@app.get("/auth/oidc/{provider_id}")
async def oidc_redirect(provider_id: str, request: Request):
"""Redirect user to the OIDC provider's authorization endpoint."""
try:
client = get_client(provider_id)
except ValueError:
return RedirectResponse(url="/login")
settings = get_settings()
redirect_uri = f"{settings.external_url}/auth/oidc/{provider_id}/callback"
return await client.authorize_redirect(request, redirect_uri)
@app.get("/auth/oidc/{provider_id}/callback")
async def oidc_callback(provider_id: str, request: Request):
"""Handle the OIDC provider callback — exchange code for tokens and create session."""
try:
client = get_client(provider_id)
except ValueError:
return RedirectResponse(url="/login")
try:
token = await client.authorize_access_token(request)
except Exception as e:
logger.error("OIDC token exchange failed for {}: {}", provider_id, e)
return RedirectResponse(url="/login")
userinfo = token.get("userinfo")
if not userinfo:
try:
userinfo = await client.userinfo()
except Exception as e:
logger.error("OIDC userinfo failed for {}: {}", provider_id, e)
return RedirectResponse(url="/login")
email = userinfo.get("email")
if not email:
logger.error("OIDC provider {} did not return email", provider_id)
return RedirectResponse(url="/login")
provider_config = await get_provider_config(provider_id)
auto_create = provider_config.get("auto_create_users", False) if provider_config else False
async with async_session() as session:
# Find or create user
result = await session.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
if user is None:
if not auto_create:
logger.warning("OIDC: user {} not found and auto-create disabled for {}", email, provider_id)
return RedirectResponse(url="/login")
user = User(email=email, role="unprivileged")
session.add(user)
await session.flush()
logger.info("OIDC: auto-created user {} via {}", email, provider_id)
if user.disabled_at is not None:
logger.warning("OIDC: disabled user {} attempted login via {}", email, provider_id)
return RedirectResponse(url="/login")
# Update sign-in tracking
user.last_signed_in_at = utcnow()
user.last_signed_in_method = f"oidc:{provider_id}"
session.add(user)
# Store/update OIDC connection with refresh token
refresh_token = token.get("refresh_token")
existing_conn = (await session.execute(
select(OIDCConnection).where(
OIDCConnection.user_id == user.id,
OIDCConnection.provider == provider_id,
)
)).scalar_one_or_none()
if existing_conn:
existing_conn.refresh_token = refresh_token
existing_conn.refreshed_at = utcnow()
existing_conn.refresh_response = dict(token)
session.add(existing_conn)
else:
conn = OIDCConnection(
provider=provider_id,
refresh_token=refresh_token,
refresh_response=dict(token),
refreshed_at=utcnow(),
user_id=user.id,
)
session.add(conn)
await session.commit()
logger.info("OIDC login: {} via {}", email, provider_id)
# Set NiceGUI session — store in Starlette session since we're in a plain route
request.session["authenticated"] = True
request.session["user_id"] = str(user.id)
request.session["email"] = user.email
request.session["role"] = user.role
return RedirectResponse(url="/")

129
wiregui/pages/auth_saml.py Normal file
View file

@ -0,0 +1,129 @@
"""SAML authentication routes — SP-initiated SSO redirect and ACS callback."""
from urllib.parse import urlparse
from fastapi import Request
from fastapi.responses import HTMLResponse, RedirectResponse, Response
from loguru import logger
from nicegui import app
from sqlmodel import select
from wiregui.auth.saml import create_saml_auth, get_login_url, get_metadata, process_response
from wiregui.config import get_settings
from wiregui.db import async_session
from wiregui.models.configuration import Configuration
from wiregui.models.user import User
from wiregui.utils.time import utcnow
async def _get_saml_provider(provider_id: str) -> dict | None:
async with async_session() as session:
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
if not config:
return None
for p in config.saml_identity_providers or []:
if p.get("id") == provider_id:
return p
return None
def _request_data_from_fastapi(request: Request) -> dict:
settings = get_settings()
parsed = urlparse(settings.external_url)
return {
"http_host": parsed.hostname,
"script_name": "",
"server_port": parsed.port or (443 if parsed.scheme == "https" else 80),
"get_data": dict(request.query_params),
"post_data": {},
"https": "on" if parsed.scheme == "https" else "off",
}
@app.get("/auth/saml/{provider_id}")
async def saml_redirect(provider_id: str, request: Request):
"""Redirect user to the SAML IdP."""
provider = await _get_saml_provider(provider_id)
if not provider:
return RedirectResponse(url="/login")
try:
req_data = _request_data_from_fastapi(request)
auth = create_saml_auth(provider, req_data)
login_url = get_login_url(auth)
return RedirectResponse(url=login_url)
except Exception as e:
logger.error("SAML redirect failed for {}: {}", provider_id, e)
return RedirectResponse(url="/login")
@app.post("/auth/saml/{provider_id}/callback")
async def saml_callback(provider_id: str, request: Request):
"""Handle the SAML ACS callback (POST with SAMLResponse)."""
provider = await _get_saml_provider(provider_id)
if not provider:
return RedirectResponse(url="/login")
try:
form_data = await request.form()
req_data = _request_data_from_fastapi(request)
req_data["post_data"] = dict(form_data)
auth = create_saml_auth(provider, req_data)
user_data = process_response(auth)
if not user_data or not user_data.get("email"):
logger.warning("SAML callback: no valid user data from {}", provider_id)
return RedirectResponse(url="/login")
email = user_data["email"]
auto_create = provider.get("auto_create_users", False)
async with async_session() as session:
result = await session.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
if user is None:
if not auto_create:
logger.warning("SAML: user {} not found, auto-create disabled for {}", email, provider_id)
return RedirectResponse(url="/login")
user = User(email=email, role="unprivileged")
session.add(user)
await session.flush()
logger.info("SAML: auto-created user {} via {}", email, provider_id)
if user.disabled_at is not None:
logger.warning("SAML: disabled user {} attempted login via {}", email, provider_id)
return RedirectResponse(url="/login")
user.last_signed_in_at = utcnow()
user.last_signed_in_method = f"saml:{provider_id}"
session.add(user)
await session.commit()
request.session["authenticated"] = True
request.session["user_id"] = str(user.id)
request.session["email"] = user.email
request.session["role"] = user.role
logger.info("SAML login: {} via {}", email, provider_id)
return RedirectResponse(url="/", status_code=303)
except Exception as e:
logger.error("SAML callback failed for {}: {}", provider_id, e)
return RedirectResponse(url="/login")
@app.get("/auth/saml/{provider_id}/metadata")
async def saml_metadata(provider_id: str):
"""Return SP metadata XML for the SAML provider."""
provider = await _get_saml_provider(provider_id)
if not provider:
return Response(status_code=404)
try:
metadata_xml = get_metadata(provider)
return Response(content=metadata_xml, media_type="application/xml")
except Exception as e:
logger.error("SAML metadata generation failed for {}: {}", provider_id, e)
return Response(status_code=500)

463
wiregui/pages/devices.py Normal file
View file

@ -0,0 +1,463 @@
"""User-facing device management pages."""
import io
from uuid import UUID
import qrcode
import qrcode.image.svg
from loguru import logger
from nicegui import app, ui
from sqlmodel import select
from wiregui.config import get_settings
from wiregui.db import async_session
from wiregui.models.device import Device
from wiregui.pages.layout import layout
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated
from wiregui.utils.crypto import generate_keypair, generate_preshared_key
from wiregui.utils.network import allocate_ipv4, allocate_ipv6
from wiregui.utils.server_key import get_server_public_key
from wiregui.utils.wg_conf import build_client_config
def _format_bytes(b: int | None) -> str:
if b is None:
return "-"
for unit in ("B", "KB", "MB", "GB", "TB"):
if b < 1024:
return f"{b:.1f} {unit}"
b /= 1024
return f"{b:.1f} PB"
@ui.page("/devices")
async def devices_page():
if not app.storage.user.get("authenticated"):
return ui.navigate.to("/login")
layout()
user_id = UUID(app.storage.user["user_id"])
async def load_devices() -> list[Device]:
async with async_session() as session:
result = await session.execute(
select(Device).where(Device.user_id == user_id).order_by(Device.inserted_at.desc())
)
return list(result.scalars().all())
async def refresh_table():
devices = await load_devices()
table.rows = [
{
"id": str(d.id),
"name": d.name,
"description": d.description or "",
"ipv4": d.ipv4 or "-",
"ipv6": d.ipv6 or "-",
"public_key": d.public_key[:16] + "...",
"rx": _format_bytes(d.rx_bytes),
"tx": _format_bytes(d.tx_bytes),
"handshake": str(d.latest_handshake)[:19] if d.latest_handshake else "-",
}
for d in devices
]
table.update()
# --- Create device ---
async def create_device():
name = create_name.value.strip()
if not name:
ui.notify("Device name is required", type="negative")
return
try:
settings = get_settings()
private_key, public_key = generate_keypair()
psk = generate_preshared_key()
async with async_session() as session:
ipv4 = await allocate_ipv4(session, settings.wg_ipv4_network)
ipv6 = await allocate_ipv6(session, settings.wg_ipv6_network)
device = Device(
name=name,
description=create_desc.value.strip() or None,
public_key=public_key,
preshared_key=psk,
ipv4=ipv4,
ipv6=ipv6,
user_id=user_id,
use_default_allowed_ips=create_use_default_ips.value,
use_default_dns=create_use_default_dns.value,
use_default_endpoint=create_use_default_endpoint.value,
use_default_mtu=create_use_default_mtu.value,
use_default_persistent_keepalive=create_use_default_keepalive.value,
endpoint=(create_endpoint.value.strip() or None
if not create_use_default_endpoint.value else None),
dns=([s.strip() for s in create_dns.value.split(",") if s.strip()]
if not create_use_default_dns.value and create_dns.value else []),
mtu=(int(create_mtu.value)
if not create_use_default_mtu.value and create_mtu.value else None),
persistent_keepalive=(int(create_keepalive.value)
if not create_use_default_keepalive.value and create_keepalive.value else None),
allowed_ips=([s.strip() for s in create_allowed_ips.value.split(",") if s.strip()]
if not create_use_default_ips.value and create_allowed_ips.value else []),
)
session.add(device)
await session.commit()
await session.refresh(device)
logger.info("Device created: {} ({})", device.name, device.ipv4)
await on_device_created(device)
server_pubkey = await get_server_public_key()
config_text = build_client_config(device, private_key, server_pubkey)
_show_config_dialog(device.name, config_text)
create_dialog.close()
_reset_create_form()
await refresh_table()
except Exception as e:
logger.error("Failed to create device: {}", e)
ui.notify(f"Error: {e}", type="negative")
def _reset_create_form():
create_name.value = ""
create_desc.value = ""
create_use_default_ips.value = True
create_use_default_dns.value = True
create_use_default_endpoint.value = True
create_use_default_mtu.value = True
create_use_default_keepalive.value = True
create_endpoint.value = ""
create_dns.value = ""
create_mtu.value = ""
create_keepalive.value = ""
create_allowed_ips.value = ""
# --- Delete device ---
async def delete_device(device_id: str):
async with async_session() as session:
device = await session.get(Device, UUID(device_id))
if device and device.user_id == user_id:
await session.delete(device)
await session.commit()
logger.info("Device deleted: {}", device.name)
await on_device_deleted(device)
ui.notify(f"Deleted {device.name}")
await refresh_table()
def on_row_click(e):
ui.navigate.to(f"/devices/{e.args['id']}")
# --- Page content ---
with ui.column().classes("w-full p-4"):
with ui.row().classes("w-full items-center justify-between"):
ui.label("My Devices").classes("text-h5")
ui.button("Add Device", icon="add", on_click=lambda: create_dialog.open()).props("color=primary")
columns = [
{"name": "name", "label": "Name", "field": "name", "align": "left", "sortable": True},
{"name": "ipv4", "label": "IPv4", "field": "ipv4", "align": "left"},
{"name": "ipv6", "label": "IPv6", "field": "ipv6", "align": "left"},
{"name": "public_key", "label": "Public Key", "field": "public_key", "align": "left"},
{"name": "rx", "label": "RX", "field": "rx", "align": "right"},
{"name": "tx", "label": "TX", "field": "tx", "align": "right"},
{"name": "handshake", "label": "Last Handshake", "field": "handshake", "align": "left"},
{"name": "actions", "label": "", "field": "id", "align": "center"},
]
table = ui.table(columns=columns, rows=[], row_key="id").classes("w-full")
table.on("rowClick", on_row_click)
table.add_slot(
"body-cell-actions",
'''
<q-td :props="props">
<q-btn flat dense icon="delete" color="negative"
@click.stop="() => $parent.$emit('delete', props.row.id)" />
</q-td>
''',
)
table.on("delete", lambda e: delete_device(e.args))
# --- Create device dialog (full form) ---
with ui.dialog() as create_dialog:
with ui.card().classes("w-[600px]"):
ui.label("New Device").classes("text-h6")
create_name = ui.input("Device Name").props("outlined dense").classes("w-full")
create_desc = ui.input("Description (optional)").props("outlined dense").classes("w-full")
ui.separator().classes("q-my-sm")
ui.label("Configuration Overrides").classes("text-subtitle2")
ui.label("Toggle off to set custom values instead of server defaults.").classes("text-caption text-grey-7")
with ui.grid(columns=2).classes("w-full gap-2"):
create_use_default_ips = ui.switch("Use default Allowed IPs", value=True)
create_allowed_ips = ui.input("Allowed IPs", placeholder="0.0.0.0/0, ::/0").props(
"outlined dense"
).classes("w-full").bind_enabled_from(create_use_default_ips, "value", backward=lambda v: not v)
create_use_default_dns = ui.switch("Use default DNS", value=True)
create_dns = ui.input("DNS Servers", placeholder="1.1.1.1, 1.0.0.1").props(
"outlined dense"
).classes("w-full").bind_enabled_from(create_use_default_dns, "value", backward=lambda v: not v)
create_use_default_endpoint = ui.switch("Use default Endpoint", value=True)
create_endpoint = ui.input("Endpoint", placeholder="vpn.example.com").props(
"outlined dense"
).classes("w-full").bind_enabled_from(create_use_default_endpoint, "value", backward=lambda v: not v)
create_use_default_mtu = ui.switch("Use default MTU", value=True)
create_mtu = ui.input("MTU", placeholder="1280").props(
"outlined dense"
).classes("w-full").bind_enabled_from(create_use_default_mtu, "value", backward=lambda v: not v)
create_use_default_keepalive = ui.switch("Use default Keepalive", value=True)
create_keepalive = ui.input("Persistent Keepalive", placeholder="25").props(
"outlined dense"
).classes("w-full").bind_enabled_from(create_use_default_keepalive, "value", backward=lambda v: not v)
with ui.row().classes("w-full justify-end q-mt-md"):
ui.button("Cancel", on_click=create_dialog.close).props("flat")
ui.button("Create", on_click=create_device).props("color=primary")
await refresh_table()
# Auto-refresh stats every 30 seconds
ui.timer(30, refresh_table)
@ui.page("/devices/{device_id}")
async def device_detail_page(device_id: str):
if not app.storage.user.get("authenticated"):
return ui.navigate.to("/login")
layout()
user_id = UUID(app.storage.user["user_id"])
async with async_session() as sess:
device = await sess.get(Device, UUID(device_id))
if not device or device.user_id != user_id:
ui.label("Device not found").classes("text-h5 text-negative p-4")
return
# --- Edit handlers ---
async def save_edit():
async with async_session() as session:
d = await session.get(Device, UUID(device_id))
if not d:
return
d.name = edit_name.value.strip()
d.description = edit_desc.value.strip() or None
d.use_default_allowed_ips = edit_use_default_ips.value
d.use_default_dns = edit_use_default_dns.value
d.use_default_endpoint = edit_use_default_endpoint.value
d.use_default_mtu = edit_use_default_mtu.value
d.use_default_persistent_keepalive = edit_use_default_keepalive.value
if not d.use_default_endpoint:
d.endpoint = edit_endpoint.value.strip() or None
if not d.use_default_dns:
d.dns = [s.strip() for s in edit_dns.value.split(",") if s.strip()]
if not d.use_default_mtu:
d.mtu = int(edit_mtu.value) if edit_mtu.value else None
if not d.use_default_persistent_keepalive:
d.persistent_keepalive = int(edit_keepalive.value) if edit_keepalive.value else None
if not d.use_default_allowed_ips:
d.allowed_ips = [s.strip() for s in edit_allowed_ips.value.split(",") if s.strip()]
session.add(d)
await session.commit()
await session.refresh(d)
await on_device_updated(d)
logger.info("Device updated: {}", edit_name.value)
ui.notify("Device updated", type="positive")
ui.navigate.to(f"/devices/{device_id}")
async def delete_and_redirect():
async with async_session() as session:
d = await session.get(Device, UUID(device_id))
if d:
await session.delete(d)
await session.commit()
logger.info("Device deleted: {}", d.name)
await on_device_deleted(d)
ui.navigate.to("/devices")
settings = get_settings()
# --- Page content ---
with ui.column().classes("w-full p-4"):
with ui.row().classes("items-center gap-2"):
ui.button(icon="arrow_back", on_click=lambda: ui.navigate.to("/devices")).props("flat")
ui.label(device.name).classes("text-h5")
if device.description:
ui.label(f"{device.description}").classes("text-subtitle1 text-grey-7")
# Device info card
with ui.card().classes("w-full q-mt-md"):
ui.label("Device Details").classes("text-subtitle1 text-bold")
ui.separator()
with ui.grid(columns=2).classes("w-full gap-2 q-pa-sm"):
ui.label("Public Key:").classes("text-bold")
ui.label(device.public_key).classes("font-mono text-sm")
ui.label("IPv4:").classes("text-bold")
ui.label(device.ipv4 or "-")
ui.label("IPv6:").classes("text-bold")
ui.label(device.ipv6 or "-")
ui.label("Created:").classes("text-bold")
ui.label(str(device.inserted_at)[:19])
# Traffic stats (live-updating)
with ui.card().classes("w-full q-mt-md"):
ui.label("Traffic Stats").classes("text-subtitle1 text-bold")
ui.label("Auto-refreshes every 30s").classes("text-caption text-grey-7")
ui.separator()
with ui.grid(columns=2).classes("w-full gap-2 q-pa-sm"):
ui.label("RX:").classes("text-bold")
stat_rx = ui.label(_format_bytes(device.rx_bytes))
ui.label("TX:").classes("text-bold")
stat_tx = ui.label(_format_bytes(device.tx_bytes))
ui.label("Last Handshake:").classes("text-bold")
stat_handshake = ui.label(str(device.latest_handshake)[:19] if device.latest_handshake else "-")
ui.label("Remote IP:").classes("text-bold")
stat_remote = ui.label(device.remote_ip or "-")
async def refresh_stats():
async with async_session() as session:
d = await session.get(Device, UUID(device_id))
if not d:
return
stat_rx.text = _format_bytes(d.rx_bytes)
stat_tx.text = _format_bytes(d.tx_bytes)
stat_handshake.text = str(d.latest_handshake)[:19] if d.latest_handshake else "-"
stat_remote.text = d.remote_ip or "-"
ui.timer(30, refresh_stats)
# Active configuration
with ui.card().classes("w-full q-mt-md"):
ui.label("Active Configuration").classes("text-subtitle1 text-bold")
ui.separator()
with ui.grid(columns=2).classes("w-full gap-2 q-pa-sm"):
_ips = device.allowed_ips if not device.use_default_allowed_ips else settings.wg_allowed_ips
ui.label("Allowed IPs:").classes("text-bold")
ui.label(str(_ips) if isinstance(_ips, str) else ", ".join(_ips) if _ips else "-")
_dns = device.dns if not device.use_default_dns else settings.wg_dns
ui.label("DNS:").classes("text-bold")
ui.label(str(_dns) if isinstance(_dns, str) else ", ".join(_dns) if _dns else "-")
_ep = device.endpoint if not device.use_default_endpoint else settings.wg_endpoint_host
ui.label("Endpoint:").classes("text-bold")
ui.label(f"{_ep}:{settings.wg_endpoint_port}" if _ep else "-")
_mtu = device.mtu if not device.use_default_mtu else settings.wg_mtu
ui.label("MTU:").classes("text-bold")
ui.label(str(_mtu) if _mtu else "-")
_ka = device.persistent_keepalive if not device.use_default_persistent_keepalive else settings.wg_persistent_keepalive
ui.label("Persistent Keepalive:").classes("text-bold")
ui.label(str(_ka) if _ka else "-")
# Edit form
with ui.card().classes("w-full q-mt-md"):
ui.label("Edit Device").classes("text-subtitle1 text-bold")
ui.separator()
edit_name = ui.input("Device Name", value=device.name).props("outlined dense").classes("w-full")
edit_desc = ui.input("Description", value=device.description or "").props("outlined dense").classes("w-full")
ui.separator().classes("q-my-sm")
ui.label("Configuration Overrides").classes("text-subtitle2")
with ui.grid(columns=2).classes("w-full gap-2"):
edit_use_default_ips = ui.switch("Use default Allowed IPs", value=device.use_default_allowed_ips)
edit_allowed_ips = ui.input(
"Allowed IPs", value=", ".join(device.allowed_ips) if device.allowed_ips else "",
).props("outlined dense").classes("w-full").bind_enabled_from(
edit_use_default_ips, "value", backward=lambda v: not v
)
edit_use_default_dns = ui.switch("Use default DNS", value=device.use_default_dns)
edit_dns = ui.input(
"DNS Servers", value=", ".join(device.dns) if device.dns else "",
).props("outlined dense").classes("w-full").bind_enabled_from(
edit_use_default_dns, "value", backward=lambda v: not v
)
edit_use_default_endpoint = ui.switch("Use default Endpoint", value=device.use_default_endpoint)
edit_endpoint = ui.input(
"Endpoint", value=device.endpoint or "",
).props("outlined dense").classes("w-full").bind_enabled_from(
edit_use_default_endpoint, "value", backward=lambda v: not v
)
edit_use_default_mtu = ui.switch("Use default MTU", value=device.use_default_mtu)
edit_mtu = ui.input(
"MTU", value=str(device.mtu) if device.mtu else "",
).props("outlined dense").classes("w-full").bind_enabled_from(
edit_use_default_mtu, "value", backward=lambda v: not v
)
edit_use_default_keepalive = ui.switch("Use default Keepalive", value=device.use_default_persistent_keepalive)
edit_keepalive = ui.input(
"Persistent Keepalive", value=str(device.persistent_keepalive) if device.persistent_keepalive else "",
).props("outlined dense").classes("w-full").bind_enabled_from(
edit_use_default_keepalive, "value", backward=lambda v: not v
)
ui.button("Save Changes", on_click=save_edit).props("color=primary").classes("q-mt-md")
# Danger zone
with ui.card().classes("w-full q-mt-md"):
ui.label("Danger Zone").classes("text-subtitle1 text-bold text-negative")
ui.separator()
ui.button("Delete Device", icon="delete", on_click=lambda: confirm_dialog.open()).props(
"color=negative outline"
)
# Confirm delete dialog
with ui.dialog() as confirm_dialog:
with ui.card().classes("w-80"):
ui.label("Delete Device?").classes("text-h6")
ui.label(f"This will permanently remove '{device.name}' and its WireGuard peer.").classes("text-body2")
with ui.row().classes("w-full justify-end q-mt-sm"):
ui.button("Cancel", on_click=confirm_dialog.close).props("flat")
ui.button("Delete", on_click=delete_and_redirect).props("color=negative")
def _show_config_dialog(device_name: str, config_text: str):
"""Show a dialog with the WireGuard client configuration and QR code."""
with ui.dialog(value=True) as dialog:
with ui.card().classes("w-96"):
ui.label(f"Config for {device_name}").classes("text-h6")
ui.label("Save this — the private key won't be shown again.").classes("text-caption text-negative")
ui.textarea(value=config_text).props("readonly outlined").classes(
"w-full font-mono text-xs q-mt-sm"
).style("min-height: 200px")
try:
qr = qrcode.make(config_text, image_factory=qrcode.image.svg.SvgPathImage)
buf = io.BytesIO()
qr.save(buf)
ui.html(buf.getvalue().decode()).classes("w-full q-mt-sm")
except Exception:
ui.label("QR code generation failed").classes("text-caption text-grey")
ui.button(
"Download .conf",
on_click=lambda: ui.download(config_text.encode(), f"{device_name}.conf"),
).props("color=primary outline").classes("w-full q-mt-sm")
ui.button("Close", on_click=dialog.close).props("flat").classes("w-full")

9
wiregui/pages/home.py Normal file
View file

@ -0,0 +1,9 @@
from nicegui import app, ui
@ui.page("/")
def home_page():
if not app.storage.user.get("authenticated"):
return ui.navigate.to("/login")
# Redirect to devices as the main landing page
return ui.navigate.to("/devices")

48
wiregui/pages/layout.py Normal file
View file

@ -0,0 +1,48 @@
"""Shared layout — sidebar navigation + header."""
from nicegui import app, ui
from wiregui.services import notifications
def layout(title: str = "WireGUI"):
"""Render the shared app chrome (header + sidebar). Call at the top of each page."""
user_email = app.storage.user.get("email", "")
role = app.storage.user.get("role", "")
def logout():
app.storage.user.clear()
ui.navigate.to("/login")
# Header
with ui.header().classes("items-center justify-between"):
with ui.row().classes("items-center"):
ui.button(icon="menu", on_click=lambda: drawer.toggle()).props("flat color=white")
ui.label("WireGUI").classes("text-h6")
with ui.row().classes("items-center"):
if role == "admin":
notif_count = notifications.count()
with ui.button(
icon="notifications",
on_click=lambda: ui.navigate.to("/admin/diagnostics"),
).props("flat color=white"):
if notif_count > 0:
ui.badge(str(notif_count), color="red").props("floating")
ui.label(f"{user_email}").classes("text-subtitle2")
ui.button("Logout", on_click=logout).props("flat color=white")
# Sidebar
with ui.left_drawer(value=True, bordered=True).classes("bg-grey-1") as drawer:
ui.label("Navigation").classes("text-subtitle2 q-pa-sm text-grey-7")
ui.separator()
ui.item("Devices", on_click=lambda: ui.navigate.to("/devices")).classes("cursor-pointer")
ui.item("Account", on_click=lambda: ui.navigate.to("/account")).classes("cursor-pointer")
if role == "admin":
ui.separator()
ui.label("Admin").classes("text-subtitle2 q-pa-sm text-grey-7")
ui.item("Users", on_click=lambda: ui.navigate.to("/admin/users")).classes("cursor-pointer")
ui.item("All Devices", on_click=lambda: ui.navigate.to("/admin/devices")).classes("cursor-pointer")
ui.item("Rules", on_click=lambda: ui.navigate.to("/admin/rules")).classes("cursor-pointer")
ui.item("Settings", on_click=lambda: ui.navigate.to("/admin/settings")).classes("cursor-pointer")
ui.item("Diagnostics", on_click=lambda: ui.navigate.to("/admin/diagnostics")).classes("cursor-pointer")

82
wiregui/pages/login.py Normal file
View file

@ -0,0 +1,82 @@
"""Login page — email/password, MFA redirect, OIDC provider buttons."""
from nicegui import app, ui
from sqlmodel import select
from wiregui.auth.oidc import load_providers
from wiregui.auth.session import authenticate_user
from wiregui.db import async_session
from wiregui.models.mfa_method import MFAMethod
from wiregui.utils.time import utcnow
@ui.page("/login")
async def login_page():
if app.storage.user.get("authenticated"):
return ui.navigate.to("/")
# Load OIDC providers for SSO buttons
oidc_providers = await load_providers()
async def try_login():
user = await authenticate_user(email.value, password.value)
if user is None:
ui.notify("Invalid email or password", type="negative")
return
# Check if user has MFA methods
async with async_session() as session:
result = await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user.id)
)
mfa_methods = result.scalars().all()
# Update sign-in tracking
user_record = await session.get(type(user), user.id)
user_record.last_signed_in_at = utcnow()
user_record.last_signed_in_method = "local"
session.add(user_record)
await session.commit()
if mfa_methods:
# Store pending auth and redirect to MFA challenge
app.storage.user["pending_mfa"] = {
"user_id": str(user.id),
"email": user.email,
"role": user.role,
}
ui.navigate.to("/mfa")
else:
# No MFA — complete login directly
app.storage.user.update(
authenticated=True,
user_id=str(user.id),
email=user.email,
role=user.role,
)
ui.navigate.to("/")
with ui.column().classes("absolute-center items-center"):
ui.label("WireGUI").classes("text-h4 text-bold")
ui.label("Sign in to your account").classes("text-subtitle1 q-mb-md")
with ui.card().classes("w-80"):
email = ui.input("Email").props("outlined dense").classes("w-full")
password = ui.input("Password", password=True, password_toggle_button=True).props(
"outlined dense"
).classes("w-full")
ui.button("Sign in", on_click=try_login).classes("w-full q-mt-sm")
password.on("keydown.enter", try_login)
# OIDC provider buttons
if oidc_providers:
ui.separator().classes("q-my-md")
ui.label("Or sign in with").classes("text-caption text-center w-full")
for provider in oidc_providers:
pid = provider.get("id", "")
label = provider.get("label", pid)
ui.button(
label,
on_click=lambda p=pid: ui.navigate.to(f"/auth/oidc/{p}"),
).props("outline").classes("w-full q-mt-xs")

View file

@ -0,0 +1,93 @@
"""MFA challenge page — presented after password login when user has MFA enabled."""
from uuid import UUID
from loguru import logger
from nicegui import app, ui
from sqlmodel import select
from wiregui.auth.mfa import verify_totp_code
from wiregui.db import async_session
from wiregui.models.mfa_method import MFAMethod
@ui.page("/mfa")
async def mfa_challenge_page():
# Must have passed password auth (pending_mfa set by login page)
pending = app.storage.user.get("pending_mfa")
if not pending:
return ui.navigate.to("/login")
user_id = UUID(pending["user_id"])
# Load user's MFA methods
async with async_session() as session:
result = await session.execute(
select(MFAMethod).where(MFAMethod.user_id == user_id)
)
methods = result.scalars().all()
totp_methods = [m for m in methods if m.type == "totp"]
if not totp_methods:
# No MFA methods — shouldn't be here, complete login
_complete_login(pending)
return
async def verify_code():
code = code_input.value.strip()
if not code:
ui.notify("Enter your authentication code", type="negative")
return
for method in totp_methods:
secret = method.payload.get("secret")
if secret and verify_totp_code(secret, code):
# Update last used
async with async_session() as session:
m = await session.get(MFAMethod, method.id)
if m:
from wiregui.utils.time import utcnow
m.last_used_at = utcnow()
session.add(m)
await session.commit()
logger.info("MFA verified for user {}", pending["email"])
_complete_login(pending)
return
ui.notify("Invalid code", type="negative")
with ui.column().classes("absolute-center items-center"):
ui.label("Two-Factor Authentication").classes("text-h5")
ui.label(f"Enter the code from your authenticator app for {pending['email']}").classes(
"text-subtitle1 q-mb-md"
)
with ui.card().classes("w-80"):
code_input = ui.input("Authentication Code").props(
"outlined dense maxlength=6"
).classes("w-full text-center font-mono text-lg")
ui.button("Verify", on_click=verify_code).classes("w-full q-mt-sm")
code_input.on("keydown.enter", verify_code)
ui.button("Cancel", on_click=lambda: _cancel_mfa()).props("flat").classes("q-mt-md")
def _complete_login(pending: dict):
"""Complete the login by setting full auth state."""
app.storage.user.update(
authenticated=True,
user_id=pending["user_id"],
email=pending["email"],
role=pending["role"],
)
# Clear pending state
app.storage.user.pop("pending_mfa", None)
ui.navigate.to("/")
def _cancel_mfa():
app.storage.user.clear()
ui.navigate.to("/login")

9
wiregui/redis.py Normal file
View file

@ -0,0 +1,9 @@
import redis.asyncio as redis
from wiregui.config import get_settings
pool = redis.ConnectionPool.from_url(get_settings().redis_url)
def get_redis() -> redis.Redis:
return redis.Redis(connection_pool=pool)

View file

View file

@ -0,0 +1,37 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
class ConfigurationRead(BaseModel):
id: UUID
allow_unprivileged_device_management: bool
allow_unprivileged_device_configuration: bool
local_auth_enabled: bool
disable_vpn_on_oidc_error: bool
default_client_persistent_keepalive: int
default_client_mtu: int
default_client_endpoint: str | None
default_client_dns: list[str]
default_client_allowed_ips: list[str]
vpn_session_duration: int
logo_url: str | None
logo_type: str | None
inserted_at: datetime
updated_at: datetime
class ConfigurationUpdate(BaseModel):
allow_unprivileged_device_management: bool | None = None
allow_unprivileged_device_configuration: bool | None = None
local_auth_enabled: bool | None = None
disable_vpn_on_oidc_error: bool | None = None
default_client_persistent_keepalive: int | None = None
default_client_mtu: int | None = None
default_client_endpoint: str | None = None
default_client_dns: list[str] | None = None
default_client_allowed_ips: list[str] | None = None
vpn_session_duration: int | None = None
logo_url: str | None = None
logo_type: str | None = None

50
wiregui/schemas/device.py Normal file
View file

@ -0,0 +1,50 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
class DeviceRead(BaseModel):
id: UUID
name: str
description: str | None
public_key: str
ipv4: str | None
ipv6: str | None
use_default_allowed_ips: bool
use_default_dns: bool
use_default_endpoint: bool
use_default_mtu: bool
use_default_persistent_keepalive: bool
endpoint: str | None
mtu: int | None
persistent_keepalive: int | None
allowed_ips: list[str]
dns: list[str]
rx_bytes: int | None
tx_bytes: int | None
latest_handshake: datetime | None
user_id: UUID
inserted_at: datetime
updated_at: datetime
class DeviceCreate(BaseModel):
name: str
description: str | None = None
user_id: UUID | None = None # admin can assign to another user
class DeviceUpdate(BaseModel):
name: str | None = None
description: str | None = None
use_default_allowed_ips: bool | None = None
use_default_dns: bool | None = None
use_default_endpoint: bool | None = None
use_default_mtu: bool | None = None
use_default_persistent_keepalive: bool | None = None
endpoint: str | None = None
mtu: int | None = None
persistent_keepalive: int | None = None
allowed_ips: list[str] | None = None
dns: list[str] | None = None

31
wiregui/schemas/rule.py Normal file
View file

@ -0,0 +1,31 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
class RuleRead(BaseModel):
id: UUID
action: str
destination: str
port_type: str | None
port_range: str | None
user_id: UUID | None
inserted_at: datetime
updated_at: datetime
class RuleCreate(BaseModel):
action: str = "drop"
destination: str
port_type: str | None = None
port_range: str | None = None
user_id: UUID | None = None
class RuleUpdate(BaseModel):
action: str | None = None
destination: str | None = None
port_type: str | None = None
port_range: str | None = None
user_id: UUID | None = None

28
wiregui/schemas/user.py Normal file
View file

@ -0,0 +1,28 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, EmailStr
class UserRead(BaseModel):
id: UUID
email: str
role: str
disabled_at: datetime | None
last_signed_in_at: datetime | None
last_signed_in_method: str | None
inserted_at: datetime
updated_at: datetime
class UserCreate(BaseModel):
email: str
password: str
role: str = "unprivileged"
class UserUpdate(BaseModel):
email: str | None = None
password: str | None = None
role: str | None = None
disabled_at: datetime | None = None

View file

51
wiregui/services/email.py Normal file
View file

@ -0,0 +1,51 @@
"""Email sending via aiosmtplib for magic links and notifications."""
import aiosmtplib
from email.message import EmailMessage
from loguru import logger
from wiregui.config import get_settings
async def send_email(to: str, subject: str, body: str) -> bool:
"""Send an email via configured SMTP. Returns True on success."""
settings = get_settings()
if not settings.smtp_host:
logger.warning("SMTP not configured — email to {} not sent", to)
return False
msg = EmailMessage()
msg["From"] = settings.smtp_from
msg["To"] = to
msg["Subject"] = subject
msg.set_content(body)
try:
await aiosmtplib.send(
msg,
hostname=settings.smtp_host,
port=settings.smtp_port,
username=settings.smtp_user,
password=settings.smtp_password,
start_tls=True,
)
logger.info("Email sent to {}: {}", to, subject)
return True
except Exception as e:
logger.error("Failed to send email to {}: {}", to, e)
return False
async def send_magic_link(to: str, link: str) -> bool:
"""Send a magic link sign-in email."""
subject = "WireGUI — Sign in link"
body = f"""You requested a sign-in link for WireGUI.
Click here to sign in:
{link}
This link expires in 15 minutes. If you didn't request this, you can safely ignore this email.
"""
return await send_email(to, subject, body)

136
wiregui/services/events.py Normal file
View file

@ -0,0 +1,136 @@
"""Event bridge — propagates database changes to WireGuard and firewall."""
from uuid import UUID
from loguru import logger
from wiregui.config import get_settings
from wiregui.db import async_session
from wiregui.models.device import Device
from wiregui.models.rule import Rule
from wiregui.services import firewall, wireguard
def _device_allowed_ips(device: Device) -> list[str]:
"""Build the allowed-ips list for a device peer (its tunnel addresses)."""
ips = []
if device.ipv4:
ips.append(f"{device.ipv4}/32")
if device.ipv6:
ips.append(f"{device.ipv6}/128")
return ips
# --- Device events ---
async def on_device_created(device: Device) -> None:
"""Configure WireGuard peer and firewall after a new device is created."""
settings = get_settings()
if not settings.wg_enabled:
return
try:
await wireguard.add_peer(
public_key=device.public_key,
allowed_ips=_device_allowed_ips(device),
preshared_key=device.preshared_key,
)
except Exception as e:
logger.error("Failed to add WG peer for device {}: {}", device.name, e)
try:
# Ensure user chain exists before adding jump rules
await firewall.add_user_chain(str(device.user_id))
await firewall.add_device_jump_rule(
str(device.user_id), device.ipv4, device.ipv6,
)
except Exception as e:
logger.error("Failed to add firewall jump rule for device {}: {}", device.name, e)
async def on_device_deleted(device: Device) -> None:
"""Remove WireGuard peer after a device is deleted."""
if not get_settings().wg_enabled:
return
try:
await wireguard.remove_peer(public_key=device.public_key)
except Exception as e:
logger.error("Failed to remove WG peer for device {}: {}", device.name, e)
# Firewall jump rules are cleaned up on next rebuild
async def on_device_updated(device: Device) -> None:
"""Update WireGuard peer after a device is modified."""
if not get_settings().wg_enabled:
return
try:
await wireguard.add_peer(
public_key=device.public_key,
allowed_ips=_device_allowed_ips(device),
preshared_key=device.preshared_key,
)
except Exception as e:
logger.error("Failed to update WG peer for device {}: {}", device.name, e)
# --- Rule events ---
async def on_rule_created(rule: Rule) -> None:
"""Apply a new firewall rule."""
if not get_settings().wg_enabled:
return
if rule.user_id is None:
return # global rules handled via rebuild
try:
await firewall.apply_rule(
str(rule.user_id), rule.destination, rule.action,
rule.port_type, rule.port_range,
)
except Exception as e:
logger.error("Failed to apply firewall rule {}: {}", rule.id, e)
async def on_rule_updated(rule: Rule) -> None:
"""Firewall rule updated — rebuild the user's chain."""
if not get_settings().wg_enabled:
return
if rule.user_id is None:
return
await _rebuild_user_chain(str(rule.user_id))
async def on_rule_deleted(rule: Rule) -> None:
"""Firewall rule removed — rebuild the user's chain."""
if not get_settings().wg_enabled:
return
if rule.user_id is None:
return
await _rebuild_user_chain(str(rule.user_id))
async def _rebuild_user_chain(user_id: str) -> None:
"""Flush and rebuild a single user's firewall chain from current DB rules."""
try:
from sqlmodel import select as sel
async with async_session() as session:
rules = (await session.execute(
sel(Rule).where(Rule.user_id == UUID(user_id))
)).scalars().all()
devices = (await session.execute(
sel(Device).where(Device.user_id == UUID(user_id))
)).scalars().all()
await firewall.rebuild_all_rules([{
"user_id": user_id,
"devices": [{"ipv4": d.ipv4, "ipv6": d.ipv6} for d in devices],
"rules": [
{"destination": r.destination, "action": r.action,
"port_type": r.port_type, "port_range": r.port_range}
for r in rules
],
}])
except Exception as e:
logger.error("Failed to rebuild firewall chain for user {}: {}", user_id, e)

View file

@ -0,0 +1,191 @@
"""nftables firewall management — per-user chains and sets for device traffic filtering."""
import asyncio
import json
from loguru import logger
from wiregui.config import get_settings
TABLE_NAME = "wiregui"
async def _nft(cmd: str) -> str:
"""Run an nft command and return stdout."""
proc = await asyncio.create_subprocess_exec(
"nft", *cmd.split(),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"nft {cmd} failed: {stderr.decode().strip()}")
return stdout.decode().strip()
async def _nft_batch(commands: list[str]) -> None:
"""Run multiple nft commands in a single atomic batch."""
batch = "\n".join(commands)
proc = await asyncio.create_subprocess_exec(
"nft", "-f", "-",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate(batch.encode())
if proc.returncode != 0:
raise RuntimeError(f"nft batch failed: {stderr.decode().strip()}")
async def setup_base_tables() -> None:
"""Create the base wiregui table with forward and postrouting chains."""
commands = [
f"add table inet {TABLE_NAME}",
# Forward chain for filtering device traffic
f"add chain inet {TABLE_NAME} forward {{ type filter hook forward priority 0; policy accept; }}",
# Postrouting for NAT/masquerade
f"add chain inet {TABLE_NAME} postrouting {{ type nat hook postrouting priority 100; policy accept; }}",
]
try:
await _nft_batch(commands)
logger.info("Base nftables table '{}' created", TABLE_NAME)
except RuntimeError as e:
# Table may already exist
if "File exists" not in str(e):
raise
logger.debug("Base nftables table '{}' already exists", TABLE_NAME)
async def setup_masquerade(iface: str | None = None) -> None:
"""Add masquerade rules for VPN traffic — NAT only traffic originating from WG subnets."""
settings = get_settings()
iface = iface or settings.wg_interface
v4_net = settings.wg_ipv4_network
v6_net = settings.wg_ipv6_network
commands = [
f"flush chain inet {TABLE_NAME} postrouting",
f'add rule inet {TABLE_NAME} postrouting ip saddr {v4_net} oifname != "{iface}" masquerade',
f'add rule inet {TABLE_NAME} postrouting ip6 saddr {v6_net} oifname != "{iface}" masquerade',
]
try:
await _nft_batch(commands)
logger.info("Masquerade rule added for {}", iface)
except RuntimeError as e:
logger.debug("Masquerade setup: {}", e)
async def add_user_chain(user_id: str) -> None:
"""Create a per-user chain for firewall rules."""
chain = _user_chain_name(user_id)
commands = [
f"add chain inet {TABLE_NAME} {chain}",
]
try:
await _nft_batch(commands)
logger.debug("User chain created: {}", chain)
except RuntimeError as e:
if "File exists" not in str(e):
raise
async def remove_user_chain(user_id: str) -> None:
"""Remove a per-user chain and all its rules."""
chain = _user_chain_name(user_id)
try:
await _nft_batch([
f"flush chain inet {TABLE_NAME} {chain}",
f"delete chain inet {TABLE_NAME} {chain}",
])
logger.debug("User chain removed: {}", chain)
except RuntimeError as e:
logger.debug("Remove user chain {}: {}", chain, e)
async def add_device_jump_rule(user_id: str, device_ipv4: str | None, device_ipv6: str | None) -> None:
"""Add jump rules in the forward chain to route device traffic to the user chain."""
chain = _user_chain_name(user_id)
commands = []
if device_ipv4:
commands.append(
f"add rule inet {TABLE_NAME} forward ip saddr {device_ipv4} jump {chain}"
)
if device_ipv6:
commands.append(
f"add rule inet {TABLE_NAME} forward ip6 saddr {device_ipv6} jump {chain}"
)
if commands:
await _nft_batch(commands)
logger.debug("Jump rules added for device {}/{} -> {}", device_ipv4, device_ipv6, chain)
async def apply_rule(user_id: str, destination: str, action: str, port_type: str | None = None, port_range: str | None = None) -> None:
"""Add a filter rule to a user's chain."""
chain = _user_chain_name(user_id)
rule = _build_rule_expr(destination, action, port_type, port_range)
await _nft_batch([f"add rule inet {TABLE_NAME} {chain} {rule}"])
logger.debug("Rule applied in {}: {} -> {}", chain, destination, action)
async def rebuild_all_rules(users_devices_rules: list[dict]) -> None:
"""Full reconciliation: flush and rebuild all per-user chains from DB state.
Args:
users_devices_rules: list of dicts with keys:
user_id, devices (list of {ipv4, ipv6}), rules (list of {destination, action, port_type, port_range})
"""
commands = []
for entry in users_devices_rules:
user_id = entry["user_id"]
chain = _user_chain_name(user_id)
# Create/flush user chain
commands.append(f"add chain inet {TABLE_NAME} {chain}")
commands.append(f"flush chain inet {TABLE_NAME} {chain}")
# Add rules
for rule in entry.get("rules", []):
expr = _build_rule_expr(
rule["destination"], rule["action"],
rule.get("port_type"), rule.get("port_range"),
)
commands.append(f"add rule inet {TABLE_NAME} {chain} {expr}")
# Flush forward chain jump rules and re-add
commands.append(f"flush chain inet {TABLE_NAME} forward")
for entry in users_devices_rules:
user_id = entry["user_id"]
chain = _user_chain_name(user_id)
for dev in entry.get("devices", []):
if dev.get("ipv4"):
commands.append(f"add rule inet {TABLE_NAME} forward ip saddr {dev['ipv4']} jump {chain}")
if dev.get("ipv6"):
commands.append(f"add rule inet {TABLE_NAME} forward ip6 saddr {dev['ipv6']} jump {chain}")
if commands:
await _nft_batch(commands)
logger.info("Firewall rules rebuilt for {} users", len(users_devices_rules))
def _user_chain_name(user_id: str) -> str:
"""Generate a deterministic chain name from a user ID."""
# Use first 12 chars of UUID (without hyphens) to keep names short
short = user_id.replace("-", "")[:12]
return f"user_{short}"
def _build_rule_expr(destination: str, action: str, port_type: str | None = None, port_range: str | None = None) -> str:
"""Build an nftables rule expression string."""
# Determine IP version from destination
if ":" in destination:
addr_match = f"ip6 daddr {destination}"
else:
addr_match = f"ip daddr {destination}"
parts = [addr_match]
if port_type and port_range:
parts.append(f"{port_type} dport {port_range}")
parts.append(action)
return " ".join(parts)

View file

@ -0,0 +1,65 @@
"""In-memory notification queue with severity levels."""
from collections import deque
from datetime import datetime
from typing import Any
from uuid import uuid4
from loguru import logger
from wiregui.utils.time import utcnow
MAX_NOTIFICATIONS = 100
class Notification:
__slots__ = ("id", "severity", "message", "user", "timestamp")
def __init__(self, severity: str, message: str, user: str | None = None):
self.id = str(uuid4())
self.severity = severity # "info" | "warning" | "error"
self.message = message
self.user = user
self.timestamp = utcnow()
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"severity": self.severity,
"message": self.message,
"user": self.user,
"timestamp": self.timestamp.isoformat(),
}
_notifications: deque[Notification] = deque(maxlen=MAX_NOTIFICATIONS)
def add(severity: str, message: str, user: str | None = None) -> Notification:
"""Add a notification to the queue."""
n = Notification(severity, message, user)
_notifications.appendleft(n)
logger.debug("Notification added: [{}] {}", severity, message)
return n
def current() -> list[Notification]:
"""Return all current notifications (newest first)."""
return list(_notifications)
def clear(notification_id: str) -> None:
"""Remove a specific notification by ID."""
for i, n in enumerate(_notifications):
if n.id == notification_id:
del _notifications[i]
return
def clear_all() -> None:
"""Remove all notifications."""
_notifications.clear()
def count() -> int:
return len(_notifications)

View file

@ -0,0 +1,188 @@
"""WireGuard interface management via subprocess calls to `wg` and `ip`."""
import asyncio
from dataclasses import dataclass, field
from datetime import datetime
from loguru import logger
from wiregui.config import get_settings
@dataclass
class PeerInfo:
public_key: str
endpoint: str | None = None
allowed_ips: list[str] = field(default_factory=list)
latest_handshake: datetime | None = None
rx_bytes: int = 0
tx_bytes: int = 0
async def _run(args: list[str], input_data: str | None = None) -> str:
"""Run a subprocess and return stdout. Raises on non-zero exit."""
proc = await asyncio.create_subprocess_exec(
*args,
stdin=asyncio.subprocess.PIPE if input_data else None,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate(input_data.encode() if input_data else None)
if proc.returncode != 0:
raise RuntimeError(f"{' '.join(args)} failed (rc={proc.returncode}): {stderr.decode().strip()}")
return stdout.decode().strip()
async def ensure_interface(iface: str | None = None) -> None:
"""Create WireGuard interface if it doesn't exist, assign server IPs and bring it up."""
settings = get_settings()
iface = iface or settings.wg_interface
# Check if interface exists
try:
await _run(["ip", "link", "show", iface])
logger.debug("Interface {} already exists", iface)
return
except RuntimeError:
pass
logger.info("Creating WireGuard interface {}", iface)
await _run(["ip", "link", "add", iface, "type", "wireguard"])
# Assign server IP (first host in each network)
from ipaddress import IPv4Network, IPv6Network
v4_net = IPv4Network(settings.wg_ipv4_network, strict=False)
v4_server = str(list(v4_net.hosts())[0])
await _run(["ip", "address", "add", f"{v4_server}/{v4_net.prefixlen}", "dev", iface])
v6_net = IPv6Network(settings.wg_ipv6_network, strict=False)
v6_server = str(list(v6_net.hosts())[0])
await _run(["ip", "address", "add", f"{v6_server}/{v6_net.prefixlen}", "dev", iface])
await _run(["ip", "link", "set", iface, "up"])
logger.info("Interface {} is up with {} and {}", iface, v4_server, v6_server)
async def configure_interface(iface: str | None = None) -> None:
"""Set the server private key and listen port on the WireGuard interface from DB config."""
from sqlmodel import select
from wiregui.db import async_session
from wiregui.models.configuration import Configuration
settings = get_settings()
iface = iface or settings.wg_interface
async with async_session() as session:
result = await session.execute(select(Configuration).limit(1))
config = result.scalar_one_or_none()
if not config or not config.server_private_key:
logger.error("No server private key in Configuration — WG interface not configured")
return
# Write private key to a temp file (stdin piping has issues with uvloop)
import tempfile
import os
key_fd, key_path = tempfile.mkstemp()
try:
os.write(key_fd, config.server_private_key.encode())
os.close(key_fd)
os.chmod(key_path, 0o600)
await _run(["wg", "set", iface, "private-key", key_path, "listen-port", str(settings.wg_endpoint_port)])
finally:
os.unlink(key_path)
logger.info("WireGuard interface {} configured (listen-port={})", iface, settings.wg_endpoint_port)
async def set_private_key(private_key_path: str, iface: str | None = None) -> None:
"""Set the WireGuard private key from a file."""
settings = get_settings()
iface = iface or settings.wg_interface
await _run(["wg", "set", iface, "private-key", private_key_path])
async def set_listen_port(port: int, iface: str | None = None) -> None:
"""Set the WireGuard listen port."""
settings = get_settings()
iface = iface or settings.wg_interface
await _run(["wg", "set", iface, "listen-port", str(port)])
async def add_peer(
public_key: str,
allowed_ips: list[str],
preshared_key: str | None = None,
iface: str | None = None,
) -> None:
"""Add or update a WireGuard peer."""
settings = get_settings()
iface = iface or settings.wg_interface
args = ["wg", "set", iface, "peer", public_key, "allowed-ips", ",".join(allowed_ips)]
if preshared_key:
import tempfile
import os
psk_fd, psk_path = tempfile.mkstemp()
try:
os.write(psk_fd, preshared_key.encode())
os.close(psk_fd)
os.chmod(psk_path, 0o600)
await _run([
"wg", "set", iface, "peer", public_key,
"allowed-ips", ",".join(allowed_ips),
"preshared-key", psk_path,
])
finally:
os.unlink(psk_path)
else:
await _run(args)
logger.info("Peer added/updated: {} -> {}", public_key[:20], allowed_ips)
async def remove_peer(public_key: str, iface: str | None = None) -> None:
"""Remove a WireGuard peer."""
settings = get_settings()
iface = iface or settings.wg_interface
await _run(["wg", "set", iface, "peer", public_key, "remove"])
logger.info("Peer removed: {}", public_key[:20])
async def get_peers(iface: str | None = None) -> list[PeerInfo]:
"""Parse `wg show <iface> dump` and return peer information."""
settings = get_settings()
iface = iface or settings.wg_interface
try:
output = await _run(["wg", "show", iface, "dump"])
except RuntimeError:
return []
peers = []
for line in output.splitlines()[1:]: # skip the interface line
parts = line.split("\t")
if len(parts) < 8:
continue
pub_key = parts[0]
# parts: public_key, preshared_key, endpoint, allowed_ips, latest_handshake, rx, tx, keepalive
endpoint = parts[2] if parts[2] != "(none)" else None
allowed_ips = parts[3].split(",") if parts[3] != "(none)" else []
handshake_ts = int(parts[4]) if parts[4] != "0" else None
latest_handshake = datetime.utcfromtimestamp(handshake_ts) if handshake_ts else None
rx_bytes = int(parts[5])
tx_bytes = int(parts[6])
peers.append(PeerInfo(
public_key=pub_key,
endpoint=endpoint,
allowed_ips=allowed_ips,
latest_handshake=latest_handshake,
rx_bytes=rx_bytes,
tx_bytes=tx_bytes,
))
return peers

28
wiregui/tasks/__init__.py Normal file
View file

@ -0,0 +1,28 @@
"""Background tasks — registered on app startup."""
import asyncio
from loguru import logger
from wiregui.config import get_settings
_tasks: list[asyncio.Task] = []
def register_task(coro, name: str) -> None:
"""Schedule a coroutine as a background task."""
task = asyncio.create_task(coro, name=name)
_tasks.append(task)
logger.info("Background task registered: {}", name)
async def cancel_all() -> None:
"""Cancel all registered background tasks (called on shutdown)."""
for task in _tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
_tasks.clear()
logger.info("All background tasks cancelled")

View file

@ -0,0 +1,60 @@
"""Periodic WAN connectivity checks — fetch a URL and log the result."""
import asyncio
import httpx
from loguru import logger
from sqlmodel import select
from wiregui.db import async_session
from wiregui.models.configuration import Configuration
from wiregui.models.connectivity_check import ConnectivityCheck
from wiregui.services import notifications
from wiregui.utils.time import utcnow
DEFAULT_URL = "https://ping-dev.firezone.dev"
DEFAULT_INTERVAL = 300 # 5 minutes
async def connectivity_loop() -> None:
"""Run forever: perform connectivity checks at a configurable interval."""
logger.info("Connectivity check task started")
await asyncio.sleep(60) # Initial delay to avoid startup spam
while True:
try:
await _check_connectivity()
except asyncio.CancelledError:
raise
except Exception as e:
logger.error("Connectivity check failed: {}", e)
await asyncio.sleep(DEFAULT_INTERVAL)
async def _check_connectivity() -> None:
"""Fetch the connectivity check URL and store the result."""
url = DEFAULT_URL
try:
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.get(url)
check = ConnectivityCheck(
url=url,
response_code=resp.status_code,
response_headers=dict(resp.headers),
response_body=resp.text[:500],
)
logger.debug("Connectivity check: {} -> {}", url, resp.status_code)
except Exception as e:
check = ConnectivityCheck(
url=url,
response_code=None,
response_body=str(e)[:500],
)
logger.warning("Connectivity check failed: {}", e)
notifications.add("warning", f"WAN connectivity check failed: {e}")
async with async_session() as session:
session.add(check)
await session.commit()

View file

@ -0,0 +1,108 @@
"""Periodically refresh OIDC tokens for all active connections."""
import asyncio
from loguru import logger
from sqlmodel import select
from wiregui.db import async_session
from wiregui.models.oidc_connection import OIDCConnection
from wiregui.models.user import User
from wiregui.services import notifications
from wiregui.utils.time import utcnow
INTERVAL_SECONDS = 600 # 10 minutes
async def oidc_refresh_loop() -> None:
"""Run forever: refresh OIDC tokens every INTERVAL_SECONDS."""
logger.info("OIDC refresh task started (interval={}s)", INTERVAL_SECONDS)
await asyncio.sleep(60) # Initial delay to avoid startup spam
while True:
try:
await _refresh_all()
except asyncio.CancelledError:
raise
except Exception as e:
logger.error("OIDC refresh cycle failed: {}", e)
await asyncio.sleep(INTERVAL_SECONDS)
async def _refresh_all() -> None:
"""Attempt to refresh all stored OIDC tokens."""
from authlib.integrations.httpx_client import AsyncOAuth2Client
from wiregui.auth.oidc import load_providers
providers = await load_providers()
provider_map = {p["id"]: p for p in providers}
async with async_session() as session:
result = await session.execute(
select(OIDCConnection).where(OIDCConnection.refresh_token.is_not(None))
)
connections = result.scalars().all()
if not connections:
return
refreshed = 0
failed = 0
for conn in connections:
provider_config = provider_map.get(conn.provider)
if not provider_config:
continue
try:
async with AsyncOAuth2Client(
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
) as client:
# Load server metadata to get token endpoint
discovery_url = provider_config.get("discovery_document_uri")
if discovery_url:
import httpx
resp = await client.get(discovery_url)
metadata = resp.json()
token_endpoint = metadata.get("token_endpoint")
else:
continue
new_token = await client.refresh_token(
url=token_endpoint,
refresh_token=conn.refresh_token,
)
# Update connection
async with async_session() as session:
c = await session.get(OIDCConnection, conn.id)
if c:
c.refresh_token = new_token.get("refresh_token", c.refresh_token)
c.refresh_response = dict(new_token)
c.refreshed_at = utcnow()
session.add(c)
await session.commit()
refreshed += 1
except Exception as e:
failed += 1
logger.warning("OIDC refresh failed for connection {} (provider={}): {}",
conn.id, conn.provider, e)
# Check if we should disable VPN
from wiregui.models.configuration import Configuration
async with async_session() as session:
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
if config and config.disable_vpn_on_oidc_error:
user = await session.get(User, conn.user_id)
if user:
notifications.add(
"error",
f"OIDC refresh failed for {user.email} ({conn.provider}). VPN access may be affected.",
user=user.email,
)
if refreshed or failed:
logger.info("OIDC refresh: {} succeeded, {} failed", refreshed, failed)

Some files were not shown because too many files have changed in this diff Show more