feat: initial WireGUI implementation — full VPN management platform
Complete Python/NiceGUI rewrite of the Wirezone (Elixir/Phoenix) VPN management platform. All 10 implementation phases delivered. Core stack: - NiceGUI reactive UI with SQLModel ORM on PostgreSQL (asyncpg) - Alembic migrations, Valkey/Redis cache, pydantic-settings config - WireGuard management via subprocess (wg/ip/nft CLIs) - 164 tests passing, 35% code coverage Features: - User/device/rule CRUD with admin and unprivileged roles - Full device config form with per-device WG overrides - WireGuard client config generation with QR codes - REST API (v0) with Bearer token auth for all resources - TOTP MFA with QR registration and challenge flow - OIDC SSO with authlib (provider registry, auto-create users) - Magic link passwordless sign-in via email - SAML SP-initiated SSO with IdP metadata parsing - WebAuthn/FIDO2 security key registration - nftables firewall with per-user chains and masquerade - Background tasks: WG stats polling, VPN session expiry, OIDC token refresh, WAN connectivity checks - Startup reconciliation (DB ↔ WireGuard state sync) - In-memory notification system with header badge - Admin UI: users, devices, rules, settings (3 tabs), diagnostics - Loguru logging with optional timestamped file output Deployment: - Multi-stage Dockerfile (python:3.13-slim) - Docker Compose prod stack (bridge networking, NET_ADMIN, nftables) - Forgejo CI: tests → semantic versioning → Docker registry push - Health endpoint at /api/health
This commit is contained in:
commit
0546b44507
109 changed files with 11793 additions and 0 deletions
13
.dockerignore
Normal file
13
.dockerignore
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
.venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.env
|
||||
.nicegui/
|
||||
logs/
|
||||
.git/
|
||||
.idea/
|
||||
.pytest_cache/
|
||||
tests/
|
||||
.forgejo/
|
||||
*.md
|
||||
compose*.yml
|
||||
211
.forgejo/workflows/release.yml
Normal file
211
.forgejo/workflows/release.yml
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: docker
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17
|
||||
env:
|
||||
POSTGRES_USER: wiregui
|
||||
POSTGRES_PASSWORD: wiregui
|
||||
POSTGRES_DB: wiregui
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U wiregui"
|
||||
--health-interval 5s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
env:
|
||||
WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install -y --no-install-recommends \
|
||||
wireguard-tools pkg-config libxml2-dev libxmlsec1-dev libxmlsec1-openssl
|
||||
|
||||
- name: Install uv
|
||||
run: pip install uv
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest -v --tb=short
|
||||
|
||||
release:
|
||||
needs: test
|
||||
if: github.ref == 'refs/heads/main' && github.event_name == 'push'
|
||||
runs-on: docker
|
||||
outputs:
|
||||
new_tag: ${{ steps.version.outputs.new_tag }}
|
||||
new_version: ${{ steps.version.outputs.new_version }}
|
||||
skip: ${{ steps.version.outputs.skip }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Determine version bump
|
||||
id: version
|
||||
run: |
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
echo "latest_tag=${LATEST_TAG}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
CURRENT="${LATEST_TAG#v}"
|
||||
IFS='.' read -r MAJOR MINOR PATCH <<< "$CURRENT"
|
||||
|
||||
COMMITS=$(git log "${LATEST_TAG}..HEAD" --pretty=format:"%s" 2>/dev/null || git log --pretty=format:"%s")
|
||||
|
||||
BUMP="none"
|
||||
while IFS= read -r msg; do
|
||||
case "$msg" in
|
||||
*"BREAKING CHANGE"*|*"!:"*)
|
||||
BUMP="major"
|
||||
break
|
||||
;;
|
||||
feat:*|feat\(*)
|
||||
[ "$BUMP" != "major" ] && BUMP="minor"
|
||||
;;
|
||||
fix:*|fix\(*|perf:*|perf\(*|refactor:*|refactor\(*)
|
||||
[ "$BUMP" = "none" ] && BUMP="patch"
|
||||
;;
|
||||
esac
|
||||
done <<< "$COMMITS"
|
||||
|
||||
if [ "$BUMP" = "none" ]; then
|
||||
echo "No version-relevant commits since ${LATEST_TAG}, skipping release"
|
||||
echo "skip=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
case "$BUMP" in
|
||||
major) MAJOR=$((MAJOR + 1)); MINOR=0; PATCH=0 ;;
|
||||
minor) MINOR=$((MINOR + 1)); PATCH=0 ;;
|
||||
patch) PATCH=$((PATCH + 1)) ;;
|
||||
esac
|
||||
|
||||
NEW_VERSION="${MAJOR}.${MINOR}.${PATCH}"
|
||||
echo "new_version=${NEW_VERSION}" >> "$GITHUB_OUTPUT"
|
||||
echo "new_tag=v${NEW_VERSION}" >> "$GITHUB_OUTPUT"
|
||||
echo "bump=${BUMP}" >> "$GITHUB_OUTPUT"
|
||||
echo "skip=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Version bump: ${BUMP} -> v${NEW_VERSION}"
|
||||
|
||||
- name: Generate changelog
|
||||
id: changelog
|
||||
if: steps.version.outputs.skip != 'true'
|
||||
run: |
|
||||
LATEST_TAG="${{ steps.version.outputs.latest_tag }}"
|
||||
NEW_TAG="${{ steps.version.outputs.new_tag }}"
|
||||
|
||||
BODY="## ${NEW_TAG}"$'\n\n'
|
||||
|
||||
for type_label in "feat:Features" "fix:Bug Fixes" "refactor:Refactoring" "perf:Performance" "docs:Documentation" "chore:Maintenance"; do
|
||||
prefix="${type_label%%:*}"
|
||||
label="${type_label#*:}"
|
||||
MATCHES=$(git log "${LATEST_TAG}..HEAD" --pretty=format:"%s" 2>/dev/null | grep -E "^${prefix}[:(]" || true)
|
||||
if [ -n "$MATCHES" ]; then
|
||||
BODY="${BODY}### ${label}"$'\n\n'
|
||||
while IFS= read -r line; do
|
||||
CLEAN=$(echo "$line" | sed -E "s/^${prefix}(\([^)]*\))?:\s*//")
|
||||
BODY="${BODY}- ${CLEAN}"$'\n'
|
||||
done <<< "$MATCHES"
|
||||
BODY="${BODY}"$'\n'
|
||||
fi
|
||||
done
|
||||
|
||||
echo "${BODY}" > /tmp/changelog.md
|
||||
echo "Generated changelog for ${NEW_TAG}"
|
||||
|
||||
- name: Create tag and release
|
||||
if: steps.version.outputs.skip != 'true'
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
NEW_TAG="${{ steps.version.outputs.new_tag }}"
|
||||
|
||||
git config user.name "Forgejo Actions"
|
||||
git config user.email "noreply@forge.provvedo.com"
|
||||
git tag -a "${NEW_TAG}" -m "Release ${NEW_TAG}"
|
||||
git push origin "${NEW_TAG}"
|
||||
|
||||
FORGEJO_URL="${GITHUB_SERVER_URL}"
|
||||
REPO="${GITHUB_REPOSITORY}"
|
||||
|
||||
python3 -c "
|
||||
import json, urllib.request, os
|
||||
body = open('/tmp/changelog.md').read()
|
||||
tag = '${NEW_TAG}'
|
||||
data = json.dumps({
|
||||
'tag_name': tag,
|
||||
'name': tag,
|
||||
'body': body,
|
||||
'draft': False,
|
||||
'prerelease': False
|
||||
}).encode()
|
||||
req = urllib.request.Request(
|
||||
'${FORGEJO_URL}/api/v1/repos/${REPO}/releases',
|
||||
data=data,
|
||||
headers={
|
||||
'Authorization': 'token ' + os.environ['GITHUB_TOKEN'],
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
method='POST'
|
||||
)
|
||||
resp = urllib.request.urlopen(req)
|
||||
print(f'Created release {tag} (HTTP {resp.status})')
|
||||
"
|
||||
|
||||
docker:
|
||||
needs: release
|
||||
if: needs.release.outputs.skip != 'true'
|
||||
runs-on: docker
|
||||
container:
|
||||
image: catthehacker/ubuntu:act-latest
|
||||
options: --privileged
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Build and push image
|
||||
shell: bash
|
||||
env:
|
||||
REGISTRY_TOKEN: ${{ secrets.REGISTRY_TOKEN }}
|
||||
run: |
|
||||
VERSION="${{ needs.release.outputs.new_version }}"
|
||||
TAG="${{ needs.release.outputs.new_tag }}"
|
||||
REGISTRY=$(echo "${{ github.server_url }}" | sed 's|https://||; s|http://||')
|
||||
IMAGE="${REGISTRY}/${{ github.repository_owner }}/wiregui"
|
||||
|
||||
MAJOR=$(echo "$VERSION" | cut -d. -f1)
|
||||
MINOR=$(echo "$VERSION" | cut -d. -f2)
|
||||
|
||||
echo "Building ${IMAGE}:${TAG}"
|
||||
|
||||
# Log in to Forgejo container registry
|
||||
echo "${REGISTRY_TOKEN}" | docker login "${REGISTRY}" \
|
||||
-u "${{ github.repository_owner }}" --password-stdin
|
||||
|
||||
# Build the image
|
||||
docker build --network host \
|
||||
--build-arg "VERSION=${VERSION}" \
|
||||
-t "${IMAGE}:${TAG}" \
|
||||
-t "${IMAGE}:${MAJOR}.${MINOR}" \
|
||||
-t "${IMAGE}:latest" \
|
||||
.
|
||||
|
||||
# Push all tags
|
||||
docker push "${IMAGE}:${TAG}"
|
||||
docker push "${IMAGE}:${MAJOR}.${MINOR}"
|
||||
docker push "${IMAGE}:latest"
|
||||
|
||||
echo "Pushed ${IMAGE}:${TAG}, ${IMAGE}:${MAJOR}.${MINOR}, ${IMAGE}:latest"
|
||||
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
.venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.env
|
||||
.nicegui/
|
||||
logs/
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
|
|
@ -0,0 +1 @@
|
|||
3.13
|
||||
124
CLAUDE.md
Normal file
124
CLAUDE.md
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
# WireGUI
|
||||
|
||||
## Project Overview
|
||||
WireGUI is a Python rewrite of the Wirezone VPN management platform (Elixir/Phoenix).
|
||||
Original source: `/home/stefanob/PycharmProjects/personal/wirezone`
|
||||
|
||||
## Tech Stack
|
||||
- **UI Framework**: NiceGUI (reactive server-side UI over WebSocket)
|
||||
- **ORM/Models**: SQLModel (SQLAlchemy + Pydantic)
|
||||
- **Database**: PostgreSQL (via asyncpg)
|
||||
- **Cache/Sessions**: Valkey (Redis-compatible)
|
||||
- **Migrations**: Alembic
|
||||
- **REST API**: FastAPI (built into NiceGUI)
|
||||
- **Auth**: authlib (OIDC), python-jose (JWT), pyotp (TOTP), webauthn, bcrypt
|
||||
- **VPN**: subprocess calls to `wg` and `ip` commands
|
||||
- **Firewall**: python-nftables or subprocess `nft`
|
||||
- **Python**: 3.13
|
||||
- **Package Manager**: uv
|
||||
|
||||
## Development Setup
|
||||
```bash
|
||||
uv sync # Install dependencies
|
||||
docker compose up -d # Start PostgreSQL and Valkey
|
||||
alembic upgrade head # Run migrations
|
||||
uv run python -m wiregui.main # Start the application
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
```
|
||||
wiregui/
|
||||
├── main.py # NiceGUI entrypoint, mounts FastAPI, starts background tasks
|
||||
├── config.py # pydantic-settings: Settings class
|
||||
├── db.py # async SQLAlchemy engine + sessionmaker
|
||||
├── redis.py # Valkey connection pool
|
||||
├── models/ # SQLModel table definitions
|
||||
│ ├── user.py
|
||||
│ ├── device.py
|
||||
│ ├── rule.py
|
||||
│ ├── mfa_method.py
|
||||
│ ├── oidc_connection.py
|
||||
│ ├── api_token.py
|
||||
│ ├── connectivity_check.py
|
||||
│ └── configuration.py
|
||||
├── schemas/ # Pydantic request/response schemas (non-table)
|
||||
├── auth/ # Authentication modules
|
||||
│ ├── passwords.py # bcrypt hashing
|
||||
│ ├── jwt.py # JWT create/verify
|
||||
│ ├── session.py # NiceGUI session middleware
|
||||
│ ├── oidc.py # authlib OIDC
|
||||
│ ├── saml.py # python3-saml
|
||||
│ ├── mfa.py # TOTP + WebAuthn
|
||||
│ └── api_token.py # API token auth
|
||||
├── api/v0/ # REST API routers
|
||||
│ ├── users.py
|
||||
│ ├── devices.py
|
||||
│ ├── rules.py
|
||||
│ └── configuration.py
|
||||
├── pages/ # NiceGUI page definitions
|
||||
│ ├── layout.py # shared sidebar/header
|
||||
│ ├── login.py
|
||||
│ ├── devices.py # user device CRUD
|
||||
│ ├── account.py # user account/MFA
|
||||
│ ├── mfa_challenge.py
|
||||
│ └── admin/ # admin pages
|
||||
│ ├── users.py
|
||||
│ ├── devices.py
|
||||
│ ├── rules.py
|
||||
│ ├── settings.py
|
||||
│ └── diagnostics.py
|
||||
├── services/ # Core services
|
||||
│ ├── wireguard.py # WG interface management
|
||||
│ ├── firewall.py # nftables rule management
|
||||
│ ├── events.py # DB → WG/firewall bridge
|
||||
│ ├── notifications.py # in-memory notification queue
|
||||
│ └── email.py # aiosmtplib
|
||||
├── tasks/ # Background tasks
|
||||
│ ├── vpn_session.py # expire VPN sessions
|
||||
│ ├── stats.py # poll WG stats
|
||||
│ ├── connectivity.py # WAN connectivity checks
|
||||
│ └── oidc_refresh.py # refresh OIDC tokens
|
||||
└── utils/ # Utilities
|
||||
├── crypto.py # keypair gen, Fernet encrypt/decrypt
|
||||
├── network.py # IP allocation, CIDR validation
|
||||
└── validators.py # shared validators
|
||||
alembic/
|
||||
├── env.py
|
||||
├── script.py.mako
|
||||
└── versions/
|
||||
```
|
||||
|
||||
## Commands
|
||||
- `uv sync` — install/update dependencies
|
||||
- `uv run python -m wiregui.main` — run the app
|
||||
- `alembic revision --autogenerate -m "description"` — create migration
|
||||
- `alembic upgrade head` — apply all migrations
|
||||
- `alembic downgrade -1` — rollback last migration
|
||||
- `docker compose up -d` — start local Postgres + Valkey
|
||||
- `docker compose down` — stop local services
|
||||
- `pytest` — run tests
|
||||
|
||||
## Conventions
|
||||
- Use SQLModel for all database models (combines SQLAlchemy table + Pydantic validation)
|
||||
- Use async database sessions with asyncpg
|
||||
- Place all NiceGUI pages in `wiregui/pages/`
|
||||
- Place all SQLModel table models in `wiregui/models/`
|
||||
- Place Pydantic request/response schemas in `wiregui/schemas/`
|
||||
- Use Alembic autogenerate for migrations
|
||||
- Background tasks use asyncio (create_task + while/sleep pattern)
|
||||
- WireGuard/nftables managed via subprocess (asyncio.create_subprocess_exec)
|
||||
- DB is source of truth; WG/firewall state is reconciled on startup
|
||||
|
||||
## Logging — MANDATORY
|
||||
- **Use loguru for ALL logging and messages. No `print()` statements allowed anywhere in this project.**
|
||||
- Import: `from loguru import logger`
|
||||
- Use `logger.info()`, `logger.warning()`, `logger.error()`, `logger.debug()`, etc.
|
||||
- Loguru is configured in `wiregui/logging.py` via `setup_logging()`
|
||||
- When `WG_LOG_TO_FILE=true` (default), timestamped log files are written to `logs/` in the project root
|
||||
- The `logs/` directory is gitignored
|
||||
|
||||
## Testing
|
||||
- Tests live in `tests/` mirroring the `wiregui/` structure
|
||||
- Run with `uv run pytest`
|
||||
- Use `pytest-asyncio` for async tests
|
||||
- Test database: uses same Postgres instance, separate `wiregui_test` database
|
||||
55
Dockerfile
Normal file
55
Dockerfile
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
FROM python:3.13-slim AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install uv for fast dependency resolution
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||
|
||||
# Install system deps needed for building (wireguard-tools for wg CLI)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc libpq-dev wireguard-tools nftables iproute2 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy dependency files first for layer caching
|
||||
COPY pyproject.toml uv.lock* ./
|
||||
|
||||
# Install dependencies (production only, no dev group)
|
||||
RUN uv sync --no-dev --frozen 2>/dev/null || uv sync --no-dev
|
||||
|
||||
# Copy application code
|
||||
COPY wiregui/ wiregui/
|
||||
COPY alembic/ alembic/
|
||||
COPY alembic.ini ./
|
||||
|
||||
FROM python:3.13-slim AS runner
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Runtime dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
wireguard-tools nftables iproute2 libpq5 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy uv and virtualenv from builder
|
||||
COPY --from=builder /usr/local/bin/uv /usr/local/bin/uv
|
||||
COPY --from=builder /app/.venv /app/.venv
|
||||
COPY --from=builder /app/wiregui /app/wiregui
|
||||
COPY --from=builder /app/alembic /app/alembic
|
||||
COPY --from=builder /app/alembic.ini /app/alembic.ini
|
||||
COPY --from=builder /app/pyproject.toml /app/pyproject.toml
|
||||
|
||||
# Ensure the venv is on PATH
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Create logs directory
|
||||
RUN mkdir -p /app/logs
|
||||
|
||||
ARG VERSION=0.0.0-dev
|
||||
ENV WG_VERSION=$VERSION
|
||||
|
||||
EXPOSE 13000
|
||||
EXPOSE 51820/udp
|
||||
|
||||
# Run migrations then start the app
|
||||
CMD ["sh", "-c", "alembic upgrade head && python -m wiregui.main"]
|
||||
0
README.md
Normal file
0
README.md
Normal file
196
TODO.md
Normal file
196
TODO.md
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
# WireGUI Implementation TODO
|
||||
|
||||
Migration of Wirezone (Elixir/Phoenix) to Python/NiceGUI.
|
||||
Source: `/home/stefanob/PycharmProjects/personal/wirezone`
|
||||
|
||||
**Test count: 164 passing | Coverage: 35%**
|
||||
|
||||
---
|
||||
|
||||
## Phase 1: Foundation — Models, DB, Config ✅
|
||||
|
||||
- [x] `pyproject.toml` with dependencies, `uv sync`
|
||||
- [x] Package directory structure
|
||||
- [x] `wiregui/config.py` — pydantic-settings (DB, Redis, WG, auth, SMTP, logging)
|
||||
- [x] `wiregui/db.py` — async engine, sessionmaker, `init_db()`
|
||||
- [x] `wiregui/redis.py` — Valkey connection pool
|
||||
- [x] All 8 SQLModel models (User, Device, Rule, MFAMethod, OIDCConnection, ApiToken, ConnectivityCheck, Configuration)
|
||||
- [x] Alembic init + initial migration + `alembic upgrade head`
|
||||
- [x] `wiregui/main.py` — app entrypoint
|
||||
- [x] `compose.yml` — PostgreSQL 17 + Valkey 8
|
||||
- [x] `wiregui/utils/time.py` — `utcnow()` helper for naive UTC timestamps
|
||||
|
||||
---
|
||||
|
||||
## Phase 2: Auth System — Login + Sessions ✅
|
||||
|
||||
- [x] `wiregui/auth/passwords.py` — bcrypt hash/verify (direct bcrypt, not passlib)
|
||||
- [x] `wiregui/auth/jwt.py` — create/decode JWT via python-jose
|
||||
- [x] `wiregui/auth/session.py` — `authenticate_user()` email/password verification
|
||||
- [x] `wiregui/auth/middleware.py` — HTTP-level auth middleware (available for REST API)
|
||||
- [x] `wiregui/auth/seed.py` — auto-create admin on first startup
|
||||
- [x] `wiregui/pages/login.py` — login page with email/password form
|
||||
- [x] `wiregui/pages/home.py` — authenticated home (redirects to /devices)
|
||||
- [x] Auth guards via `app.storage.user` on each page
|
||||
- [x] Logout clears storage and redirects
|
||||
|
||||
---
|
||||
|
||||
## Phase 3: Device UI — User-Facing CRUD ✅
|
||||
|
||||
- [x] `wiregui/pages/layout.py` — shared sidebar + header
|
||||
- [x] `wiregui/utils/network.py` — IPv4/IPv6 allocation (random offset + scan)
|
||||
- [x] `wiregui/utils/crypto.py` — WG keypair + PSK generation via `wg` CLI
|
||||
- [x] `wiregui/utils/wg_conf.py` — WG client `.conf` builder
|
||||
- [x] `wiregui/pages/devices.py` — `/devices` list + create dialog + delete
|
||||
- [x] `/devices/{device_id}` — detail page with stats and config flags
|
||||
- [x] QR code generation + `.conf` download
|
||||
- [x] Full device create/edit form with all wirezone options (description, per-device config overrides, use_default_* toggles with bound inputs, better layout)
|
||||
|
||||
---
|
||||
|
||||
## Phase 4: WireGuard Integration ✅
|
||||
|
||||
- [x] `wiregui/services/wireguard.py` — async subprocess: ensure_interface, add/remove_peer, get_peers, set_private_key, set_listen_port
|
||||
- [x] `wiregui/services/events.py` — event bridge: device CRUD → WG + firewall
|
||||
- [x] Device create/delete in UI fires WG events
|
||||
- [x] `wiregui/tasks/__init__.py` — background task registry + cancel_all
|
||||
- [x] `wiregui/tasks/stats.py` — poll WG stats every 60s, update DB
|
||||
- [x] `wiregui/tasks/reconcile.py` — startup reconciliation (diff DB vs WG, add/remove)
|
||||
- [x] `config.py` — `wg_enabled` flag (default False for dev)
|
||||
- [x] Startup: ensure_interface → reconcile → stats_loop (when wg_enabled)
|
||||
|
||||
---
|
||||
|
||||
## Phase 5: Firewall (nftables) ✅
|
||||
|
||||
- [x] `wiregui/services/firewall.py` — nft CLI: setup_base_tables, masquerade, per-user chains, jump rules, apply_rule, rebuild_all_rules
|
||||
- [x] IPv4/IPv6 aware, TCP/UDP port range support
|
||||
- [x] `wiregui/pages/admin/rules.py` — `/admin/rules` CRUD (action, CIDR, protocol, port, user)
|
||||
- [x] Events: on_rule_created/deleted, on_device_created adds jump rules
|
||||
- [x] Startup: setup_base_tables + setup_masquerade (when wg_enabled)
|
||||
- [x] Edit rule — edit dialog in admin rules page with all fields
|
||||
- [x] Full user chain rebuild on rule update/delete via `_rebuild_user_chain()` in events.py
|
||||
|
||||
---
|
||||
|
||||
## Phase 6: REST API (v0) ✅
|
||||
|
||||
- [x] `wiregui/auth/api_token.py` — token generation (random → sha256), Bearer resolution with expiry + disabled user checks
|
||||
- [x] `wiregui/api/deps.py` — get_db, get_current_api_user, require_admin
|
||||
- [x] `wiregui/schemas/` — Pydantic schemas: UserRead/Create/Update, DeviceRead/Create/Update, RuleRead/Create/Update, ConfigurationRead/Update
|
||||
- [x] `wiregui/api/v0/users.py` — full CRUD (admin only)
|
||||
- [x] `wiregui/api/v0/devices.py` — full CRUD (owner or admin, triggers WG/firewall events)
|
||||
- [x] `wiregui/api/v0/rules.py` — full CRUD (admin only, triggers firewall events)
|
||||
- [x] `wiregui/api/v0/configuration.py` — GET/PUT (admin only, auto-creates singleton)
|
||||
- [x] Mounted on NiceGUI app at `/api/v0`
|
||||
|
||||
---
|
||||
|
||||
## Phase 7: Admin UI ✅
|
||||
|
||||
- [x] `/admin/users` — table (email, role, devices, status, last sign-in, method, created), create (email/password/role), edit (email/role/password/disabled), delete with cascading cleanup (devices → WG events, rules)
|
||||
- [x] `/admin/devices` — all devices with user filter, full create form (owner, name, description, all use_default_* toggles with bound override inputs), full edit form, delete with WG events, config + QR on creation
|
||||
- [x] `/admin/settings` — 3 tabs:
|
||||
- Client Defaults (endpoint, DNS, allowed IPs, MTU, keepalive)
|
||||
- Security (VPN session duration, local auth, unpriv device mgmt/config, OIDC auto-disable)
|
||||
- Authentication (OIDC provider CRUD with table + dialog; SAML placeholder for Phase 8)
|
||||
- [x] `/admin/diagnostics` — WG interface status, active peers, connectivity checks, system notifications with clear/clear-all
|
||||
- [x] `wiregui/services/notifications.py` — in-memory deque (capped at 100), add/clear/count/current
|
||||
- [x] Header notification bell badge (admin only, links to diagnostics)
|
||||
- [ ] **TODO:** SAML provider management in Authentication tab
|
||||
|
||||
---
|
||||
|
||||
## Phase 8: Advanced Auth (MFA, OIDC, Magic Links, SAML) ✅
|
||||
|
||||
- [x] TOTP MFA (`wiregui/auth/mfa.py`) — secret generation, URI/QR, verification with clock drift tolerance
|
||||
- [x] MFA challenge page (`/mfa`) — 6-digit code entry, multi-method support, last-used tracking
|
||||
- [x] Login page updated: checks for MFA methods after password auth, redirects to `/mfa` if present
|
||||
- [x] OIDC (`wiregui/auth/oidc.py`) — provider registry from Configuration, authlib Starlette integration
|
||||
- [x] OIDC routes (`/auth/oidc/{provider}` + `/auth/oidc/{provider}/callback`) — auth code flow, user lookup/auto-create, refresh token storage in OIDCConnection
|
||||
- [x] Login page shows OIDC provider buttons dynamically from config
|
||||
- [x] OIDC refresh task (`wiregui/tasks/oidc_refresh.py`) — every 10min, refreshes all stored tokens, creates notifications on failure, respects `disable_vpn_on_oidc_error`
|
||||
- [x] Magic links (`/auth/magic-link` + `/auth/magic/{user_id}/{token}`) — request page, signed JWT with 15min expiry, email via aiosmtplib
|
||||
- [x] Email service (`wiregui/services/email.py`) — aiosmtplib send, magic link template
|
||||
- [x] `/account` page — 3 tabs: Profile (details + password change), Two-Factor Auth (TOTP registration with QR + verification, list/delete methods), API Tokens (create with configurable expiry, list, delete)
|
||||
- [x] OIDC providers registered on startup from Configuration
|
||||
- [x] WebAuthn MFA (`wiregui/auth/webauthn.py`) — registration/authentication options generation, response verification, credential storage
|
||||
- [x] SAML (`wiregui/auth/saml.py` + `wiregui/pages/auth_saml.py`) — SP-initiated SSO, metadata endpoint, ACS callback, IdP metadata parsing, attribute mapping
|
||||
- [x] WebAuthn browser-side JS integration in account page — `ui.run_javascript()` calls `navigator.credentials.create()`, serializes response, server verifies and stores credential
|
||||
- [x] SAML provider management UI in admin settings Authentication tab — table + add/delete dialog (config ID, label, XML metadata, sign requests/metadata/assertions/envelopes toggles, auto-create users)
|
||||
|
||||
---
|
||||
|
||||
## Phase 9: Background Tasks & VPN Session Management
|
||||
|
||||
- [x] Task scheduler (`wiregui/tasks/__init__.py`) — register/cancel
|
||||
- [x] Stats polling task (Phase 4)
|
||||
- [x] OIDC refresh task (Phase 8)
|
||||
- [x] VPN session expiry task (`wiregui/tasks/vpn_session.py`) — every 60s, finds expired sessions based on `vpn_session_duration` + `last_signed_in_at`, removes WG peers, creates notifications
|
||||
- [x] Connectivity check poller (`wiregui/tasks/connectivity.py`) — fetches URL, stores result in DB, notification on failure
|
||||
- [x] Live stats push — `ui.timer(30, ...)` on `/devices` (table refresh), `/devices/{id}` (RX/TX/handshake/remote IP labels), `/admin/devices` (table refresh)
|
||||
|
||||
---
|
||||
|
||||
## Phase 10: Polish, Testing & Deployment
|
||||
|
||||
### Testing (partially done)
|
||||
- [x] pytest + pytest-asyncio setup, conftest with test DB
|
||||
- [x] test_models.py (10 tests), test_auth.py (8 tests), test_utils.py (6 tests), test_services.py (6 tests), test_firewall.py (7 tests)
|
||||
- [x] test_api.py (6 tests) — token generation, resolution, expiry, disabled user
|
||||
- [x] test_notifications.py (9 tests) — add, ordering, count, clear, max cap, to_dict
|
||||
- [x] test_admin.py (13 tests) — user CRUD, cascading deletes, config CRUD, OIDC providers, device overrides
|
||||
- [x] test_mfa.py (11 tests) — TOTP secret gen, URI, code verification (valid/invalid/wrong secret/empty), QR SVG, DB integration, multi-method
|
||||
- [x] test_magic_link.py (4 tests) — token creation/expiry/user mismatch, disabled user rejection
|
||||
- [x] test_account.py (8 tests) — password change flow, API token CRUD, OIDC connection CRUD, refresh token update
|
||||
- [x] test_integration_mfa.py (7 tests) — full TOTP registration flow, MFA blocks login, wrong code, multi-method, last-used tracking, delete allows bypass, disabled user
|
||||
- [x] test_integration_oidc.py (10 tests) — provider config loading, connection create/update, auto-create user, disabled user, refresh token, multi-provider
|
||||
- [x] test_tasks.py (6 tests) — VPN session expiry (expired/unlimited/no-config/disabled user), connectivity check (success/failure with notification)
|
||||
- [ ] HTTP-level integration tests (OIDC redirect/callback flow with respx mocking)
|
||||
|
||||
### Coverage gaps (35% overall — run `uv run pytest --cov=wiregui --cov-report=term-missing --cov-branch`)
|
||||
|
||||
**100% covered:** models, schemas, config, auth/passwords, auth/jwt, auth/mfa, auth/api_token, utils/crypto, utils/time, services/notifications
|
||||
|
||||
**API routes (32-84% — partially covered via httpx TestClient):**
|
||||
- [x] `wiregui/api/v0/users.py` (84%) — list/get/create/update/delete
|
||||
- [x] `wiregui/api/v0/rules.py` (71%) — CRUD
|
||||
- [x] `wiregui/api/v0/devices.py` (67%) — CRUD, permissions
|
||||
- [x] `wiregui/api/v0/configuration.py` (61%) — get/update, auto-create
|
||||
- [ ] `wiregui/api/deps.py` (32%) — test get_current_api_user with real Bearer header parsing, require_admin rejection
|
||||
|
||||
**Services (62-89% covered):**
|
||||
- [x] `wiregui/services/wireguard.py` (62%) — add/remove/get peers mocked
|
||||
- [x] `wiregui/services/firewall.py` (73%) — base tables, chains, rules, rebuild mocked
|
||||
- [x] `wiregui/services/events.py` (80%) — device + rule events, rebuild chain
|
||||
- [x] `wiregui/services/email.py` (89%) — send_email, magic link, no-smtp fallback
|
||||
- [ ] `wiregui/services/wireguard.py` — test ensure_interface, set_private_key, set_listen_port
|
||||
- [ ] `wiregui/services/firewall.py` — test _nft/_nft_batch error handling, add_device_jump_rule with only ipv4/ipv6
|
||||
|
||||
**Tasks (40-84% covered):**
|
||||
- [x] `wiregui/tasks/stats.py` (77%) — update from peers, no-op, unmatched peer
|
||||
- [x] `wiregui/tasks/reconcile.py` (84%) — add missing, remove orphaned, in-sync
|
||||
- [x] `wiregui/tasks/oidc_refresh.py` (40%) — no connections, skip unknown provider
|
||||
- [ ] `wiregui/tasks/oidc_refresh.py` — test successful refresh, failure with notification, disable_vpn_on_oidc_error
|
||||
|
||||
**Auth modules (85-92% covered):**
|
||||
- [x] `wiregui/auth/oidc.py` (87%) — register providers, get_client, load from config
|
||||
- [x] `wiregui/auth/webauthn.py` (85%) — registration/authentication options
|
||||
- [x] `wiregui/auth/session.py` (90%) — no-password, disabled, nonexistent user
|
||||
- [ ] `wiregui/auth/saml.py` (0%) — needs mock SAML IdP metadata + response parsing
|
||||
- [ ] `wiregui/auth/webauthn.py` — test verify_registration, verify_authentication with mock credential data
|
||||
|
||||
**Pages (0% — requires E2E testing):**
|
||||
- [ ] Consider Playwright or NiceGUI's testing utilities for E2E page tests
|
||||
|
||||
### Logging (done)
|
||||
- [x] Loguru configured (wiregui/logging.py), no print statements
|
||||
- [x] File logging to `logs/` when `WG_LOG_TO_FILE=true`
|
||||
|
||||
### Deployment
|
||||
- [ ] Dockerfile (multi-stage)
|
||||
- [ ] compose.prod.yml (app + postgres + valkey + caddy)
|
||||
- [ ] Health endpoint `GET /api/health`
|
||||
- [ ] First-run CLI setup command
|
||||
- [ ] README.md
|
||||
36
alembic.ini
Normal file
36
alembic.ini
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
[alembic]
|
||||
script_location = alembic
|
||||
prepend_sys_path = .
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
47
alembic/env.py
Normal file
47
alembic/env.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.models import * # noqa: F401, F403 — ensure all models are registered
|
||||
|
||||
config = context.config
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
target_metadata = SQLModel.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode — emit SQL to script output."""
|
||||
context.configure(
|
||||
url=get_settings().database_url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection) -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode — connect to the database."""
|
||||
engine = create_async_engine(get_settings().database_url)
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
asyncio.run(run_migrations_online())
|
||||
27
alembic/script.py.mako
Normal file
27
alembic/script.py.mako
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
"""add server keypair to configuration
|
||||
|
||||
Revision ID: 0741bc76e748
|
||||
Revises: 647a4418cc8c
|
||||
Create Date: 2026-03-30 15:37:19.276524
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0741bc76e748'
|
||||
down_revision: Union[str, None] = '647a4418cc8c'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('configurations', sa.Column('server_private_key', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
|
||||
op.add_column('configurations', sa.Column('server_public_key', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('configurations', 'server_public_key')
|
||||
op.drop_column('configurations', 'server_private_key')
|
||||
# ### end Alembic commands ###
|
||||
171
alembic/versions/647a4418cc8c_initial_schema.py
Normal file
171
alembic/versions/647a4418cc8c_initial_schema.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
"""initial schema
|
||||
|
||||
Revision ID: 647a4418cc8c
|
||||
Revises:
|
||||
Create Date: 2026-03-30 13:18:58.766259
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '647a4418cc8c'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('configurations',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('allow_unprivileged_device_management', sa.Boolean(), nullable=False),
|
||||
sa.Column('allow_unprivileged_device_configuration', sa.Boolean(), nullable=False),
|
||||
sa.Column('local_auth_enabled', sa.Boolean(), nullable=False),
|
||||
sa.Column('disable_vpn_on_oidc_error', sa.Boolean(), nullable=False),
|
||||
sa.Column('default_client_persistent_keepalive', sa.Integer(), nullable=False),
|
||||
sa.Column('default_client_mtu', sa.Integer(), nullable=False),
|
||||
sa.Column('default_client_endpoint', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('default_client_dns', sa.JSON(), nullable=True),
|
||||
sa.Column('default_client_allowed_ips', sa.JSON(), nullable=True),
|
||||
sa.Column('vpn_session_duration', sa.Integer(), nullable=False),
|
||||
sa.Column('logo_url', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('logo_type', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('openid_connect_providers', sa.JSON(), nullable=True),
|
||||
sa.Column('saml_identity_providers', sa.JSON(), nullable=True),
|
||||
sa.Column('inserted_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('connectivity_checks',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('url', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('response_code', sa.Integer(), nullable=True),
|
||||
sa.Column('response_headers', sa.JSON(), nullable=True),
|
||||
sa.Column('response_body', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('inserted_at', sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('users',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('password_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('role', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('last_signed_in_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('last_signed_in_method', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('sign_in_token_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('sign_in_token_created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('disabled_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('inserted_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
op.create_table('api_tokens',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('token_hash', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('inserted_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_api_tokens_token_hash'), 'api_tokens', ['token_hash'], unique=True)
|
||||
op.create_index(op.f('ix_api_tokens_user_id'), 'api_tokens', ['user_id'], unique=False)
|
||||
op.create_table('devices',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('public_key', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('preshared_key', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('use_default_allowed_ips', sa.Boolean(), nullable=False),
|
||||
sa.Column('use_default_dns', sa.Boolean(), nullable=False),
|
||||
sa.Column('use_default_endpoint', sa.Boolean(), nullable=False),
|
||||
sa.Column('use_default_mtu', sa.Boolean(), nullable=False),
|
||||
sa.Column('use_default_persistent_keepalive', sa.Boolean(), nullable=False),
|
||||
sa.Column('endpoint', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('mtu', sa.Integer(), nullable=True),
|
||||
sa.Column('persistent_keepalive', sa.Integer(), nullable=True),
|
||||
sa.Column('allowed_ips', sa.JSON(), nullable=True),
|
||||
sa.Column('dns', sa.JSON(), nullable=True),
|
||||
sa.Column('ipv4', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('ipv6', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('remote_ip', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('rx_bytes', sa.Integer(), nullable=True),
|
||||
sa.Column('tx_bytes', sa.Integer(), nullable=True),
|
||||
sa.Column('latest_handshake', sa.DateTime(), nullable=True),
|
||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('inserted_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('ipv4'),
|
||||
sa.UniqueConstraint('ipv6')
|
||||
)
|
||||
op.create_index(op.f('ix_devices_public_key'), 'devices', ['public_key'], unique=True)
|
||||
op.create_index(op.f('ix_devices_user_id'), 'devices', ['user_id'], unique=False)
|
||||
op.create_table('mfa_methods',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('payload', sa.JSON(), nullable=True),
|
||||
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('inserted_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_mfa_methods_user_id'), 'mfa_methods', ['user_id'], unique=False)
|
||||
op.create_table('oidc_connections',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('refresh_token', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('refresh_response', sa.JSON(), nullable=True),
|
||||
sa.Column('refreshed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('inserted_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_oidc_connections_user_id'), 'oidc_connections', ['user_id'], unique=False)
|
||||
op.create_table('rules',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('action', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('destination', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('port_type', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('port_range', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('user_id', sa.Uuid(), nullable=True),
|
||||
sa.Column('inserted_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_rules_user_id'), 'rules', ['user_id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_rules_user_id'), table_name='rules')
|
||||
op.drop_table('rules')
|
||||
op.drop_index(op.f('ix_oidc_connections_user_id'), table_name='oidc_connections')
|
||||
op.drop_table('oidc_connections')
|
||||
op.drop_index(op.f('ix_mfa_methods_user_id'), table_name='mfa_methods')
|
||||
op.drop_table('mfa_methods')
|
||||
op.drop_index(op.f('ix_devices_user_id'), table_name='devices')
|
||||
op.drop_index(op.f('ix_devices_public_key'), table_name='devices')
|
||||
op.drop_table('devices')
|
||||
op.drop_index(op.f('ix_api_tokens_user_id'), table_name='api_tokens')
|
||||
op.drop_index(op.f('ix_api_tokens_token_hash'), table_name='api_tokens')
|
||||
op.drop_table('api_tokens')
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
||||
op.drop_table('connectivity_checks')
|
||||
op.drop_table('configurations')
|
||||
# ### end Alembic commands ###
|
||||
63
compose.prod.yml
Normal file
63
compose.prod.yml
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
services:
|
||||
wiregui:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
image: wiregui:latest
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "13000:13000"
|
||||
- "51821:51821/udp"
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
- SYS_MODULE
|
||||
sysctls:
|
||||
- net.ipv4.ip_forward=1
|
||||
- net.ipv6.conf.all.forwarding=1
|
||||
- net.ipv6.conf.all.disable_ipv6=0
|
||||
environment:
|
||||
WG_DATABASE_URL: postgresql+asyncpg://wiregui:wiregui@postgres/wiregui
|
||||
WG_REDIS_URL: redis://valkey:6379/0
|
||||
WG_SECRET_KEY: ${WG_SECRET_KEY:-change-me-in-production}
|
||||
WG_WG_ENABLED: "true"
|
||||
WG_WG_ENDPOINT_HOST: ${WG_ENDPOINT_HOST:-vpn.example.com}
|
||||
WG_WG_ENDPOINT_PORT: "51821"
|
||||
WG_HOST: "0.0.0.0"
|
||||
WG_PORT: "13000"
|
||||
WG_EXTERNAL_URL: ${WG_EXTERNAL_URL:-http://localhost:13000}
|
||||
WG_ADMIN_EMAIL: ${WG_ADMIN_EMAIL:-admin@localhost}
|
||||
WG_ADMIN_PASSWORD: ${WG_ADMIN_PASSWORD:-}
|
||||
WG_LOG_TO_FILE: "true"
|
||||
volumes:
|
||||
- wiregui_logs:/app/logs
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
valkey:
|
||||
condition: service_started
|
||||
|
||||
postgres:
|
||||
image: postgres:17
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
POSTGRES_USER: wiregui
|
||||
POSTGRES_PASSWORD: wiregui
|
||||
POSTGRES_DB: wiregui
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U wiregui"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
valkey:
|
||||
image: valkey/valkey:8
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- valkey_data:/data
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
valkey_data:
|
||||
wiregui_logs:
|
||||
22
compose.yml
Normal file
22
compose.yml
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
services:
|
||||
postgres:
|
||||
image: postgres:17
|
||||
environment:
|
||||
POSTGRES_USER: wiregui
|
||||
POSTGRES_PASSWORD: wiregui
|
||||
POSTGRES_DB: wiregui
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
|
||||
valkey:
|
||||
image: valkey/valkey:8
|
||||
ports:
|
||||
- "6379:6379"
|
||||
volumes:
|
||||
- valkey_data:/data
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
valkey_data:
|
||||
49
pyproject.toml
Normal file
49
pyproject.toml
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
[project]
|
||||
name = "wiregui"
|
||||
version = "0.1.0"
|
||||
description = "WireGuard VPN management platform — Python/NiceGUI rewrite of Wirezone"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
# UI
|
||||
"nicegui>=2.12",
|
||||
# ORM & Database
|
||||
"sqlmodel>=0.0.22",
|
||||
"asyncpg>=0.30",
|
||||
"alembic>=1.14",
|
||||
# Configuration
|
||||
"pydantic-settings>=2.7",
|
||||
# Cache
|
||||
"redis>=5.2",
|
||||
# Encryption
|
||||
"cryptography>=44",
|
||||
# Auth
|
||||
"bcrypt>=4.0",
|
||||
"python-jose[cryptography]>=3.3",
|
||||
"authlib>=1.4",
|
||||
"pyotp>=2.9",
|
||||
"webauthn>=2.2",
|
||||
"python3-saml>=1.16",
|
||||
# HTTP client
|
||||
"httpx>=0.28",
|
||||
# Email
|
||||
"aiosmtplib>=3.0",
|
||||
# QR codes
|
||||
"qrcode[pil]>=8.0",
|
||||
# Logging
|
||||
"loguru>=0.7.3",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.24",
|
||||
"pytest-cov>=7.1.0",
|
||||
"respx>=0.22.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
asyncio_default_test_loop_scope = "session"
|
||||
testpaths = ["tests"]
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
65
tests/conftest.py
Normal file
65
tests/conftest.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
"""Shared test fixtures — async DB session using a test database."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from wiregui.config import get_settings
|
||||
|
||||
# All models must be imported so SQLModel.metadata knows about them
|
||||
from wiregui.models import * # noqa: F401, F403
|
||||
|
||||
|
||||
def _test_database_url() -> str:
|
||||
url = get_settings().database_url
|
||||
base, _dbname = url.rsplit("/", 1)
|
||||
return f"{base}/wiregui_test"
|
||||
|
||||
|
||||
TEST_DATABASE_URL = _test_database_url()
|
||||
|
||||
# Module-level engine creation (runs once via autouse session fixture)
|
||||
_engine = None
|
||||
|
||||
|
||||
def _ensure_test_db_sync():
|
||||
"""Ensure wiregui_test database exists (called once)."""
|
||||
import asyncio
|
||||
|
||||
async def _create():
|
||||
base_url = get_settings().database_url.rsplit("/", 1)[0] + "/postgres"
|
||||
admin_engine = create_async_engine(base_url, isolation_level="AUTOCOMMIT")
|
||||
async with admin_engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = 'wiregui_test'")
|
||||
)
|
||||
if result.scalar() is None:
|
||||
await conn.execute(text("CREATE DATABASE wiregui_test"))
|
||||
await admin_engine.dispose()
|
||||
|
||||
asyncio.run(_create())
|
||||
|
||||
|
||||
# Create test DB once at import time
|
||||
_ensure_test_db_sync()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def session() -> AsyncGenerator[AsyncSession]:
|
||||
"""Fresh engine + session per test, with table setup/teardown."""
|
||||
engine = create_async_engine(TEST_DATABASE_URL)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
async with factory() as sess:
|
||||
yield sess
|
||||
await sess.rollback()
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
await engine.dispose()
|
||||
161
tests/test_account.py
Normal file
161
tests/test_account.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
"""Tests for account functionality — password changes, API tokens, OIDC connections."""
|
||||
|
||||
import hashlib
|
||||
from datetime import timedelta
|
||||
|
||||
from sqlmodel import func, select
|
||||
|
||||
from wiregui.auth.api_token import generate_api_token
|
||||
from wiregui.auth.passwords import hash_password, verify_password
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# --- Password change ---
|
||||
|
||||
|
||||
async def test_password_change_flow(session):
|
||||
"""Simulate the password change flow: verify old, set new."""
|
||||
user = User(email="pw-change@example.com", password_hash=hash_password("old-password"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# Verify old password
|
||||
assert verify_password("old-password", user.password_hash) is True
|
||||
|
||||
# Change password
|
||||
user.password_hash = hash_password("new-password")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert verify_password("new-password", fetched.password_hash) is True
|
||||
assert verify_password("old-password", fetched.password_hash) is False
|
||||
|
||||
|
||||
async def test_password_change_wrong_current(session):
|
||||
"""Wrong current password should not allow change."""
|
||||
user = User(email="pw-wrong@example.com", password_hash=hash_password("correct"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# Simulate check
|
||||
assert verify_password("wrong", user.password_hash) is False
|
||||
|
||||
|
||||
# --- API token management ---
|
||||
|
||||
|
||||
async def test_create_multiple_tokens(session):
|
||||
user = User(email="multi-token@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
for _ in range(3):
|
||||
_, token_hash = generate_api_token()
|
||||
session.add(ApiToken(token_hash=token_hash, user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(ApiToken).where(ApiToken.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 3
|
||||
|
||||
|
||||
async def test_token_with_expiry(session):
|
||||
user = User(email="expiry-token@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
_, token_hash = generate_api_token()
|
||||
expires = utcnow() + timedelta(days=30)
|
||||
token = ApiToken(token_hash=token_hash, expires_at=expires, user_id=user.id)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(ApiToken, token.id)
|
||||
assert fetched.expires_at is not None
|
||||
assert fetched.expires_at > utcnow()
|
||||
|
||||
|
||||
async def test_delete_token(session):
|
||||
user = User(email="del-token@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
_, token_hash = generate_api_token()
|
||||
token = ApiToken(token_hash=token_hash, user_id=user.id)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
await session.delete(token)
|
||||
await session.flush()
|
||||
|
||||
assert await session.get(ApiToken, token.id) is None
|
||||
|
||||
|
||||
# --- OIDC connections ---
|
||||
|
||||
|
||||
async def test_oidc_connection_create(session):
|
||||
user = User(email="oidc-conn@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(
|
||||
provider="google",
|
||||
refresh_token="refresh-tok-123",
|
||||
refresh_response={"access_token": "at", "token_type": "Bearer"},
|
||||
refreshed_at=utcnow(),
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar_one()
|
||||
assert fetched.provider == "google"
|
||||
assert fetched.refresh_token == "refresh-tok-123"
|
||||
assert fetched.refresh_response["access_token"] == "at"
|
||||
|
||||
|
||||
async def test_multiple_oidc_providers(session):
|
||||
user = User(email="multi-oidc@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
for provider in ["google", "okta", "azure"]:
|
||||
conn = OIDCConnection(provider=provider, user_id=user.id)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 3
|
||||
|
||||
|
||||
async def test_oidc_connection_update_refresh_token(session):
|
||||
user = User(email="oidc-refresh@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(
|
||||
provider="google",
|
||||
refresh_token="old-token",
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
conn.refresh_token = "new-token"
|
||||
conn.refreshed_at = utcnow()
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(OIDCConnection, conn.id)
|
||||
assert fetched.refresh_token == "new-token"
|
||||
assert fetched.refreshed_at is not None
|
||||
283
tests/test_admin.py
Normal file
283
tests/test_admin.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
"""Tests for admin functionality — user management, configuration, cascading deletes."""
|
||||
|
||||
import pytest
|
||||
from sqlmodel import func, select
|
||||
|
||||
from wiregui.auth.passwords import hash_password, verify_password
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# --- User CRUD ---
|
||||
|
||||
|
||||
async def test_create_user_with_role(session):
|
||||
user = User(email="new-admin@test.com", password_hash=hash_password("secret"), role="admin")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert fetched.role == "admin"
|
||||
assert verify_password("secret", fetched.password_hash)
|
||||
|
||||
|
||||
async def test_update_user_email(session):
|
||||
user = User(email="old@test.com", password_hash=hash_password("pw"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
user.email = "new@test.com"
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert fetched.email == "new@test.com"
|
||||
|
||||
|
||||
async def test_disable_user(session):
|
||||
user = User(email="active@test.com", password_hash=hash_password("pw"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
assert user.disabled_at is None
|
||||
|
||||
user.disabled_at = utcnow()
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert fetched.disabled_at is not None
|
||||
|
||||
|
||||
async def test_promote_demote_user(session):
|
||||
user = User(email="user@test.com", role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
assert user.role == "unprivileged"
|
||||
|
||||
user.role = "admin"
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(User, user.id)
|
||||
assert fetched.role == "admin"
|
||||
|
||||
user.role = "unprivileged"
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
assert (await session.get(User, user.id)).role == "unprivileged"
|
||||
|
||||
|
||||
# --- Cascading delete (manual, as we do it in the admin page) ---
|
||||
|
||||
|
||||
async def test_delete_user_cascades_devices(session):
|
||||
user = User(email="cascade@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
d1 = Device(name="d1", public_key="pk-cascade-1", ipv4="10.0.0.1", user_id=user.id)
|
||||
d2 = Device(name="d2", public_key="pk-cascade-2", ipv4="10.0.0.2", user_id=user.id)
|
||||
session.add_all([d1, d2])
|
||||
await session.flush()
|
||||
|
||||
# Manually delete devices then user (matching admin page behavior)
|
||||
devices = (await session.execute(select(Device).where(Device.user_id == user.id))).scalars().all()
|
||||
for d in devices:
|
||||
await session.delete(d)
|
||||
await session.delete(user)
|
||||
await session.flush()
|
||||
|
||||
assert (await session.execute(select(func.count()).select_from(Device).where(Device.user_id == user.id))).scalar() == 0
|
||||
assert await session.get(User, user.id) is None
|
||||
|
||||
|
||||
async def test_delete_user_cascades_rules(session):
|
||||
user = User(email="rule-cascade@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id)
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
|
||||
# Delete rules then user
|
||||
rules = (await session.execute(select(Rule).where(Rule.user_id == user.id))).scalars().all()
|
||||
for r in rules:
|
||||
await session.delete(r)
|
||||
await session.delete(user)
|
||||
await session.flush()
|
||||
|
||||
assert (await session.execute(select(func.count()).select_from(Rule).where(Rule.user_id == user.id))).scalar() == 0
|
||||
|
||||
|
||||
# --- Configuration singleton ---
|
||||
|
||||
|
||||
async def test_configuration_create_and_update(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
assert config.default_client_mtu == 1280
|
||||
assert config.local_auth_enabled is True
|
||||
|
||||
config.default_client_mtu = 1400
|
||||
config.local_auth_enabled = False
|
||||
config.vpn_session_duration = 3600
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Configuration, config.id)
|
||||
assert fetched.default_client_mtu == 1400
|
||||
assert fetched.local_auth_enabled is False
|
||||
assert fetched.vpn_session_duration == 3600
|
||||
|
||||
|
||||
async def test_configuration_oidc_providers(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
assert config.openid_connect_providers == []
|
||||
|
||||
providers = [
|
||||
{
|
||||
"id": "google",
|
||||
"label": "Sign in with Google",
|
||||
"scope": "openid email profile",
|
||||
"response_type": "code",
|
||||
"client_id": "google-client-id",
|
||||
"client_secret": "google-secret",
|
||||
"discovery_document_uri": "https://accounts.google.com/.well-known/openid-configuration",
|
||||
"auto_create_users": True,
|
||||
},
|
||||
{
|
||||
"id": "okta",
|
||||
"label": "Okta SSO",
|
||||
"scope": "openid email profile",
|
||||
"response_type": "code",
|
||||
"client_id": "okta-client-id",
|
||||
"client_secret": "okta-secret",
|
||||
"discovery_document_uri": "https://dev-123.okta.com/.well-known/openid-configuration",
|
||||
"auto_create_users": False,
|
||||
},
|
||||
]
|
||||
config.openid_connect_providers = providers
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Configuration, config.id)
|
||||
assert len(fetched.openid_connect_providers) == 2
|
||||
assert fetched.openid_connect_providers[0]["id"] == "google"
|
||||
assert fetched.openid_connect_providers[1]["auto_create_users"] is False
|
||||
|
||||
|
||||
async def test_configuration_update_client_defaults(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
config.default_client_endpoint = "vpn.example.com"
|
||||
config.default_client_dns = ["8.8.8.8", "8.8.4.4"]
|
||||
config.default_client_allowed_ips = ["10.0.0.0/8"]
|
||||
config.default_client_persistent_keepalive = 30
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Configuration, config.id)
|
||||
assert fetched.default_client_endpoint == "vpn.example.com"
|
||||
assert fetched.default_client_dns == ["8.8.8.8", "8.8.4.4"]
|
||||
assert fetched.default_client_allowed_ips == ["10.0.0.0/8"]
|
||||
assert fetched.default_client_persistent_keepalive == 30
|
||||
|
||||
|
||||
async def test_configuration_security_toggles(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
config.allow_unprivileged_device_management = False
|
||||
config.allow_unprivileged_device_configuration = False
|
||||
config.disable_vpn_on_oidc_error = True
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Configuration, config.id)
|
||||
assert fetched.allow_unprivileged_device_management is False
|
||||
assert fetched.allow_unprivileged_device_configuration is False
|
||||
assert fetched.disable_vpn_on_oidc_error is True
|
||||
|
||||
|
||||
# --- Device config overrides ---
|
||||
|
||||
|
||||
async def test_device_with_custom_config(session):
|
||||
user = User(email="config-user@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(
|
||||
name="custom-config",
|
||||
public_key="pk-custom-config",
|
||||
user_id=user.id,
|
||||
use_default_dns=False,
|
||||
use_default_endpoint=False,
|
||||
use_default_mtu=False,
|
||||
use_default_persistent_keepalive=False,
|
||||
use_default_allowed_ips=False,
|
||||
dns=["8.8.8.8"],
|
||||
endpoint="custom-vpn.example.com",
|
||||
mtu=1400,
|
||||
persistent_keepalive=15,
|
||||
allowed_ips=["10.0.0.0/8", "172.16.0.0/12"],
|
||||
)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Device, device.id)
|
||||
assert fetched.use_default_dns is False
|
||||
assert fetched.dns == ["8.8.8.8"]
|
||||
assert fetched.endpoint == "custom-vpn.example.com"
|
||||
assert fetched.mtu == 1400
|
||||
assert fetched.persistent_keepalive == 15
|
||||
assert fetched.allowed_ips == ["10.0.0.0/8", "172.16.0.0/12"]
|
||||
|
||||
|
||||
async def test_device_default_flags_are_true(session):
|
||||
user = User(email="defaults@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(name="defaults", public_key="pk-defaults", user_id=user.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(Device, device.id)
|
||||
assert fetched.use_default_allowed_ips is True
|
||||
assert fetched.use_default_dns is True
|
||||
assert fetched.use_default_endpoint is True
|
||||
assert fetched.use_default_mtu is True
|
||||
assert fetched.use_default_persistent_keepalive is True
|
||||
|
||||
|
||||
# --- User device count ---
|
||||
|
||||
|
||||
async def test_user_device_count_query(session):
|
||||
user = User(email="count-user@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
for i in range(3):
|
||||
session.add(Device(name=f"d{i}", public_key=f"pk-count-{i}", user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(Device).where(Device.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 3
|
||||
86
tests/test_api.py
Normal file
86
tests/test_api.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""Tests for REST API endpoints and token auth."""
|
||||
|
||||
import hashlib
|
||||
|
||||
from wiregui.auth.api_token import generate_api_token, resolve_bearer_token
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# --- Token generation ---
|
||||
|
||||
|
||||
def test_generate_api_token():
|
||||
plaintext, token_hash = generate_api_token()
|
||||
assert len(plaintext) > 20
|
||||
assert token_hash == hashlib.sha256(plaintext.encode()).hexdigest()
|
||||
|
||||
|
||||
def test_generate_api_token_unique():
|
||||
t1, h1 = generate_api_token()
|
||||
t2, h2 = generate_api_token()
|
||||
assert t1 != t2
|
||||
assert h1 != h2
|
||||
|
||||
|
||||
# --- Token resolution ---
|
||||
|
||||
|
||||
async def test_resolve_valid_token(session):
|
||||
user = User(email="api-user@example.com", password_hash=hash_password("x"), role="admin")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
token = ApiToken(token_hash=token_hash, user_id=user.id)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is not None
|
||||
assert resolved.id == user.id
|
||||
|
||||
|
||||
async def test_resolve_invalid_token(session):
|
||||
resolved = await resolve_bearer_token(session, "bogus-token")
|
||||
assert resolved is None
|
||||
|
||||
|
||||
async def test_resolve_expired_token(session):
|
||||
from datetime import timedelta
|
||||
|
||||
user = User(email="expired-api@example.com", password_hash=hash_password("x"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
token = ApiToken(
|
||||
token_hash=token_hash,
|
||||
user_id=user.id,
|
||||
expires_at=utcnow() - timedelta(hours=1),
|
||||
)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is None
|
||||
|
||||
|
||||
async def test_resolve_token_disabled_user(session):
|
||||
user = User(
|
||||
email="disabled-api@example.com",
|
||||
password_hash=hash_password("x"),
|
||||
disabled_at=utcnow(),
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
token = ApiToken(token_hash=token_hash, user_id=user.id)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
resolved = await resolve_bearer_token(session, plaintext)
|
||||
assert resolved is None
|
||||
325
tests/test_api_routes.py
Normal file
325
tests/test_api_routes.py
Normal file
|
|
@ -0,0 +1,325 @@
|
|||
"""Tests for REST API routes via httpx AsyncClient against the FastAPI app."""
|
||||
|
||||
import hashlib
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.api.deps import get_current_api_user, get_db, require_admin
|
||||
from wiregui.api.v0 import router as api_router
|
||||
from wiregui.auth.api_token import generate_api_token
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
def _build_app(session, admin_user=None, regular_user=None):
|
||||
"""Build a test FastAPI app with overridden dependencies."""
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(api_router, prefix="/api")
|
||||
|
||||
async def override_get_db():
|
||||
yield session
|
||||
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
if admin_user:
|
||||
test_app.dependency_overrides[get_current_api_user] = lambda: admin_user
|
||||
test_app.dependency_overrides[require_admin] = lambda: admin_user
|
||||
|
||||
return test_app
|
||||
|
||||
|
||||
async def _make_admin(session) -> User:
|
||||
user = User(email="api-admin@test.com", password_hash=hash_password("pw"), role="admin")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
return user
|
||||
|
||||
|
||||
async def _make_user(session, email="api-user@test.com") -> User:
|
||||
user = User(email=email, password_hash=hash_password("pw"), role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
return user
|
||||
|
||||
|
||||
# ========== Users API ==========
|
||||
|
||||
|
||||
async def test_list_users(session):
|
||||
admin = await _make_admin(session)
|
||||
await _make_user(session, "user1@test.com")
|
||||
await _make_user(session, "user2@test.com")
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/users/")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 3 # admin + 2 users
|
||||
|
||||
|
||||
async def test_get_user(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get(f"/api/v0/users/{admin.id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["email"] == "api-admin@test.com"
|
||||
|
||||
|
||||
async def test_get_user_not_found(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get(f"/api/v0/users/{uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_create_user(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.post("/api/v0/users/", json={
|
||||
"email": "new-api-user@test.com",
|
||||
"password": "secret123",
|
||||
"role": "unprivileged",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["email"] == "new-api-user@test.com"
|
||||
assert data["role"] == "unprivileged"
|
||||
assert "id" in data
|
||||
|
||||
|
||||
async def test_update_user(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put(f"/api/v0/users/{user.id}", json={
|
||||
"role": "admin",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["role"] == "admin"
|
||||
|
||||
|
||||
async def test_update_user_password(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put(f"/api/v0/users/{user.id}", json={
|
||||
"password": "new-password-123",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
from wiregui.auth.passwords import verify_password
|
||||
refreshed = await session.get(User, user.id)
|
||||
assert verify_password("new-password-123", refreshed.password_hash)
|
||||
|
||||
|
||||
async def test_delete_user(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.delete(f"/api/v0/users/{user.id}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
assert await session.get(User, user.id) is None
|
||||
|
||||
|
||||
# ========== Devices API ==========
|
||||
|
||||
|
||||
async def test_list_devices_admin_sees_all(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session)
|
||||
session.add(Device(name="d1", public_key="pk-api-d1", user_id=admin.id))
|
||||
session.add(Device(name="d2", public_key="pk-api-d2", user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/devices/")
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) >= 2
|
||||
|
||||
|
||||
async def test_list_devices_user_sees_own(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session, "own-devices@test.com")
|
||||
session.add(Device(name="mine", public_key="pk-api-mine", user_id=user.id))
|
||||
session.add(Device(name="not-mine", public_key="pk-api-notmine", user_id=admin.id))
|
||||
await session.flush()
|
||||
|
||||
# Override to be the regular user
|
||||
test_app = _build_app(session)
|
||||
test_app.dependency_overrides[get_current_api_user] = lambda: user
|
||||
async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/devices/")
|
||||
assert resp.status_code == 200
|
||||
names = [d["name"] for d in resp.json()]
|
||||
assert "mine" in names
|
||||
assert "not-mine" not in names
|
||||
|
||||
|
||||
async def test_get_device(session):
|
||||
admin = await _make_admin(session)
|
||||
device = Device(name="detail", public_key="pk-api-detail", user_id=admin.id, ipv4="10.0.0.5")
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get(f"/api/v0/devices/{device.id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "detail"
|
||||
assert resp.json()["ipv4"] == "10.0.0.5"
|
||||
|
||||
|
||||
async def test_get_device_forbidden_for_other_user(session):
|
||||
admin = await _make_admin(session)
|
||||
user = await _make_user(session, "other-dev@test.com")
|
||||
device = Device(name="admin-dev", public_key="pk-api-forbid", user_id=admin.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
test_app = _build_app(session)
|
||||
test_app.dependency_overrides[get_current_api_user] = lambda: user
|
||||
async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
|
||||
resp = await client.get(f"/api/v0/devices/{device.id}")
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
async def test_update_device(session):
|
||||
admin = await _make_admin(session)
|
||||
device = Device(name="old-name", public_key="pk-api-update", user_id=admin.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put(f"/api/v0/devices/{device.id}", json={"name": "new-name"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "new-name"
|
||||
|
||||
|
||||
async def test_delete_device(session):
|
||||
admin = await _make_admin(session)
|
||||
device = Device(name="to-delete", public_key="pk-api-del", user_id=admin.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
did = device.id
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.delete(f"/api/v0/devices/{did}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
assert await session.get(Device, did) is None
|
||||
|
||||
|
||||
# ========== Rules API ==========
|
||||
|
||||
|
||||
async def test_list_rules(session):
|
||||
admin = await _make_admin(session)
|
||||
session.add(Rule(action="accept", destination="10.0.0.0/8"))
|
||||
session.add(Rule(action="drop", destination="192.168.0.0/16", user_id=admin.id))
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/rules/")
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()) >= 2
|
||||
|
||||
|
||||
async def test_create_rule(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.post("/api/v0/rules/", json={
|
||||
"action": "accept",
|
||||
"destination": "172.16.0.0/12",
|
||||
"port_type": "tcp",
|
||||
"port_range": "443",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["action"] == "accept"
|
||||
assert data["destination"] == "172.16.0.0/12"
|
||||
assert data["port_type"] == "tcp"
|
||||
assert data["port_range"] == "443"
|
||||
|
||||
|
||||
async def test_update_rule(session):
|
||||
admin = await _make_admin(session)
|
||||
rule = Rule(action="accept", destination="10.0.0.0/8")
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put(f"/api/v0/rules/{rule.id}", json={"action": "drop"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["action"] == "drop"
|
||||
|
||||
|
||||
async def test_delete_rule(session):
|
||||
admin = await _make_admin(session)
|
||||
rule = Rule(action="drop", destination="0.0.0.0/0")
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
rid = rule.id
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.delete(f"/api/v0/rules/{rid}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
assert await session.get(Rule, rid) is None
|
||||
|
||||
|
||||
# ========== Configuration API ==========
|
||||
|
||||
|
||||
async def test_get_configuration_auto_creates(session):
|
||||
admin = await _make_admin(session)
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/api/v0/configuration/")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["default_client_mtu"] == 1280
|
||||
assert data["local_auth_enabled"] is True
|
||||
|
||||
|
||||
async def test_update_configuration(session):
|
||||
admin = await _make_admin(session)
|
||||
# Pre-create config
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
app = _build_app(session, admin_user=admin)
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.put("/api/v0/configuration/", json={
|
||||
"default_client_mtu": 1400,
|
||||
"vpn_session_duration": 3600,
|
||||
"default_client_dns": ["8.8.8.8"],
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["default_client_mtu"] == 1400
|
||||
assert data["vpn_session_duration"] == 3600
|
||||
assert data["default_client_dns"] == ["8.8.8.8"]
|
||||
98
tests/test_auth.py
Normal file
98
tests/test_auth.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""Tests for authentication modules."""
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.jwt import create_access_token, decode_access_token
|
||||
from wiregui.auth.passwords import hash_password, verify_password
|
||||
from wiregui.auth.seed import seed_admin
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
# --- Password hashing ---
|
||||
|
||||
|
||||
def test_hash_and_verify():
|
||||
hashed = hash_password("my-secret")
|
||||
assert verify_password("my-secret", hashed) is True
|
||||
|
||||
|
||||
def test_verify_wrong_password():
|
||||
hashed = hash_password("correct")
|
||||
assert verify_password("wrong", hashed) is False
|
||||
|
||||
|
||||
def test_hash_is_not_plaintext():
|
||||
hashed = hash_password("plaintext")
|
||||
assert hashed != "plaintext"
|
||||
assert hashed.startswith("$2b$")
|
||||
|
||||
|
||||
# --- JWT ---
|
||||
|
||||
|
||||
def test_create_and_decode_token():
|
||||
token = create_access_token(user_id="user-123", role="admin")
|
||||
payload = decode_access_token(token)
|
||||
assert payload is not None
|
||||
assert payload["sub"] == "user-123"
|
||||
assert payload["role"] == "admin"
|
||||
assert "exp" in payload
|
||||
|
||||
|
||||
def test_decode_invalid_token():
|
||||
assert decode_access_token("garbage.token.value") is None
|
||||
|
||||
|
||||
def test_decode_tampered_token():
|
||||
token = create_access_token(user_id="user-123", role="admin")
|
||||
tampered = token[:-4] + "XXXX"
|
||||
assert decode_access_token(tampered) is None
|
||||
|
||||
|
||||
# --- Admin seed ---
|
||||
|
||||
|
||||
async def test_seed_admin_creates_user(session, monkeypatch):
|
||||
"""seed_admin should create an admin when no users exist."""
|
||||
# Patch async_session to use our test session
|
||||
from unittest.mock import AsyncMock
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.seed.async_session", mock_session)
|
||||
monkeypatch.setattr("wiregui.auth.seed.get_settings", lambda: type("S", (), {
|
||||
"admin_email": "seed-test@example.com",
|
||||
"admin_password": "seed-pass-123",
|
||||
})())
|
||||
|
||||
await seed_admin()
|
||||
|
||||
result = await session.execute(select(User).where(User.email == "seed-test@example.com"))
|
||||
admin = result.scalar_one()
|
||||
assert admin.role == "admin"
|
||||
assert verify_password("seed-pass-123", admin.password_hash)
|
||||
|
||||
|
||||
async def test_seed_admin_skips_when_users_exist(session, monkeypatch):
|
||||
"""seed_admin should not create a second admin if users already exist."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
existing = User(email="existing@example.com", role="unprivileged")
|
||||
session.add(existing)
|
||||
await session.flush()
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.seed.async_session", mock_session)
|
||||
|
||||
await seed_admin()
|
||||
|
||||
result = await session.execute(select(User))
|
||||
users = result.scalars().all()
|
||||
assert len(users) == 1
|
||||
assert users[0].email == "existing@example.com"
|
||||
226
tests/test_auth_extended.py
Normal file
226
tests/test_auth_extended.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""Extended auth tests — OIDC registration, WebAuthn options, session edge cases."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.auth.session import authenticate_user
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# ========== Session / authenticate_user edge cases ==========
|
||||
|
||||
|
||||
async def test_authenticate_user_no_password_hash(session, monkeypatch):
|
||||
"""Users without a password (OIDC-only) should not authenticate via password."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
user = User(email="no-pw@test.com", password_hash=None)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
result = await authenticate_user("no-pw@test.com", "anything")
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_authenticate_user_disabled(session, monkeypatch):
|
||||
"""Disabled users should not authenticate."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
user = User(email="disabled-auth@test.com", password_hash=hash_password("pw"), disabled_at=utcnow())
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
result = await authenticate_user("disabled-auth@test.com", "pw")
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_authenticate_user_nonexistent(session, monkeypatch):
|
||||
"""Nonexistent email should return None."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
result = await authenticate_user("ghost@nowhere.com", "pw")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ========== OIDC provider registration ==========
|
||||
|
||||
|
||||
async def test_register_providers_from_config(session, monkeypatch):
|
||||
"""register_providers should register configured OIDC providers with authlib."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||||
|
||||
from wiregui.models.configuration import Configuration
|
||||
config = Configuration(openid_connect_providers=[
|
||||
{
|
||||
"id": "test-reg",
|
||||
"label": "Test",
|
||||
"scope": "openid email",
|
||||
"client_id": "cid",
|
||||
"client_secret": "cs",
|
||||
"discovery_document_uri": "https://idp.test/.well-known/openid-configuration",
|
||||
}
|
||||
])
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
with patch("wiregui.auth.oidc.oauth") as mock_oauth:
|
||||
from wiregui.auth.oidc import register_providers
|
||||
await register_providers()
|
||||
mock_oauth.register.assert_called_once()
|
||||
call_kwargs = mock_oauth.register.call_args[1]
|
||||
assert call_kwargs["name"] == "test-reg"
|
||||
assert call_kwargs["client_id"] == "cid"
|
||||
|
||||
|
||||
async def test_get_client_unknown_provider():
|
||||
"""get_client should raise for unregistered providers."""
|
||||
import pytest
|
||||
from wiregui.auth.oidc import get_client
|
||||
with pytest.raises(ValueError, match="not registered"):
|
||||
get_client("nonexistent-provider-xyz")
|
||||
|
||||
|
||||
# ========== WebAuthn options ==========
|
||||
|
||||
|
||||
def test_webauthn_registration_options(monkeypatch):
|
||||
"""create_registration_options should return valid options and challenge."""
|
||||
monkeypatch.setattr("wiregui.auth.webauthn.get_settings", lambda: type("S", (), {
|
||||
"external_url": "https://vpn.example.com",
|
||||
})())
|
||||
|
||||
from wiregui.auth.webauthn import create_registration_options
|
||||
user_id = uuid4()
|
||||
result = create_registration_options(user_id, "user@example.com")
|
||||
|
||||
assert "options_json" in result
|
||||
assert "challenge" in result
|
||||
assert len(result["challenge"]) > 10
|
||||
assert "user@example.com" in result["options_json"]
|
||||
|
||||
|
||||
def test_webauthn_registration_options_with_excludes(monkeypatch):
|
||||
"""Existing credentials should be excluded from registration options."""
|
||||
monkeypatch.setattr("wiregui.auth.webauthn.get_settings", lambda: type("S", (), {
|
||||
"external_url": "https://vpn.example.com",
|
||||
})())
|
||||
|
||||
from wiregui.auth.webauthn import create_registration_options
|
||||
existing = [{"credential_id": "AQIDBA"}] # base64url of bytes [1,2,3,4]
|
||||
result = create_registration_options(uuid4(), "user@example.com", existing)
|
||||
assert "options_json" in result
|
||||
|
||||
|
||||
def test_webauthn_authentication_options(monkeypatch):
|
||||
"""create_authentication_options should accept credential descriptors."""
|
||||
monkeypatch.setattr("wiregui.auth.webauthn.get_settings", lambda: type("S", (), {
|
||||
"external_url": "https://vpn.example.com",
|
||||
})())
|
||||
|
||||
from wiregui.auth.webauthn import create_authentication_options
|
||||
credentials = [{"credential_id": "AQIDBA"}]
|
||||
result = create_authentication_options(credentials)
|
||||
assert "options_json" in result
|
||||
assert "challenge" in result
|
||||
|
||||
|
||||
# ========== Events — rule update/delete with rebuild ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.firewall")
|
||||
async def test_on_rule_updated_triggers_rebuild(mock_fw, mock_settings):
|
||||
"""on_rule_updated should rebuild the user's firewall chain."""
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_fw.rebuild_all_rules = AsyncMock()
|
||||
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.services.events import on_rule_updated
|
||||
|
||||
# Need to mock the DB call inside _rebuild_user_chain
|
||||
with patch("wiregui.services.events.async_session") as mock_session_factory:
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Mock the select results
|
||||
mock_rules_result = MagicMock()
|
||||
mock_rules_result.scalars.return_value.all.return_value = []
|
||||
mock_devices_result = MagicMock()
|
||||
mock_devices_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute = AsyncMock(side_effect=[mock_rules_result, mock_devices_result])
|
||||
|
||||
mock_session_factory.return_value = mock_session
|
||||
|
||||
rule = Rule(action="accept", destination="10.0.0.0/8", user_id="a1b2c3d4-0000-0000-0000-000000000000")
|
||||
await on_rule_updated(rule)
|
||||
|
||||
mock_fw.rebuild_all_rules.assert_awaited_once()
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.firewall")
|
||||
async def test_on_rule_deleted_triggers_rebuild(mock_fw, mock_settings):
|
||||
"""on_rule_deleted should rebuild the user's firewall chain."""
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_fw.rebuild_all_rules = AsyncMock()
|
||||
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.services.events import on_rule_deleted
|
||||
|
||||
with patch("wiregui.services.events.async_session") as mock_session_factory:
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_rules_result = MagicMock()
|
||||
mock_rules_result.scalars.return_value.all.return_value = []
|
||||
mock_devices_result = MagicMock()
|
||||
mock_devices_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute = AsyncMock(side_effect=[mock_rules_result, mock_devices_result])
|
||||
|
||||
mock_session_factory.return_value = mock_session
|
||||
|
||||
rule = Rule(action="drop", destination="0.0.0.0/0", user_id="a1b2c3d4-0000-0000-0000-000000000000")
|
||||
await on_rule_deleted(rule)
|
||||
|
||||
mock_fw.rebuild_all_rules.assert_awaited_once()
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
async def test_on_rule_deleted_skips_when_disabled(mock_settings):
|
||||
"""Rule events should be no-ops when WG is disabled."""
|
||||
mock_settings.return_value.wg_enabled = False
|
||||
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.services.events import on_rule_deleted, on_rule_updated
|
||||
|
||||
rule = Rule(action="drop", destination="0.0.0.0/0", user_id="a1b2c3d4-0000-0000-0000-000000000000")
|
||||
await on_rule_updated(rule) # Should not raise
|
||||
await on_rule_deleted(rule) # Should not raise
|
||||
40
tests/test_firewall.py
Normal file
40
tests/test_firewall.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""Tests for firewall service — rule expression building and chain naming."""
|
||||
|
||||
from wiregui.services.firewall import _build_rule_expr, _user_chain_name
|
||||
|
||||
|
||||
def test_user_chain_name():
|
||||
uid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
name = _user_chain_name(uid)
|
||||
assert name == "user_a1b2c3d4e5f6"
|
||||
assert len(name) <= 30
|
||||
|
||||
|
||||
def test_user_chain_name_deterministic():
|
||||
uid = "12345678-1234-1234-1234-123456789abc"
|
||||
assert _user_chain_name(uid) == _user_chain_name(uid)
|
||||
|
||||
|
||||
def test_build_rule_expr_ipv4_accept():
|
||||
expr = _build_rule_expr("10.0.0.0/8", "accept")
|
||||
assert expr == "ip daddr 10.0.0.0/8 accept"
|
||||
|
||||
|
||||
def test_build_rule_expr_ipv6_drop():
|
||||
expr = _build_rule_expr("fd00::/64", "drop")
|
||||
assert expr == "ip6 daddr fd00::/64 drop"
|
||||
|
||||
|
||||
def test_build_rule_expr_with_port():
|
||||
expr = _build_rule_expr("192.168.0.0/16", "accept", port_type="tcp", port_range="80-443")
|
||||
assert expr == "ip daddr 192.168.0.0/16 tcp dport 80-443 accept"
|
||||
|
||||
|
||||
def test_build_rule_expr_single_port():
|
||||
expr = _build_rule_expr("10.0.0.1/32", "drop", port_type="udp", port_range="53")
|
||||
assert expr == "ip daddr 10.0.0.1/32 udp dport 53 drop"
|
||||
|
||||
|
||||
def test_build_rule_expr_no_port():
|
||||
expr = _build_rule_expr("0.0.0.0/0", "accept", port_type=None, port_range=None)
|
||||
assert expr == "ip daddr 0.0.0.0/0 accept"
|
||||
239
tests/test_integration_mfa.py
Normal file
239
tests/test_integration_mfa.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""Integration tests for MFA — full registration and authentication flows through the database."""
|
||||
|
||||
import pyotp
|
||||
from sqlmodel import func, select
|
||||
|
||||
from wiregui.auth.mfa import generate_totp_secret, verify_totp_code
|
||||
from wiregui.auth.passwords import hash_password, verify_password
|
||||
from wiregui.auth.session import authenticate_user
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
async def test_full_totp_registration_flow(session, monkeypatch):
|
||||
"""End-to-end: create user → generate secret → verify code → store method → re-verify from DB."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
# Create user with password
|
||||
user = User(email="mfa-flow@example.com", password_hash=hash_password("secure123"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# Step 1: Generate TOTP secret (happens in account page)
|
||||
secret = generate_totp_secret()
|
||||
|
||||
# Step 2: User scans QR, enters code from their authenticator
|
||||
totp = pyotp.TOTP(secret)
|
||||
code = totp.now()
|
||||
|
||||
# Step 3: Verify the code is correct before saving
|
||||
assert verify_totp_code(secret, code) is True
|
||||
|
||||
# Step 4: Save the MFA method to DB
|
||||
method = MFAMethod(
|
||||
name="My Authenticator",
|
||||
type="totp",
|
||||
payload={"secret": secret},
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
# Step 5: Simulate future login — load method from DB and verify a fresh code
|
||||
fetched_methods = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalars().all()
|
||||
|
||||
assert len(fetched_methods) == 1
|
||||
stored_secret = fetched_methods[0].payload["secret"]
|
||||
fresh_code = pyotp.TOTP(stored_secret).now()
|
||||
assert verify_totp_code(stored_secret, fresh_code) is True
|
||||
|
||||
|
||||
async def test_mfa_blocks_login_without_code(session, monkeypatch):
|
||||
"""User with MFA should not be fully authenticated without completing MFA challenge."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
# Create user with MFA
|
||||
user = User(email="mfa-block@example.com", password_hash=hash_password("password1"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(name="Phone", type="totp", payload={"secret": secret}, user_id=user.id)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
# Password auth succeeds
|
||||
authed_user = await authenticate_user("mfa-block@example.com", "password1")
|
||||
assert authed_user is not None
|
||||
|
||||
# But MFA methods exist — login page would redirect to /mfa instead of completing login
|
||||
mfa_methods = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == authed_user.id)
|
||||
)).scalars().all()
|
||||
assert len(mfa_methods) > 0 # Login flow would check this and redirect to /mfa
|
||||
|
||||
|
||||
async def test_mfa_wrong_code_rejected(session):
|
||||
"""Wrong TOTP code should be rejected even if method is valid."""
|
||||
user = User(email="mfa-wrong@example.com", password_hash=hash_password("pw"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
# Load from DB and try wrong code
|
||||
fetched = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar_one()
|
||||
|
||||
assert verify_totp_code(fetched.payload["secret"], "000000") is False
|
||||
assert verify_totp_code(fetched.payload["secret"], "123456") is False
|
||||
|
||||
|
||||
async def test_mfa_multiple_methods_any_valid_code_works(session):
|
||||
"""If user has multiple TOTP methods, a valid code from any should work."""
|
||||
user = User(email="mfa-multi@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret1 = generate_totp_secret()
|
||||
secret2 = generate_totp_secret()
|
||||
|
||||
session.add(MFAMethod(name="Phone", type="totp", payload={"secret": secret1}, user_id=user.id))
|
||||
session.add(MFAMethod(name="Backup", type="totp", payload={"secret": secret2}, user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
methods = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalars().all()
|
||||
|
||||
# Code from method 1 should verify against method 1's secret
|
||||
code1 = pyotp.TOTP(secret1).now()
|
||||
verified = False
|
||||
for m in methods:
|
||||
if verify_totp_code(m.payload["secret"], code1):
|
||||
verified = True
|
||||
break
|
||||
assert verified is True
|
||||
|
||||
# Code from method 2 should also work
|
||||
code2 = pyotp.TOTP(secret2).now()
|
||||
verified2 = False
|
||||
for m in methods:
|
||||
if verify_totp_code(m.payload["secret"], code2):
|
||||
verified2 = True
|
||||
break
|
||||
assert verified2 is True
|
||||
|
||||
|
||||
async def test_mfa_method_last_used_tracking(session):
|
||||
"""Verifying MFA should update last_used_at timestamp."""
|
||||
user = User(email="mfa-tracking@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
assert method.last_used_at is None
|
||||
|
||||
# Simulate successful verification and update
|
||||
code = pyotp.TOTP(secret).now()
|
||||
assert verify_totp_code(secret, code) is True
|
||||
|
||||
method.last_used_at = utcnow()
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(MFAMethod, method.id)
|
||||
assert fetched.last_used_at is not None
|
||||
|
||||
|
||||
async def test_mfa_delete_method_allows_login_without_mfa(session, monkeypatch):
|
||||
"""After removing all MFA methods, user should not be redirected to MFA challenge."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
user = User(email="mfa-remove@example.com", password_hash=hash_password("pw"))
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(name="Temp", type="totp", payload={"secret": secret}, user_id=user.id)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
# MFA exists
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 1
|
||||
|
||||
# Delete it
|
||||
await session.delete(method)
|
||||
await session.flush()
|
||||
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 0
|
||||
|
||||
# Password auth still works
|
||||
authed = await authenticate_user("mfa-remove@example.com", "pw")
|
||||
assert authed is not None
|
||||
|
||||
# No MFA methods — login flow would skip MFA challenge
|
||||
mfa_check = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == authed.id)
|
||||
)).scalars().all()
|
||||
assert len(mfa_check) == 0
|
||||
|
||||
|
||||
async def test_disabled_user_with_mfa_cannot_login(session, monkeypatch):
|
||||
"""Disabled user should be rejected at password stage, never reaching MFA."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.session.async_session", mock_session)
|
||||
|
||||
user = User(
|
||||
email="mfa-disabled@example.com",
|
||||
password_hash=hash_password("pw"),
|
||||
disabled_at=utcnow(),
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
session.add(MFAMethod(name="Auth", type="totp", payload={"secret": secret}, user_id=user.id))
|
||||
await session.flush()
|
||||
|
||||
# Password auth rejects disabled user before MFA is ever checked
|
||||
result = await authenticate_user("mfa-disabled@example.com", "pw")
|
||||
assert result is None
|
||||
309
tests/test_integration_oidc.py
Normal file
309
tests/test_integration_oidc.py
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
"""Integration tests for OIDC — mock provider endpoints, test full auth code flow."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import respx
|
||||
from httpx import Response
|
||||
from jose import jwt
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.oidc import get_provider_config, load_providers, oauth, register_providers
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
# --- Helper to create a fake OIDC provider config in the DB ---
|
||||
|
||||
|
||||
async def _setup_oidc_config(session) -> Configuration:
|
||||
"""Insert a Configuration with a test OIDC provider."""
|
||||
config = Configuration(
|
||||
openid_connect_providers=[
|
||||
{
|
||||
"id": "test-idp",
|
||||
"label": "Test IdP",
|
||||
"scope": "openid email profile",
|
||||
"response_type": "code",
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-client-secret",
|
||||
"discovery_document_uri": "https://idp.example.com/.well-known/openid-configuration",
|
||||
"auto_create_users": True,
|
||||
}
|
||||
],
|
||||
)
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
return config
|
||||
|
||||
|
||||
def _mock_discovery():
|
||||
"""Mock OIDC discovery document response."""
|
||||
return {
|
||||
"issuer": "https://idp.example.com",
|
||||
"authorization_endpoint": "https://idp.example.com/authorize",
|
||||
"token_endpoint": "https://idp.example.com/token",
|
||||
"userinfo_endpoint": "https://idp.example.com/userinfo",
|
||||
"jwks_uri": "https://idp.example.com/.well-known/jwks.json",
|
||||
}
|
||||
|
||||
|
||||
def _mock_token_response(email: str = "oidc-user@example.com"):
|
||||
"""Mock OIDC token endpoint response with ID token."""
|
||||
now = int(time.time())
|
||||
id_token_payload = {
|
||||
"iss": "https://idp.example.com",
|
||||
"sub": "oidc-subject-123",
|
||||
"aud": "test-client-id",
|
||||
"email": email,
|
||||
"name": "OIDC User",
|
||||
"iat": now,
|
||||
"exp": now + 3600,
|
||||
"nonce": "test-nonce",
|
||||
}
|
||||
# Sign with a simple secret (in real life this would be RSA)
|
||||
id_token = jwt.encode(id_token_payload, "fake-secret", algorithm="HS256")
|
||||
|
||||
return {
|
||||
"access_token": "mock-access-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"id_token": id_token,
|
||||
}
|
||||
|
||||
|
||||
# --- Provider config loading ---
|
||||
|
||||
|
||||
async def test_load_providers_from_config(session, monkeypatch):
|
||||
"""Providers should be loaded from the Configuration table."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||||
|
||||
await _setup_oidc_config(session)
|
||||
|
||||
providers = await load_providers()
|
||||
assert len(providers) == 1
|
||||
assert providers[0]["id"] == "test-idp"
|
||||
assert providers[0]["client_id"] == "test-client-id"
|
||||
|
||||
|
||||
async def test_load_providers_empty_when_no_config(session, monkeypatch):
|
||||
"""Should return empty list when no Configuration exists."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||||
|
||||
providers = await load_providers()
|
||||
assert providers == []
|
||||
|
||||
|
||||
async def test_get_provider_config_by_id(session, monkeypatch):
|
||||
"""Should find a specific provider by ID."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.auth.oidc.async_session", mock_session)
|
||||
|
||||
await _setup_oidc_config(session)
|
||||
|
||||
config = await get_provider_config("test-idp")
|
||||
assert config is not None
|
||||
assert config["label"] == "Test IdP"
|
||||
|
||||
config_missing = await get_provider_config("nonexistent")
|
||||
assert config_missing is None
|
||||
|
||||
|
||||
# --- OIDC connection storage ---
|
||||
|
||||
|
||||
async def test_oidc_connection_created_on_login(session):
|
||||
"""Simulates what the callback route does: create user + OIDC connection."""
|
||||
user = User(email="oidc-new@example.com", role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
token_data = _mock_token_response("oidc-new@example.com")
|
||||
conn = OIDCConnection(
|
||||
provider="test-idp",
|
||||
refresh_token=token_data["refresh_token"],
|
||||
refresh_response=token_data,
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
# Verify it was stored
|
||||
fetched = (await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar_one()
|
||||
assert fetched.provider == "test-idp"
|
||||
assert fetched.refresh_token == "mock-refresh-token"
|
||||
assert fetched.refresh_response["access_token"] == "mock-access-token"
|
||||
|
||||
|
||||
async def test_oidc_connection_updated_on_re_login(session):
|
||||
"""Re-login should update the existing OIDC connection, not create a duplicate."""
|
||||
user = User(email="oidc-relogin@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# First login
|
||||
conn = OIDCConnection(
|
||||
provider="test-idp",
|
||||
refresh_token="old-refresh-token",
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
# Re-login — update existing connection (as the callback route does)
|
||||
existing = (await session.execute(
|
||||
select(OIDCConnection).where(
|
||||
OIDCConnection.user_id == user.id,
|
||||
OIDCConnection.provider == "test-idp",
|
||||
)
|
||||
)).scalar_one()
|
||||
|
||||
existing.refresh_token = "new-refresh-token"
|
||||
from wiregui.utils.time import utcnow
|
||||
existing.refreshed_at = utcnow()
|
||||
session.add(existing)
|
||||
await session.flush()
|
||||
|
||||
# Should still be one connection
|
||||
from sqlmodel import func
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 1
|
||||
|
||||
fetched = (await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.user_id == user.id)
|
||||
)).scalar_one()
|
||||
assert fetched.refresh_token == "new-refresh-token"
|
||||
|
||||
|
||||
async def test_oidc_auto_create_user(session):
|
||||
"""When auto_create_users is True, a new user should be created from OIDC email."""
|
||||
email = "auto-created@example.com"
|
||||
|
||||
# Verify user doesn't exist
|
||||
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
|
||||
assert existing is None
|
||||
|
||||
# Simulate what callback does with auto_create
|
||||
user = User(email=email, role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
from wiregui.utils.time import utcnow
|
||||
user.last_signed_in_at = utcnow()
|
||||
user.last_signed_in_method = "oidc:test-idp"
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
created = (await session.execute(select(User).where(User.email == email))).scalar_one()
|
||||
assert created.role == "unprivileged"
|
||||
assert created.last_signed_in_method == "oidc:test-idp"
|
||||
|
||||
|
||||
async def test_oidc_disabled_user_rejected(session):
|
||||
"""Disabled users should not be logged in via OIDC."""
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
user = User(email="oidc-disabled@example.com", disabled_at=utcnow())
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# The callback route checks disabled_at before creating session
|
||||
assert user.disabled_at is not None # Would redirect to /login
|
||||
|
||||
|
||||
async def test_oidc_user_without_auto_create_rejected(session):
|
||||
"""When auto_create is False and user doesn't exist, login should fail."""
|
||||
email = "no-auto-create@example.com"
|
||||
|
||||
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
|
||||
assert existing is None
|
||||
|
||||
# The callback route checks auto_create_users from provider config
|
||||
# With auto_create=False and no existing user, it would redirect to /login
|
||||
# This verifies the precondition
|
||||
|
||||
|
||||
# --- OIDC refresh token flow ---
|
||||
|
||||
|
||||
async def test_oidc_refresh_stores_new_token(session):
|
||||
"""Simulates a successful token refresh updating the connection."""
|
||||
user = User(email="oidc-refresh-test@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(
|
||||
provider="test-idp",
|
||||
refresh_token="old-refresh",
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
# Simulate refresh result
|
||||
new_token = {
|
||||
"access_token": "new-access",
|
||||
"refresh_token": "new-refresh",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
conn.refresh_token = new_token.get("refresh_token", conn.refresh_token)
|
||||
conn.refresh_response = new_token
|
||||
from wiregui.utils.time import utcnow
|
||||
conn.refreshed_at = utcnow()
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
fetched = await session.get(OIDCConnection, conn.id)
|
||||
assert fetched.refresh_token == "new-refresh"
|
||||
assert fetched.refresh_response["access_token"] == "new-access"
|
||||
assert fetched.refreshed_at is not None
|
||||
|
||||
|
||||
async def test_oidc_multiple_providers_per_user(session):
|
||||
"""User can have connections to multiple OIDC providers."""
|
||||
user = User(email="multi-provider@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
for provider in ["google", "okta", "azure-ad"]:
|
||||
session.add(OIDCConnection(
|
||||
provider=provider,
|
||||
refresh_token=f"token-{provider}",
|
||||
user_id=user.id,
|
||||
))
|
||||
await session.flush()
|
||||
|
||||
conns = (await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.user_id == user.id).order_by(OIDCConnection.provider)
|
||||
)).scalars().all()
|
||||
|
||||
assert len(conns) == 3
|
||||
assert [c.provider for c in conns] == ["azure-ad", "google", "okta"]
|
||||
58
tests/test_magic_link.py
Normal file
58
tests/test_magic_link.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Tests for magic link authentication flow."""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from wiregui.auth.jwt import create_access_token, decode_access_token
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
def test_magic_link_token_creation():
|
||||
"""Magic link token should be a valid JWT with short expiry."""
|
||||
token = create_access_token(
|
||||
user_id="user-123",
|
||||
role="unprivileged",
|
||||
expires_delta=timedelta(minutes=15),
|
||||
)
|
||||
payload = decode_access_token(token)
|
||||
assert payload is not None
|
||||
assert payload["sub"] == "user-123"
|
||||
assert payload["role"] == "unprivileged"
|
||||
|
||||
|
||||
def test_magic_link_token_expired():
|
||||
"""Expired magic link token should be rejected."""
|
||||
token = create_access_token(
|
||||
user_id="user-123",
|
||||
role="admin",
|
||||
expires_delta=timedelta(minutes=-1), # Already expired
|
||||
)
|
||||
payload = decode_access_token(token)
|
||||
assert payload is None
|
||||
|
||||
|
||||
def test_magic_link_token_wrong_user():
|
||||
"""Token should only be valid for the intended user."""
|
||||
token = create_access_token(user_id="user-A", role="admin")
|
||||
payload = decode_access_token(token)
|
||||
assert payload["sub"] == "user-A"
|
||||
# Caller is responsible for checking sub matches the URL user_id
|
||||
|
||||
|
||||
async def test_magic_link_disabled_user_rejected(session):
|
||||
"""Disabled users should not be able to use magic links."""
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
user = User(
|
||||
email="disabled-magic@example.com",
|
||||
password_hash=hash_password("pw"),
|
||||
disabled_at=utcnow(),
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# The token would be valid but the page handler checks disabled_at
|
||||
token = create_access_token(user_id=str(user.id), role="unprivileged")
|
||||
payload = decode_access_token(token)
|
||||
assert payload is not None # Token itself is valid
|
||||
assert user.disabled_at is not None # But user is disabled — handler would reject
|
||||
127
tests/test_mfa.py
Normal file
127
tests/test_mfa.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Tests for TOTP MFA functionality."""
|
||||
|
||||
import pyotp
|
||||
|
||||
from wiregui.auth.mfa import (
|
||||
generate_totp_qr_svg,
|
||||
generate_totp_secret,
|
||||
get_totp_uri,
|
||||
verify_totp_code,
|
||||
)
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
# --- TOTP secret generation ---
|
||||
|
||||
|
||||
def test_generate_secret():
|
||||
secret = generate_totp_secret()
|
||||
assert len(secret) == 32 # base32 encoded
|
||||
assert secret.isalpha() or any(c.isdigit() for c in secret)
|
||||
|
||||
|
||||
def test_generate_secret_unique():
|
||||
s1 = generate_totp_secret()
|
||||
s2 = generate_totp_secret()
|
||||
assert s1 != s2
|
||||
|
||||
|
||||
# --- TOTP URI ---
|
||||
|
||||
|
||||
def test_get_totp_uri():
|
||||
uri = get_totp_uri("JBSWY3DPEHPK3PXP", "user@example.com")
|
||||
assert uri.startswith("otpauth://totp/")
|
||||
assert "user%40example.com" in uri or "user@example.com" in uri
|
||||
assert "secret=JBSWY3DPEHPK3PXP" in uri
|
||||
assert "issuer=WireGUI" in uri
|
||||
|
||||
|
||||
def test_get_totp_uri_custom_issuer():
|
||||
uri = get_totp_uri("SECRET", "test@test.com", issuer="MyVPN")
|
||||
assert "issuer=MyVPN" in uri
|
||||
|
||||
|
||||
# --- TOTP verification ---
|
||||
|
||||
|
||||
def test_verify_valid_code():
|
||||
secret = generate_totp_secret()
|
||||
totp = pyotp.TOTP(secret)
|
||||
code = totp.now()
|
||||
assert verify_totp_code(secret, code) is True
|
||||
|
||||
|
||||
def test_verify_invalid_code():
|
||||
secret = generate_totp_secret()
|
||||
assert verify_totp_code(secret, "000000") is False
|
||||
|
||||
|
||||
def test_verify_wrong_secret():
|
||||
secret1 = generate_totp_secret()
|
||||
secret2 = generate_totp_secret()
|
||||
code = pyotp.TOTP(secret1).now()
|
||||
assert verify_totp_code(secret2, code) is False
|
||||
|
||||
|
||||
def test_verify_empty_code():
|
||||
secret = generate_totp_secret()
|
||||
assert verify_totp_code(secret, "") is False
|
||||
|
||||
|
||||
# --- QR code generation ---
|
||||
|
||||
|
||||
def test_generate_qr_svg():
|
||||
uri = get_totp_uri("SECRET", "test@test.com")
|
||||
svg = generate_totp_qr_svg(uri)
|
||||
assert "<svg" in svg
|
||||
assert "</svg>" in svg
|
||||
|
||||
|
||||
# --- MFA method model integration ---
|
||||
|
||||
|
||||
async def test_create_totp_method(session):
|
||||
user = User(email="mfa-test@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
secret = generate_totp_secret()
|
||||
method = MFAMethod(
|
||||
name="My Phone",
|
||||
type="totp",
|
||||
payload={"secret": secret},
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(method)
|
||||
await session.flush()
|
||||
|
||||
from sqlmodel import select
|
||||
fetched = (await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar_one()
|
||||
|
||||
assert fetched.name == "My Phone"
|
||||
assert fetched.type == "totp"
|
||||
stored_secret = fetched.payload["secret"]
|
||||
code = pyotp.TOTP(stored_secret).now()
|
||||
assert verify_totp_code(stored_secret, code) is True
|
||||
|
||||
|
||||
async def test_user_multiple_mfa_methods(session):
|
||||
user = User(email="multi-mfa@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
m1 = MFAMethod(name="Phone", type="totp", payload={"secret": generate_totp_secret()}, user_id=user.id)
|
||||
m2 = MFAMethod(name="Backup", type="totp", payload={"secret": generate_totp_secret()}, user_id=user.id)
|
||||
session.add_all([m1, m2])
|
||||
await session.flush()
|
||||
|
||||
from sqlmodel import select, func
|
||||
count = (await session.execute(
|
||||
select(func.count()).select_from(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)).scalar()
|
||||
assert count == 2
|
||||
168
tests/test_models.py
Normal file
168
tests/test_models.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
"""Tests for SQLModel table definitions."""
|
||||
|
||||
import pytest # noqa: F401 — needed for pytest.raises
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.connectivity_check import ConnectivityCheck
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
async def test_create_user(session):
|
||||
user = User(email="alice@example.com", role="admin")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
result = await session.execute(select(User).where(User.email == "alice@example.com"))
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.id == user.id
|
||||
assert fetched.role == "admin"
|
||||
assert fetched.disabled_at is None
|
||||
|
||||
|
||||
async def test_create_device_with_user(session):
|
||||
user = User(email="bob@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(
|
||||
name="laptop",
|
||||
public_key="pk-test-device-001",
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
result = await session.execute(select(Device).where(Device.public_key == "pk-test-device-001"))
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.name == "laptop"
|
||||
assert fetched.user_id == user.id
|
||||
assert fetched.use_default_dns is True
|
||||
assert fetched.use_default_allowed_ips is True
|
||||
assert fetched.rx_bytes is None
|
||||
|
||||
|
||||
async def test_device_unique_public_key(session):
|
||||
user = User(email="carol@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
d1 = Device(name="d1", public_key="duplicate-key", user_id=user.id)
|
||||
session.add(d1)
|
||||
await session.flush()
|
||||
|
||||
d2 = Device(name="d2", public_key="duplicate-key", user_id=user.id)
|
||||
session.add(d2)
|
||||
with pytest.raises(Exception): # IntegrityError
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def test_create_rule(session):
|
||||
user = User(email="dave@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id)
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
|
||||
result = await session.execute(select(Rule).where(Rule.user_id == user.id))
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.action == "accept"
|
||||
assert fetched.destination == "10.0.0.0/8"
|
||||
assert fetched.port_type is None
|
||||
assert fetched.port_range is None
|
||||
|
||||
|
||||
async def test_create_rule_with_port(session):
|
||||
rule = Rule(
|
||||
action="drop",
|
||||
destination="192.168.0.0/16",
|
||||
port_type="tcp",
|
||||
port_range="80-443",
|
||||
)
|
||||
session.add(rule)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(Rule).where(Rule.id == rule.id))).scalar_one()
|
||||
assert fetched.port_type == "tcp"
|
||||
assert fetched.port_range == "80-443"
|
||||
assert fetched.user_id is None # global rule
|
||||
|
||||
|
||||
async def test_create_mfa_method(session):
|
||||
user = User(email="eve@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
mfa = MFAMethod(
|
||||
name="My Authenticator",
|
||||
type="totp",
|
||||
payload={"secret": "JBSWY3DPEHPK3PXP"},
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(mfa)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(MFAMethod).where(MFAMethod.user_id == user.id))).scalar_one()
|
||||
assert fetched.type == "totp"
|
||||
assert fetched.payload["secret"] == "JBSWY3DPEHPK3PXP"
|
||||
|
||||
|
||||
async def test_create_oidc_connection(session):
|
||||
user = User(email="frank@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(provider="google", refresh_token="tok_abc", user_id=user.id)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(OIDCConnection).where(OIDCConnection.user_id == user.id))).scalar_one()
|
||||
assert fetched.provider == "google"
|
||||
assert fetched.refresh_token == "tok_abc"
|
||||
|
||||
|
||||
async def test_create_api_token(session):
|
||||
user = User(email="grace@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
token = ApiToken(token_hash="sha256_fake_hash", user_id=user.id)
|
||||
session.add(token)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(ApiToken).where(ApiToken.user_id == user.id))).scalar_one()
|
||||
assert fetched.token_hash == "sha256_fake_hash"
|
||||
assert fetched.expires_at is None
|
||||
|
||||
|
||||
async def test_create_connectivity_check(session):
|
||||
check = ConnectivityCheck(url="https://example.com", response_code=200)
|
||||
session.add(check)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(ConnectivityCheck).where(ConnectivityCheck.id == check.id))).scalar_one()
|
||||
assert fetched.response_code == 200
|
||||
|
||||
|
||||
async def test_configuration_defaults(session):
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
fetched = (await session.execute(select(Configuration).where(Configuration.id == config.id))).scalar_one()
|
||||
assert fetched.allow_unprivileged_device_management is True
|
||||
assert fetched.local_auth_enabled is True
|
||||
assert fetched.default_client_mtu == 1280
|
||||
assert fetched.default_client_persistent_keepalive == 25
|
||||
assert fetched.default_client_dns == ["1.1.1.1", "1.0.0.1"]
|
||||
assert fetched.default_client_allowed_ips == ["0.0.0.0/0", "::/0"]
|
||||
assert fetched.vpn_session_duration == 0
|
||||
assert fetched.openid_connect_providers == []
|
||||
assert fetched.saml_identity_providers == []
|
||||
89
tests/test_notifications.py
Normal file
89
tests/test_notifications.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""Tests for the notification service."""
|
||||
|
||||
from wiregui.services import notifications
|
||||
|
||||
|
||||
def setup_function():
|
||||
"""Clear notifications before each test."""
|
||||
notifications.clear_all()
|
||||
|
||||
|
||||
def test_add_notification():
|
||||
n = notifications.add("info", "Test message")
|
||||
assert n.severity == "info"
|
||||
assert n.message == "Test message"
|
||||
assert n.user is None
|
||||
assert n.id is not None
|
||||
assert n.timestamp is not None
|
||||
|
||||
|
||||
def test_add_notification_with_user():
|
||||
n = notifications.add("error", "Something broke", user="admin@example.com")
|
||||
assert n.user == "admin@example.com"
|
||||
assert n.severity == "error"
|
||||
|
||||
|
||||
def test_current_returns_newest_first():
|
||||
notifications.add("info", "First")
|
||||
notifications.add("warning", "Second")
|
||||
notifications.add("error", "Third")
|
||||
|
||||
current = notifications.current()
|
||||
assert len(current) == 3
|
||||
assert current[0].message == "Third"
|
||||
assert current[1].message == "Second"
|
||||
assert current[2].message == "First"
|
||||
|
||||
|
||||
def test_count():
|
||||
assert notifications.count() == 0
|
||||
notifications.add("info", "One")
|
||||
notifications.add("info", "Two")
|
||||
assert notifications.count() == 2
|
||||
|
||||
|
||||
def test_clear_specific():
|
||||
n1 = notifications.add("info", "Keep this")
|
||||
n2 = notifications.add("error", "Remove this")
|
||||
|
||||
notifications.clear(n2.id)
|
||||
current = notifications.current()
|
||||
assert len(current) == 1
|
||||
assert current[0].id == n1.id
|
||||
|
||||
|
||||
def test_clear_nonexistent_id_is_noop():
|
||||
notifications.add("info", "Test")
|
||||
notifications.clear("nonexistent-id")
|
||||
assert notifications.count() == 1
|
||||
|
||||
|
||||
def test_clear_all():
|
||||
notifications.add("info", "One")
|
||||
notifications.add("info", "Two")
|
||||
notifications.add("info", "Three")
|
||||
assert notifications.count() == 3
|
||||
|
||||
notifications.clear_all()
|
||||
assert notifications.count() == 0
|
||||
assert notifications.current() == []
|
||||
|
||||
|
||||
def test_to_dict():
|
||||
n = notifications.add("warning", "Test dict", user="someone@example.com")
|
||||
d = n.to_dict()
|
||||
assert d["severity"] == "warning"
|
||||
assert d["message"] == "Test dict"
|
||||
assert d["user"] == "someone@example.com"
|
||||
assert "id" in d
|
||||
assert "timestamp" in d
|
||||
|
||||
|
||||
def test_max_notifications():
|
||||
"""Deque should cap at MAX_NOTIFICATIONS."""
|
||||
for i in range(notifications.MAX_NOTIFICATIONS + 10):
|
||||
notifications.add("info", f"Notification {i}")
|
||||
|
||||
assert notifications.count() == notifications.MAX_NOTIFICATIONS
|
||||
# Newest should be the last one added
|
||||
assert notifications.current()[0].message == f"Notification {notifications.MAX_NOTIFICATIONS + 9}"
|
||||
124
tests/test_services.py
Normal file
124
tests/test_services.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""Tests for services — WireGuard and events."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated, on_rule_created
|
||||
|
||||
|
||||
def _make_device(**kwargs) -> Device:
|
||||
defaults = dict(
|
||||
name="test",
|
||||
public_key="pk-test",
|
||||
preshared_key="psk-test",
|
||||
ipv4="10.3.2.5",
|
||||
ipv6="fd00::3:2:5",
|
||||
user_id="00000000-0000-0000-0000-000000000000",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return Device(**defaults)
|
||||
|
||||
|
||||
# --- Events (with WG enabled) ---
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.firewall")
|
||||
@patch("wiregui.services.events.wireguard")
|
||||
async def test_on_device_created_calls_add_peer(mock_wg, mock_fw, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
mock_fw.add_device_jump_rule = AsyncMock()
|
||||
|
||||
device = _make_device()
|
||||
await on_device_created(device)
|
||||
|
||||
mock_wg.add_peer.assert_awaited_once_with(
|
||||
public_key="pk-test",
|
||||
allowed_ips=["10.3.2.5/32", "fd00::3:2:5/128"],
|
||||
preshared_key="psk-test",
|
||||
)
|
||||
mock_fw.add_device_jump_rule.assert_awaited_once()
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.wireguard")
|
||||
async def test_on_device_deleted_calls_remove_peer(mock_wg, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
device = _make_device()
|
||||
await on_device_deleted(device)
|
||||
|
||||
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-test")
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.wireguard")
|
||||
async def test_on_device_updated_calls_add_peer(mock_wg, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
|
||||
device = _make_device()
|
||||
await on_device_updated(device)
|
||||
|
||||
mock_wg.add_peer.assert_awaited_once()
|
||||
|
||||
|
||||
# --- Events (WG disabled) ---
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.wireguard")
|
||||
async def test_events_skip_when_wg_disabled(mock_wg, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = False
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
device = _make_device()
|
||||
await on_device_created(device)
|
||||
await on_device_deleted(device)
|
||||
await on_device_updated(device)
|
||||
|
||||
mock_wg.add_peer.assert_not_awaited()
|
||||
mock_wg.remove_peer.assert_not_awaited()
|
||||
|
||||
|
||||
# --- Events (WG error handling) ---
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.firewall")
|
||||
@patch("wiregui.services.events.wireguard")
|
||||
async def test_on_device_created_handles_wg_error(mock_wg, mock_fw, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_wg.add_peer = AsyncMock(side_effect=RuntimeError("wg failed"))
|
||||
mock_fw.add_device_jump_rule = AsyncMock()
|
||||
|
||||
device = _make_device()
|
||||
# Should not raise — error is logged
|
||||
await on_device_created(device)
|
||||
|
||||
|
||||
# --- Rule events ---
|
||||
|
||||
|
||||
@patch("wiregui.services.events.get_settings")
|
||||
@patch("wiregui.services.events.firewall")
|
||||
async def test_on_rule_created_calls_apply_rule(mock_fw, mock_settings):
|
||||
mock_settings.return_value.wg_enabled = True
|
||||
mock_fw.apply_rule = AsyncMock()
|
||||
|
||||
rule = Rule(
|
||||
action="accept",
|
||||
destination="10.0.0.0/8",
|
||||
port_type="tcp",
|
||||
port_range="80",
|
||||
user_id="00000000-0000-0000-0000-000000000000",
|
||||
)
|
||||
await on_rule_created(rule)
|
||||
|
||||
mock_fw.apply_rule.assert_awaited_once_with(
|
||||
"00000000-0000-0000-0000-000000000000", "10.0.0.0/8", "accept", "tcp", "80",
|
||||
)
|
||||
203
tests/test_services_extended.py
Normal file
203
tests/test_services_extended.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
"""Extended service tests — wireguard subprocess mocking, firewall nft mocking, email."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from wiregui.services.wireguard import PeerInfo, add_peer, get_peers, remove_peer
|
||||
|
||||
|
||||
# ========== WireGuard service (mocked subprocess) ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
async def test_add_peer_without_psk(mock_run):
|
||||
mock_run.return_value = ""
|
||||
await add_peer("pubkey123", ["10.0.0.1/32", "fd00::1/128"], iface="wg-test")
|
||||
mock_run.assert_awaited_once()
|
||||
args = mock_run.call_args[0][0]
|
||||
assert "wg" in args
|
||||
assert "set" in args
|
||||
assert "pubkey123" in args
|
||||
assert "10.0.0.1/32,fd00::1/128" in args
|
||||
|
||||
|
||||
@patch("asyncio.create_subprocess_exec")
|
||||
async def test_add_peer_with_psk(mock_exec):
|
||||
"""PSK path uses subprocess directly with stdin."""
|
||||
mock_proc = AsyncMock()
|
||||
mock_proc.communicate.return_value = (b"", b"")
|
||||
mock_proc.returncode = 0
|
||||
mock_exec.return_value = mock_proc
|
||||
|
||||
await add_peer("pubkey456", ["10.0.0.2/32"], preshared_key="psk-data", iface="wg-test")
|
||||
mock_exec.assert_awaited_once()
|
||||
call_args = mock_exec.call_args[0]
|
||||
assert "preshared-key" in call_args
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
async def test_remove_peer(mock_run):
|
||||
mock_run.return_value = ""
|
||||
await remove_peer("pubkey789", iface="wg-test")
|
||||
mock_run.assert_awaited_once()
|
||||
args = mock_run.call_args[0][0]
|
||||
assert "remove" in args
|
||||
assert "pubkey789" in args
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
async def test_get_peers_parses_dump(mock_run):
|
||||
dump_output = (
|
||||
"privkey\tpubkey\t51820\toff\n"
|
||||
"peerkey1\t(none)\t1.2.3.4:51820\t10.0.0.1/32\t1700000000\t12345\t67890\t25\n"
|
||||
"peerkey2\t(none)\t(none)\t10.0.0.2/32,fd00::2/128\t0\t0\t0\t0\n"
|
||||
)
|
||||
mock_run.return_value = dump_output
|
||||
|
||||
peers = await get_peers(iface="wg-test")
|
||||
assert len(peers) == 2
|
||||
|
||||
assert peers[0].public_key == "peerkey1"
|
||||
assert peers[0].endpoint == "1.2.3.4:51820"
|
||||
assert peers[0].rx_bytes == 12345
|
||||
assert peers[0].tx_bytes == 67890
|
||||
assert peers[0].latest_handshake is not None
|
||||
|
||||
assert peers[1].public_key == "peerkey2"
|
||||
assert peers[1].endpoint is None
|
||||
assert peers[1].rx_bytes == 0
|
||||
assert peers[1].latest_handshake is None
|
||||
assert len(peers[1].allowed_ips) == 2
|
||||
|
||||
|
||||
@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
|
||||
async def test_get_peers_returns_empty_on_error(mock_run):
|
||||
mock_run.side_effect = RuntimeError("interface not found")
|
||||
peers = await get_peers(iface="wg-test")
|
||||
assert peers == []
|
||||
|
||||
|
||||
# ========== Firewall (mocked nft) ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_setup_base_tables(mock_batch):
|
||||
from wiregui.services.firewall import setup_base_tables
|
||||
await setup_base_tables()
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("add table" in c for c in cmds)
|
||||
assert any("forward" in c for c in cmds)
|
||||
assert any("postrouting" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_add_user_chain(mock_batch):
|
||||
from wiregui.services.firewall import add_user_chain
|
||||
await add_user_chain("a1b2c3d4-0000-0000-0000-000000000000")
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("user_a1b2c3d40000" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_remove_user_chain(mock_batch):
|
||||
from wiregui.services.firewall import remove_user_chain
|
||||
await remove_user_chain("a1b2c3d4-0000-0000-0000-000000000000")
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("flush" in c for c in cmds)
|
||||
assert any("delete" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_add_device_jump_rule(mock_batch):
|
||||
from wiregui.services.firewall import add_device_jump_rule
|
||||
await add_device_jump_rule("user-id-123", "10.0.0.5", "fd00::5")
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("10.0.0.5" in c and "jump" in c for c in cmds)
|
||||
assert any("fd00::5" in c and "jump" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_apply_rule(mock_batch):
|
||||
from wiregui.services.firewall import apply_rule
|
||||
await apply_rule("user-123", "10.0.0.0/8", "accept", "tcp", "80-443")
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("10.0.0.0/8" in c and "accept" in c and "tcp dport 80-443" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_rebuild_all_rules(mock_batch):
|
||||
from wiregui.services.firewall import rebuild_all_rules
|
||||
await rebuild_all_rules([
|
||||
{
|
||||
"user_id": "user-1",
|
||||
"devices": [{"ipv4": "10.0.0.1", "ipv6": "fd00::1"}],
|
||||
"rules": [
|
||||
{"destination": "0.0.0.0/0", "action": "accept", "port_type": None, "port_range": None},
|
||||
{"destination": "192.168.0.0/16", "action": "drop", "port_type": "tcp", "port_range": "22"},
|
||||
],
|
||||
}
|
||||
])
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("flush chain" in c and "forward" in c for c in cmds)
|
||||
assert any("0.0.0.0/0" in c and "accept" in c for c in cmds)
|
||||
assert any("192.168.0.0/16" in c and "drop" in c for c in cmds)
|
||||
assert any("10.0.0.1" in c and "jump" in c for c in cmds)
|
||||
|
||||
|
||||
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||
async def test_setup_masquerade(mock_batch):
|
||||
from wiregui.services.firewall import setup_masquerade
|
||||
await setup_masquerade(iface="wg0")
|
||||
mock_batch.assert_awaited_once()
|
||||
cmds = mock_batch.call_args[0][0]
|
||||
assert any("masquerade" in c for c in cmds)
|
||||
|
||||
|
||||
# ========== Email service (mocked smtp) ==========
|
||||
|
||||
|
||||
@patch("wiregui.services.email.aiosmtplib.send", new_callable=AsyncMock)
|
||||
async def test_send_email_success(mock_send, monkeypatch):
|
||||
monkeypatch.setattr("wiregui.services.email.get_settings", lambda: type("S", (), {
|
||||
"smtp_host": "smtp.test.com",
|
||||
"smtp_port": 587,
|
||||
"smtp_user": "user",
|
||||
"smtp_password": "pass",
|
||||
"smtp_from": "test@test.com",
|
||||
})())
|
||||
|
||||
from wiregui.services.email import send_email
|
||||
result = await send_email("to@test.com", "Subject", "Body")
|
||||
assert result is True
|
||||
mock_send.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_send_email_no_smtp_configured(monkeypatch):
|
||||
monkeypatch.setattr("wiregui.services.email.get_settings", lambda: type("S", (), {
|
||||
"smtp_host": None,
|
||||
})())
|
||||
|
||||
from wiregui.services.email import send_email
|
||||
result = await send_email("to@test.com", "Subject", "Body")
|
||||
assert result is False
|
||||
|
||||
|
||||
@patch("wiregui.services.email.aiosmtplib.send", new_callable=AsyncMock)
|
||||
async def test_send_magic_link(mock_send, monkeypatch):
|
||||
monkeypatch.setattr("wiregui.services.email.get_settings", lambda: type("S", (), {
|
||||
"smtp_host": "smtp.test.com",
|
||||
"smtp_port": 587,
|
||||
"smtp_user": "u",
|
||||
"smtp_password": "p",
|
||||
"smtp_from": "noreply@test.com",
|
||||
})())
|
||||
|
||||
from wiregui.services.email import send_magic_link
|
||||
result = await send_magic_link("user@test.com", "https://app.test/magic/123/token")
|
||||
assert result is True
|
||||
mock_send.assert_awaited_once()
|
||||
231
tests/test_tasks.py
Normal file
231
tests/test_tasks.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
"""Tests for background tasks — VPN session expiry and connectivity checks."""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.connectivity_check import ConnectivityCheck
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# --- VPN session expiry ---
|
||||
|
||||
|
||||
async def test_vpn_session_expiry_removes_expired_peers(session, monkeypatch):
|
||||
"""Users whose session expired should have their WG peers removed."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
|
||||
|
||||
# Create config with 1-hour session duration
|
||||
config = Configuration(vpn_session_duration=3600)
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
# Create a user who signed in 2 hours ago (expired)
|
||||
expired_user = User(
|
||||
email="expired@example.com",
|
||||
password_hash=hash_password("pw"),
|
||||
last_signed_in_at=utcnow() - timedelta(hours=2),
|
||||
)
|
||||
session.add(expired_user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(name="laptop", public_key="pk-expired", user_id=expired_user.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
# Create a user who signed in 30 min ago (still valid)
|
||||
active_user = User(
|
||||
email="active@example.com",
|
||||
password_hash=hash_password("pw"),
|
||||
last_signed_in_at=utcnow() - timedelta(minutes=30),
|
||||
)
|
||||
session.add(active_user)
|
||||
await session.flush()
|
||||
|
||||
active_device = Device(name="phone", public_key="pk-active", user_id=active_user.id)
|
||||
session.add(active_device)
|
||||
await session.flush()
|
||||
|
||||
# Mock WireGuard
|
||||
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.vpn_session import _expire_sessions
|
||||
await _expire_sessions()
|
||||
|
||||
# Only expired user's peer should be removed
|
||||
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-expired")
|
||||
|
||||
|
||||
async def test_vpn_session_no_expiry_when_duration_zero(session, monkeypatch):
|
||||
"""When vpn_session_duration is 0 (unlimited), no peers should be removed."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
|
||||
|
||||
config = Configuration(vpn_session_duration=0)
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
email="unlimited@example.com",
|
||||
last_signed_in_at=utcnow() - timedelta(days=365),
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.vpn_session import _expire_sessions
|
||||
await _expire_sessions()
|
||||
|
||||
mock_wg.remove_peer.assert_not_awaited()
|
||||
|
||||
|
||||
async def test_vpn_session_no_expiry_when_no_config(session, monkeypatch):
|
||||
"""When no Configuration exists, no peers should be removed."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
|
||||
|
||||
# No Configuration row at all
|
||||
|
||||
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.vpn_session import _expire_sessions
|
||||
await _expire_sessions()
|
||||
|
||||
mock_wg.remove_peer.assert_not_awaited()
|
||||
|
||||
|
||||
async def test_vpn_session_skips_disabled_users(session, monkeypatch):
|
||||
"""Disabled users should be skipped even if their session is expired."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.vpn_session.async_session", mock_session)
|
||||
|
||||
config = Configuration(vpn_session_duration=3600)
|
||||
session.add(config)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
email="disabled-session@example.com",
|
||||
last_signed_in_at=utcnow() - timedelta(hours=2),
|
||||
disabled_at=utcnow(),
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(name="d", public_key="pk-disabled-session", user_id=user.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
with patch("wiregui.tasks.vpn_session.wireguard") as mock_wg:
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.vpn_session import _expire_sessions
|
||||
await _expire_sessions()
|
||||
|
||||
mock_wg.remove_peer.assert_not_awaited()
|
||||
|
||||
|
||||
# --- Connectivity checks ---
|
||||
|
||||
|
||||
async def test_connectivity_check_success(session, monkeypatch):
|
||||
"""Successful connectivity check should store result in DB."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.connectivity.async_session", mock_session)
|
||||
|
||||
# Mock httpx to return a successful response
|
||||
import httpx
|
||||
|
||||
class MockResponse:
|
||||
status_code = 200
|
||||
headers = {"content-type": "text/plain"}
|
||||
text = "203.0.113.1"
|
||||
|
||||
class MockAsyncClient:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
async def get(self, url):
|
||||
return MockResponse()
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.connectivity.httpx.AsyncClient", lambda **kw: MockAsyncClient())
|
||||
|
||||
from wiregui.tasks.connectivity import _check_connectivity
|
||||
await _check_connectivity()
|
||||
|
||||
result = (await session.execute(select(ConnectivityCheck).limit(1))).scalar_one()
|
||||
assert result.response_code == 200
|
||||
assert result.response_body == "203.0.113.1"
|
||||
|
||||
|
||||
async def test_connectivity_check_failure(session, monkeypatch):
|
||||
"""Failed connectivity check should store error and create notification."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.connectivity.async_session", mock_session)
|
||||
|
||||
class MockAsyncClient:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
async def get(self, url):
|
||||
raise ConnectionError("Network unreachable")
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.connectivity.httpx.AsyncClient", lambda **kw: MockAsyncClient())
|
||||
|
||||
from wiregui.services import notifications
|
||||
notifications.clear_all()
|
||||
|
||||
from wiregui.tasks.connectivity import _check_connectivity
|
||||
await _check_connectivity()
|
||||
|
||||
result = (await session.execute(select(ConnectivityCheck).limit(1))).scalar_one()
|
||||
assert result.response_code is None
|
||||
assert "Network unreachable" in result.response_body
|
||||
|
||||
assert notifications.count() > 0
|
||||
assert "connectivity" in notifications.current()[0].message.lower()
|
||||
229
tests/test_tasks_extended.py
Normal file
229
tests/test_tasks_extended.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""Extended task tests — stats polling, reconciliation, OIDC refresh."""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.user import User
|
||||
from wiregui.services.wireguard import PeerInfo
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
# ========== Stats task ==========
|
||||
|
||||
|
||||
async def test_stats_update_from_wg_peers(session, monkeypatch):
|
||||
"""Stats task should update device records from WireGuard peer data."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.stats.async_session", mock_session)
|
||||
|
||||
user = User(email="stats-user@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(name="stats-dev", public_key="pk-stats-test", user_id=user.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
mock_peers = [
|
||||
PeerInfo(
|
||||
public_key="pk-stats-test",
|
||||
endpoint="1.2.3.4:51820",
|
||||
rx_bytes=123456,
|
||||
tx_bytes=789012,
|
||||
latest_handshake=utcnow(),
|
||||
)
|
||||
]
|
||||
|
||||
with patch("wiregui.tasks.stats.wireguard") as mock_wg:
|
||||
mock_wg.get_peers = AsyncMock(return_value=mock_peers)
|
||||
from wiregui.tasks.stats import _update_stats
|
||||
await _update_stats()
|
||||
|
||||
refreshed = await session.get(Device, device.id)
|
||||
assert refreshed.rx_bytes == 123456
|
||||
assert refreshed.tx_bytes == 789012
|
||||
assert refreshed.remote_ip == "1.2.3.4"
|
||||
assert refreshed.latest_handshake is not None
|
||||
|
||||
|
||||
async def test_stats_no_peers_is_noop(session, monkeypatch):
|
||||
"""No WG peers should result in no DB changes."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.stats.async_session", mock_session)
|
||||
|
||||
with patch("wiregui.tasks.stats.wireguard") as mock_wg:
|
||||
mock_wg.get_peers = AsyncMock(return_value=[])
|
||||
from wiregui.tasks.stats import _update_stats
|
||||
await _update_stats() # Should not raise
|
||||
|
||||
|
||||
async def test_stats_unmatched_peer_ignored(session, monkeypatch):
|
||||
"""Peers not matching any device should be ignored."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.stats.async_session", mock_session)
|
||||
|
||||
mock_peers = [
|
||||
PeerInfo(public_key="unknown-peer-key", rx_bytes=100, tx_bytes=200)
|
||||
]
|
||||
|
||||
with patch("wiregui.tasks.stats.wireguard") as mock_wg:
|
||||
mock_wg.get_peers = AsyncMock(return_value=mock_peers)
|
||||
from wiregui.tasks.stats import _update_stats
|
||||
await _update_stats() # Should not raise
|
||||
|
||||
|
||||
# ========== Reconciliation task ==========
|
||||
|
||||
|
||||
async def test_reconcile_adds_missing_peers(session, monkeypatch):
|
||||
"""Devices in DB but not in WG should be added."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.reconcile.async_session", mock_session)
|
||||
|
||||
user = User(email="reconcile@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(name="missing", public_key="pk-missing", ipv4="10.0.0.5", user_id=user.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
with patch("wiregui.tasks.reconcile.wireguard") as mock_wg:
|
||||
mock_wg.get_peers = AsyncMock(return_value=[]) # WG has no peers
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.reconcile import reconcile
|
||||
await reconcile()
|
||||
|
||||
mock_wg.add_peer.assert_awaited_once()
|
||||
call_kwargs = mock_wg.add_peer.call_args[1]
|
||||
assert call_kwargs["public_key"] == "pk-missing"
|
||||
assert "10.0.0.5/32" in call_kwargs["allowed_ips"]
|
||||
mock_wg.remove_peer.assert_not_awaited()
|
||||
|
||||
|
||||
async def test_reconcile_removes_orphaned_peers(session, monkeypatch):
|
||||
"""Peers in WG but not in DB should be removed."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.reconcile.async_session", mock_session)
|
||||
|
||||
# No devices in DB, but WG has a peer
|
||||
orphan = PeerInfo(public_key="pk-orphan", rx_bytes=0, tx_bytes=0)
|
||||
|
||||
with patch("wiregui.tasks.reconcile.wireguard") as mock_wg:
|
||||
mock_wg.get_peers = AsyncMock(return_value=[orphan])
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.reconcile import reconcile
|
||||
await reconcile()
|
||||
|
||||
mock_wg.remove_peer.assert_awaited_once_with(public_key="pk-orphan")
|
||||
mock_wg.add_peer.assert_not_awaited()
|
||||
|
||||
|
||||
async def test_reconcile_in_sync(session, monkeypatch):
|
||||
"""When DB and WG match, nothing should happen."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.reconcile.async_session", mock_session)
|
||||
|
||||
user = User(email="in-sync@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
device = Device(name="synced", public_key="pk-synced", user_id=user.id)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
|
||||
peer = PeerInfo(public_key="pk-synced", rx_bytes=0, tx_bytes=0)
|
||||
|
||||
with patch("wiregui.tasks.reconcile.wireguard") as mock_wg:
|
||||
mock_wg.get_peers = AsyncMock(return_value=[peer])
|
||||
mock_wg.add_peer = AsyncMock()
|
||||
mock_wg.remove_peer = AsyncMock()
|
||||
|
||||
from wiregui.tasks.reconcile import reconcile
|
||||
await reconcile()
|
||||
|
||||
mock_wg.add_peer.assert_not_awaited()
|
||||
mock_wg.remove_peer.assert_not_awaited()
|
||||
|
||||
|
||||
# ========== OIDC refresh task ==========
|
||||
|
||||
|
||||
async def test_oidc_refresh_no_connections_is_noop(session, monkeypatch):
|
||||
"""No OIDC connections should result in no refresh attempts."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.oidc_refresh.async_session", mock_session)
|
||||
monkeypatch.setattr("wiregui.auth.oidc.load_providers", AsyncMock(return_value=[]))
|
||||
|
||||
from wiregui.tasks.oidc_refresh import _refresh_all
|
||||
await _refresh_all() # Should not raise
|
||||
|
||||
|
||||
async def test_oidc_refresh_skips_unknown_provider(session, monkeypatch):
|
||||
"""Connections for unknown providers should be skipped."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session():
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr("wiregui.tasks.oidc_refresh.async_session", mock_session)
|
||||
monkeypatch.setattr("wiregui.auth.oidc.load_providers", AsyncMock(return_value=[
|
||||
{"id": "known-provider", "client_id": "cid", "client_secret": "cs", "discovery_document_uri": "https://x"}
|
||||
]))
|
||||
|
||||
user = User(email="oidc-skip@test.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
conn = OIDCConnection(provider="unknown-provider", refresh_token="tok", user_id=user.id)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
from wiregui.tasks.oidc_refresh import _refresh_all
|
||||
await _refresh_all() # Should skip gracefully
|
||||
120
tests/test_utils.py
Normal file
120
tests/test_utils.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
"""Tests for utility modules."""
|
||||
|
||||
import subprocess
|
||||
|
||||
import pytest
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.network import allocate_ipv4, allocate_ipv6
|
||||
from wiregui.utils.wg_conf import build_client_config
|
||||
|
||||
|
||||
# --- IP allocation ---
|
||||
|
||||
|
||||
async def test_allocate_ipv4_first_device(session):
|
||||
user = User(email="net-test@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
ip = await allocate_ipv4(session, "10.3.2.0/24")
|
||||
assert ip.startswith("10.3.2.")
|
||||
# Should not be the network (.0) or gateway (.1)
|
||||
last_octet = int(ip.split(".")[-1])
|
||||
assert last_octet >= 2
|
||||
|
||||
|
||||
async def test_allocate_ipv4_skips_used(session):
|
||||
user = User(email="net-skip@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# Exhaust a tiny /30 network (4 addresses: .0 network, .1 gateway, .2 usable, .3 broadcast)
|
||||
d1 = Device(name="d1", public_key="pk-net-1", ipv4="10.99.0.2", user_id=user.id)
|
||||
session.add(d1)
|
||||
await session.flush()
|
||||
|
||||
# Only .2 was usable in a /30 — allocation should fail
|
||||
with pytest.raises(ValueError, match="No available"):
|
||||
await allocate_ipv4(session, "10.99.0.0/30")
|
||||
|
||||
|
||||
async def test_allocate_ipv6(session):
|
||||
user = User(email="net6-test@example.com")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
ip = await allocate_ipv6(session, "fd00::3:2:0/120")
|
||||
assert ip.startswith("fd00::3:2:")
|
||||
|
||||
|
||||
# --- WireGuard config builder ---
|
||||
|
||||
|
||||
def test_build_client_config():
|
||||
device = Device(
|
||||
name="test-device",
|
||||
public_key="device-pub-key",
|
||||
preshared_key="device-psk",
|
||||
ipv4="10.3.2.5",
|
||||
ipv6="fd00::3:2:5",
|
||||
use_default_allowed_ips=True,
|
||||
use_default_dns=True,
|
||||
use_default_endpoint=True,
|
||||
use_default_mtu=True,
|
||||
use_default_persistent_keepalive=True,
|
||||
user_id="00000000-0000-0000-0000-000000000000",
|
||||
)
|
||||
|
||||
config = build_client_config(device, "PRIVATE_KEY_HERE", "SERVER_PUB_KEY")
|
||||
|
||||
assert "[Interface]" in config
|
||||
assert "PrivateKey = PRIVATE_KEY_HERE" in config
|
||||
assert "10.3.2.5/32" in config
|
||||
assert "fd00::3:2:5/128" in config
|
||||
assert "[Peer]" in config
|
||||
assert "PublicKey = SERVER_PUB_KEY" in config
|
||||
assert "PresharedKey = device-psk" in config
|
||||
assert "Endpoint = " in config
|
||||
|
||||
|
||||
def test_build_client_config_no_psk():
|
||||
device = Device(
|
||||
name="no-psk",
|
||||
public_key="pub",
|
||||
preshared_key=None,
|
||||
ipv4="10.3.2.6",
|
||||
ipv6=None,
|
||||
use_default_allowed_ips=True,
|
||||
use_default_dns=True,
|
||||
use_default_endpoint=True,
|
||||
use_default_mtu=True,
|
||||
use_default_persistent_keepalive=True,
|
||||
user_id="00000000-0000-0000-0000-000000000000",
|
||||
)
|
||||
|
||||
config = build_client_config(device, "PRIV", "SERVPUB")
|
||||
assert "PresharedKey" not in config
|
||||
assert "fd00::" not in config # no ipv6
|
||||
|
||||
|
||||
# --- Crypto (only if wg is installed) ---
|
||||
|
||||
|
||||
def test_generate_keypair():
|
||||
"""Test keypair generation — requires `wg` CLI to be installed."""
|
||||
try:
|
||||
subprocess.run(["wg", "--version"], capture_output=True, check=True)
|
||||
except FileNotFoundError:
|
||||
pytest.skip("wg CLI not installed")
|
||||
|
||||
from wiregui.utils.crypto import generate_keypair, generate_preshared_key
|
||||
|
||||
priv, pub = generate_keypair()
|
||||
assert len(priv) == 44 # base64-encoded 32 bytes
|
||||
assert len(pub) == 44
|
||||
|
||||
psk = generate_preshared_key()
|
||||
assert len(psk) == 44
|
||||
0
wiregui/__init__.py
Normal file
0
wiregui/__init__.py
Normal file
0
wiregui/api/__init__.py
Normal file
0
wiregui/api/__init__.py
Normal file
38
wiregui/api/deps.py
Normal file
38
wiregui/api/deps.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
"""Shared FastAPI dependencies for the REST API."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from wiregui.auth.api_token import resolve_bearer_token
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession]:
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_current_api_user(
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""Extract Bearer token from Authorization header and resolve the user."""
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if not auth.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
|
||||
|
||||
token = auth[7:]
|
||||
user = await resolve_bearer_token(session, token)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired API token")
|
||||
return user
|
||||
|
||||
|
||||
async def require_admin(user: User = Depends(get_current_api_user)) -> User:
|
||||
"""Require the authenticated user to be an admin."""
|
||||
if user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
return user
|
||||
11
wiregui/api/v0/__init__.py
Normal file
11
wiregui/api/v0/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""v0 API router — aggregates all sub-routers."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from wiregui.api.v0 import configuration, devices, rules, users
|
||||
|
||||
router = APIRouter(prefix="/v0")
|
||||
router.include_router(users.router)
|
||||
router.include_router(devices.router)
|
||||
router.include_router(rules.router)
|
||||
router.include_router(configuration.router)
|
||||
46
wiregui/api/v0/configuration.py
Normal file
46
wiregui/api/v0/configuration.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.api.deps import get_db, require_admin
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.user import User
|
||||
from wiregui.schemas.configuration import ConfigurationRead, ConfigurationUpdate
|
||||
|
||||
router = APIRouter(prefix="/configuration", tags=["configuration"])
|
||||
|
||||
|
||||
async def _get_config(session: AsyncSession) -> Configuration:
|
||||
result = await session.execute(select(Configuration).limit(1))
|
||||
config = result.scalar_one_or_none()
|
||||
if not config:
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
|
||||
@router.get("/", response_model=ConfigurationRead)
|
||||
async def get_configuration(
|
||||
session: AsyncSession = Depends(get_db),
|
||||
_admin: User = Depends(require_admin),
|
||||
):
|
||||
return await _get_config(session)
|
||||
|
||||
|
||||
@router.put("/", response_model=ConfigurationRead)
|
||||
async def update_configuration(
|
||||
body: ConfigurationUpdate,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
_admin: User = Depends(require_admin),
|
||||
):
|
||||
config = await _get_config(session)
|
||||
|
||||
for key, val in body.model_dump(exclude_unset=True).items():
|
||||
setattr(config, key, val)
|
||||
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
119
wiregui/api/v0/devices.py
Normal file
119
wiregui/api/v0/devices.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.api.deps import get_current_api_user, get_db
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.user import User
|
||||
from wiregui.schemas.device import DeviceCreate, DeviceRead, DeviceUpdate
|
||||
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated
|
||||
from wiregui.utils.crypto import generate_keypair, generate_preshared_key
|
||||
from wiregui.utils.network import allocate_ipv4, allocate_ipv6
|
||||
|
||||
router = APIRouter(prefix="/devices", tags=["devices"])
|
||||
|
||||
|
||||
@router.get("/", response_model=list[DeviceRead])
|
||||
async def list_devices(
|
||||
session: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_api_user),
|
||||
):
|
||||
if current_user.role == "admin":
|
||||
result = await session.execute(select(Device).order_by(Device.inserted_at.desc()))
|
||||
else:
|
||||
result = await session.execute(
|
||||
select(Device).where(Device.user_id == current_user.id).order_by(Device.inserted_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{device_id}", response_model=DeviceRead)
|
||||
async def get_device(
|
||||
device_id: UUID,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_api_user),
|
||||
):
|
||||
device = await session.get(Device, device_id)
|
||||
if not device:
|
||||
raise HTTPException(404, "Device not found")
|
||||
if current_user.role != "admin" and device.user_id != current_user.id:
|
||||
raise HTTPException(403, "Access denied")
|
||||
return device
|
||||
|
||||
|
||||
@router.post("/", response_model=DeviceRead, status_code=201)
|
||||
async def create_device(
|
||||
body: DeviceCreate,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_api_user),
|
||||
):
|
||||
settings = get_settings()
|
||||
owner_id = body.user_id if (body.user_id and current_user.role == "admin") else current_user.id
|
||||
|
||||
_private_key, public_key = generate_keypair()
|
||||
psk = generate_preshared_key()
|
||||
ipv4 = await allocate_ipv4(session, settings.wg_ipv4_network)
|
||||
ipv6 = await allocate_ipv6(session, settings.wg_ipv6_network)
|
||||
|
||||
device = Device(
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
public_key=public_key,
|
||||
preshared_key=psk,
|
||||
ipv4=ipv4,
|
||||
ipv6=ipv6,
|
||||
user_id=owner_id,
|
||||
)
|
||||
session.add(device)
|
||||
await session.commit()
|
||||
await session.refresh(device)
|
||||
|
||||
logger.info("API: device created {} ({})", device.name, device.ipv4)
|
||||
await on_device_created(device)
|
||||
return device
|
||||
|
||||
|
||||
@router.put("/{device_id}", response_model=DeviceRead)
|
||||
async def update_device(
|
||||
device_id: UUID,
|
||||
body: DeviceUpdate,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_api_user),
|
||||
):
|
||||
device = await session.get(Device, device_id)
|
||||
if not device:
|
||||
raise HTTPException(404, "Device not found")
|
||||
if current_user.role != "admin" and device.user_id != current_user.id:
|
||||
raise HTTPException(403, "Access denied")
|
||||
|
||||
for key, val in body.model_dump(exclude_unset=True).items():
|
||||
setattr(device, key, val)
|
||||
|
||||
session.add(device)
|
||||
await session.commit()
|
||||
await session.refresh(device)
|
||||
|
||||
await on_device_updated(device)
|
||||
return device
|
||||
|
||||
|
||||
@router.delete("/{device_id}", status_code=204)
|
||||
async def delete_device(
|
||||
device_id: UUID,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_api_user),
|
||||
):
|
||||
device = await session.get(Device, device_id)
|
||||
if not device:
|
||||
raise HTTPException(404, "Device not found")
|
||||
if current_user.role != "admin" and device.user_id != current_user.id:
|
||||
raise HTTPException(403, "Access denied")
|
||||
|
||||
await session.delete(device)
|
||||
await session.commit()
|
||||
logger.info("API: device deleted {}", device.name)
|
||||
await on_device_deleted(device)
|
||||
86
wiregui/api/v0/rules.py
Normal file
86
wiregui/api/v0/rules.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.api.deps import get_db, require_admin
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
from wiregui.schemas.rule import RuleCreate, RuleRead, RuleUpdate
|
||||
from wiregui.services.events import on_rule_created, on_rule_deleted
|
||||
|
||||
router = APIRouter(prefix="/rules", tags=["rules"])
|
||||
|
||||
|
||||
@router.get("/", response_model=list[RuleRead])
|
||||
async def list_rules(
|
||||
session: AsyncSession = Depends(get_db),
|
||||
_admin: User = Depends(require_admin),
|
||||
):
|
||||
result = await session.execute(select(Rule).order_by(Rule.inserted_at.desc()))
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{rule_id}", response_model=RuleRead)
|
||||
async def get_rule(
|
||||
rule_id: UUID,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
_admin: User = Depends(require_admin),
|
||||
):
|
||||
rule = await session.get(Rule, rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(404, "Rule not found")
|
||||
return rule
|
||||
|
||||
|
||||
@router.post("/", response_model=RuleRead, status_code=201)
|
||||
async def create_rule(
|
||||
body: RuleCreate,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
_admin: User = Depends(require_admin),
|
||||
):
|
||||
rule = Rule(**body.model_dump())
|
||||
session.add(rule)
|
||||
await session.commit()
|
||||
await session.refresh(rule)
|
||||
|
||||
logger.info("API: rule created {} -> {}", rule.action, rule.destination)
|
||||
await on_rule_created(rule)
|
||||
return rule
|
||||
|
||||
|
||||
@router.put("/{rule_id}", response_model=RuleRead)
|
||||
async def update_rule(
|
||||
rule_id: UUID,
|
||||
body: RuleUpdate,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
_admin: User = Depends(require_admin),
|
||||
):
|
||||
rule = await session.get(Rule, rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(404, "Rule not found")
|
||||
|
||||
for key, val in body.model_dump(exclude_unset=True).items():
|
||||
setattr(rule, key, val)
|
||||
|
||||
session.add(rule)
|
||||
await session.commit()
|
||||
await session.refresh(rule)
|
||||
return rule
|
||||
|
||||
|
||||
@router.delete("/{rule_id}", status_code=204)
|
||||
async def delete_rule(
|
||||
rule_id: UUID,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
_admin: User = Depends(require_admin),
|
||||
):
|
||||
rule = await session.get(Rule, rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(404, "Rule not found")
|
||||
await session.delete(rule)
|
||||
await session.commit()
|
||||
logger.info("API: rule deleted {} {}", rule.action, rule.destination)
|
||||
await on_rule_deleted(rule)
|
||||
86
wiregui/api/v0/users.py
Normal file
86
wiregui/api/v0/users.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.api.deps import get_db, require_admin
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.models.user import User
|
||||
from wiregui.schemas.user import UserCreate, UserRead, UserUpdate
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
@router.get("/", response_model=list[UserRead])
|
||||
async def list_users(
|
||||
session: AsyncSession = Depends(get_db),
|
||||
_admin: User = Depends(require_admin),
|
||||
):
|
||||
result = await session.execute(select(User).order_by(User.email))
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserRead)
|
||||
async def get_user(
|
||||
user_id: UUID,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
_admin: User = Depends(require_admin),
|
||||
):
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(404, "User not found")
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/", response_model=UserRead, status_code=201)
|
||||
async def create_user(
|
||||
body: UserCreate,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
_admin: User = Depends(require_admin),
|
||||
):
|
||||
user = User(
|
||||
email=body.email,
|
||||
password_hash=hash_password(body.password),
|
||||
role=body.role,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/{user_id}", response_model=UserRead)
|
||||
async def update_user(
|
||||
user_id: UUID,
|
||||
body: UserUpdate,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
_admin: User = Depends(require_admin),
|
||||
):
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(404, "User not found")
|
||||
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
if "password" in updates:
|
||||
updates["password_hash"] = hash_password(updates.pop("password"))
|
||||
for key, val in updates.items():
|
||||
setattr(user, key, val)
|
||||
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/{user_id}", status_code=204)
|
||||
async def delete_user(
|
||||
user_id: UUID,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
_admin: User = Depends(require_admin),
|
||||
):
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(404, "User not found")
|
||||
await session.delete(user)
|
||||
await session.commit()
|
||||
0
wiregui/auth/__init__.py
Normal file
0
wiregui/auth/__init__.py
Normal file
42
wiregui/auth/api_token.py
Normal file
42
wiregui/auth/api_token.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""API token authentication — Bearer token via Authorization header."""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
def generate_api_token() -> tuple[str, str]:
|
||||
"""Generate a new API token. Returns (plaintext_token, token_hash)."""
|
||||
plaintext = secrets.token_urlsafe(32)
|
||||
token_hash = hashlib.sha256(plaintext.encode()).hexdigest()
|
||||
return plaintext, token_hash
|
||||
|
||||
|
||||
async def resolve_bearer_token(session: AsyncSession, token: str) -> User | None:
|
||||
"""Look up a Bearer token and return the associated user, or None."""
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
result = await session.execute(
|
||||
select(ApiToken).where(ApiToken.token_hash == token_hash)
|
||||
)
|
||||
api_token = result.scalar_one_or_none()
|
||||
if api_token is None:
|
||||
return None
|
||||
|
||||
# Check expiry
|
||||
if api_token.expires_at and api_token.expires_at < utcnow():
|
||||
logger.debug("API token expired for user_id={}", api_token.user_id)
|
||||
return None
|
||||
|
||||
# Resolve user
|
||||
user = await session.get(User, api_token.user_id)
|
||||
if user is None or user.disabled_at is not None:
|
||||
return None
|
||||
|
||||
return user
|
||||
26
wiregui/auth/jwt.py
Normal file
26
wiregui/auth/jwt.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from wiregui.config import get_settings
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
DEFAULT_EXPIRE_HOURS = 8
|
||||
|
||||
|
||||
def create_access_token(
|
||||
user_id: str,
|
||||
role: str,
|
||||
expires_delta: timedelta | None = None,
|
||||
) -> str:
|
||||
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(hours=DEFAULT_EXPIRE_HOURS))
|
||||
payload = {"sub": user_id, "role": role, "exp": expire}
|
||||
return jwt.encode(payload, get_settings().secret_key, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> dict | None:
|
||||
"""Decode and validate a JWT. Returns the payload dict or None if invalid/expired."""
|
||||
try:
|
||||
return jwt.decode(token, get_settings().secret_key, algorithms=[ALGORITHM])
|
||||
except JWTError:
|
||||
return None
|
||||
31
wiregui/auth/mfa.py
Normal file
31
wiregui/auth/mfa.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""TOTP Multi-Factor Authentication using pyotp."""
|
||||
|
||||
import io
|
||||
from urllib.parse import quote
|
||||
|
||||
import pyotp
|
||||
import qrcode
|
||||
import qrcode.image.svg
|
||||
|
||||
|
||||
def generate_totp_secret() -> str:
|
||||
"""Generate a new random TOTP secret."""
|
||||
return pyotp.random_base32()
|
||||
|
||||
|
||||
def get_totp_uri(secret: str, email: str, issuer: str = "WireGUI") -> str:
|
||||
"""Build an otpauth:// URI for QR code scanning."""
|
||||
return pyotp.TOTP(secret).provisioning_uri(name=email, issuer_name=issuer)
|
||||
|
||||
|
||||
def verify_totp_code(secret: str, code: str) -> bool:
|
||||
"""Verify a TOTP code against a secret. Allows 1 window of clock drift."""
|
||||
return pyotp.TOTP(secret).verify(code, valid_window=1)
|
||||
|
||||
|
||||
def generate_totp_qr_svg(uri: str) -> str:
|
||||
"""Generate an SVG QR code for a TOTP provisioning URI."""
|
||||
qr = qrcode.make(uri, image_factory=qrcode.image.svg.SvgPathImage)
|
||||
buf = io.BytesIO()
|
||||
qr.save(buf)
|
||||
return buf.getvalue().decode()
|
||||
20
wiregui/auth/middleware.py
Normal file
20
wiregui/auth/middleware.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"""NiceGUI auth middleware — redirects unauthenticated requests to /login."""
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
# Paths that don't require authentication
|
||||
PUBLIC_PREFIXES = ("/login", "/_nicegui", "/api")
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if any(request.url.path.startswith(p) for p in PUBLIC_PREFIXES):
|
||||
return await call_next(request)
|
||||
|
||||
# NiceGUI stores auth state in the Starlette session (cookie-backed)
|
||||
if not request.session.get("authenticated"):
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
return await call_next(request)
|
||||
59
wiregui/auth/oidc.py
Normal file
59
wiregui/auth/oidc.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
"""OIDC authentication via authlib — provider registry and authorization code flow."""
|
||||
|
||||
from authlib.integrations.starlette_client import OAuth
|
||||
from loguru import logger
|
||||
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.configuration import Configuration
|
||||
|
||||
# Global OAuth instance — providers are registered dynamically
|
||||
oauth = OAuth()
|
||||
|
||||
|
||||
async def load_providers() -> list[dict]:
|
||||
"""Load OIDC provider configs from the Configuration singleton."""
|
||||
from sqlmodel import select
|
||||
|
||||
async with async_session() as session:
|
||||
result = await session.execute(select(Configuration).limit(1))
|
||||
config = result.scalar_one_or_none()
|
||||
if not config:
|
||||
return []
|
||||
return config.openid_connect_providers or []
|
||||
|
||||
|
||||
async def register_providers() -> None:
|
||||
"""Register all configured OIDC providers with authlib. Call on startup."""
|
||||
providers = await load_providers()
|
||||
for p in providers:
|
||||
provider_id = p.get("id")
|
||||
if not provider_id:
|
||||
continue
|
||||
try:
|
||||
oauth.register(
|
||||
name=provider_id,
|
||||
client_id=p.get("client_id"),
|
||||
client_secret=p.get("client_secret"),
|
||||
server_metadata_url=p.get("discovery_document_uri"),
|
||||
client_kwargs={"scope": p.get("scope", "openid email profile")},
|
||||
)
|
||||
logger.info("OIDC provider registered: {}", provider_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to register OIDC provider {}: {}", provider_id, e)
|
||||
|
||||
|
||||
def get_client(provider_id: str):
|
||||
"""Get an authlib OAuth client for a registered provider."""
|
||||
client = oauth.create_client(provider_id)
|
||||
if client is None:
|
||||
raise ValueError(f"OIDC provider '{provider_id}' is not registered")
|
||||
return client
|
||||
|
||||
|
||||
async def get_provider_config(provider_id: str) -> dict | None:
|
||||
"""Get the config dict for a specific provider."""
|
||||
providers = await load_providers()
|
||||
for p in providers:
|
||||
if p.get("id") == provider_id:
|
||||
return p
|
||||
return None
|
||||
9
wiregui/auth/passwords.py
Normal file
9
wiregui/auth/passwords.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
import bcrypt
|
||||
|
||||
|
||||
def hash_password(plain: str) -> str:
|
||||
return bcrypt.hashpw(plain.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(plain.encode(), hashed.encode())
|
||||
114
wiregui/auth/saml.py
Normal file
114
wiregui/auth/saml.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
"""SAML SP-initiated SSO via python3-saml."""
|
||||
|
||||
from loguru import logger
|
||||
from onelogin.saml2.auth import OneLogin_Saml2_Auth
|
||||
from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser
|
||||
|
||||
from wiregui.config import get_settings
|
||||
|
||||
|
||||
def _build_saml_settings(provider_config: dict) -> dict:
|
||||
"""Build python3-saml settings dict from our provider config."""
|
||||
settings = get_settings()
|
||||
base_url = settings.external_url
|
||||
|
||||
# Parse IdP metadata XML to extract endpoints and certs
|
||||
idp_data = OneLogin_Saml2_IdPMetadataParser.parse(provider_config.get("metadata", ""))
|
||||
idp_settings = idp_data.get("idp", {})
|
||||
|
||||
return {
|
||||
"strict": True,
|
||||
"debug": False,
|
||||
"sp": {
|
||||
"entityId": f"{base_url}/auth/saml/{provider_config['id']}/metadata",
|
||||
"assertionConsumerService": {
|
||||
"url": f"{base_url}/auth/saml/{provider_config['id']}/callback",
|
||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
|
||||
},
|
||||
"NameIDFormat": "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
|
||||
},
|
||||
"idp": idp_settings,
|
||||
"security": {
|
||||
"authnRequestsSigned": provider_config.get("sign_requests", True),
|
||||
"wantAssertionsSigned": provider_config.get("signed_assertion_in_resp", True),
|
||||
"signMetadata": provider_config.get("sign_metadata", True),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def prepare_saml_request(request_data: dict) -> dict:
|
||||
"""Prepare a dict that python3-saml expects from an HTTP request.
|
||||
|
||||
Args:
|
||||
request_data: dict with keys: http_host, script_name, server_port,
|
||||
get_data (dict), post_data (dict), https (str "on"/"off")
|
||||
"""
|
||||
return {
|
||||
"http_host": request_data.get("http_host", "localhost"),
|
||||
"script_name": request_data.get("script_name", ""),
|
||||
"server_port": request_data.get("server_port", 443),
|
||||
"get_data": request_data.get("get_data", {}),
|
||||
"post_data": request_data.get("post_data", {}),
|
||||
"https": request_data.get("https", "on"),
|
||||
}
|
||||
|
||||
|
||||
def create_saml_auth(provider_config: dict, request_data: dict) -> OneLogin_Saml2_Auth:
|
||||
"""Create a python3-saml Auth instance for a provider."""
|
||||
saml_settings = _build_saml_settings(provider_config)
|
||||
req = prepare_saml_request(request_data)
|
||||
return OneLogin_Saml2_Auth(req, saml_settings)
|
||||
|
||||
|
||||
def get_login_url(auth: OneLogin_Saml2_Auth) -> str:
|
||||
"""Get the SSO redirect URL."""
|
||||
return auth.login()
|
||||
|
||||
|
||||
def process_response(auth: OneLogin_Saml2_Auth) -> dict | None:
|
||||
"""Process the SAML response and return user attributes.
|
||||
|
||||
Returns dict with 'email' key, or None on failure.
|
||||
"""
|
||||
auth.process_response()
|
||||
errors = auth.get_errors()
|
||||
if errors:
|
||||
logger.error("SAML response errors: {}", errors)
|
||||
return None
|
||||
|
||||
if not auth.is_authenticated():
|
||||
logger.warning("SAML: user not authenticated")
|
||||
return None
|
||||
|
||||
attrs = auth.get_attributes()
|
||||
name_id = auth.get_nameid()
|
||||
|
||||
# Try to extract email from various attribute names
|
||||
email = (
|
||||
attrs.get("email", [None])[0]
|
||||
or attrs.get("http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress", [None])[0]
|
||||
or attrs.get("urn:oid:0.9.2342.19200300.100.1.3", [None])[0]
|
||||
or name_id
|
||||
)
|
||||
|
||||
if not email:
|
||||
logger.error("SAML: no email found in attributes or NameID")
|
||||
return None
|
||||
|
||||
return {
|
||||
"email": email,
|
||||
"name_id": name_id,
|
||||
"attributes": {k: v for k, v in attrs.items()},
|
||||
}
|
||||
|
||||
|
||||
def get_metadata(provider_config: dict) -> str:
|
||||
"""Generate SP metadata XML."""
|
||||
settings = _build_saml_settings(provider_config)
|
||||
from onelogin.saml2.settings import OneLogin_Saml2_Settings
|
||||
saml_settings = OneLogin_Saml2_Settings(settings, sp_validation_only=True)
|
||||
metadata = saml_settings.get_sp_metadata()
|
||||
errors = saml_settings.validate_metadata(metadata)
|
||||
if errors:
|
||||
logger.error("SP metadata validation errors: {}", errors)
|
||||
return metadata.decode() if isinstance(metadata, bytes) else metadata
|
||||
61
wiregui/auth/seed.py
Normal file
61
wiregui/auth/seed.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""Seed the initial admin user and server keypair on first startup."""
|
||||
|
||||
import secrets
|
||||
|
||||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
async def seed_admin() -> None:
|
||||
"""Create admin user if no users exist in the database."""
|
||||
async with async_session() as session:
|
||||
result = await session.execute(select(User).limit(1))
|
||||
if result.scalar_one_or_none() is not None:
|
||||
return # users already exist
|
||||
|
||||
settings = get_settings()
|
||||
password = settings.admin_password or secrets.token_urlsafe(16)
|
||||
|
||||
admin = User(
|
||||
email=settings.admin_email,
|
||||
password_hash=hash_password(password),
|
||||
role="admin",
|
||||
)
|
||||
session.add(admin)
|
||||
await session.commit()
|
||||
|
||||
logger.info("Admin user created: {}", settings.admin_email)
|
||||
if settings.admin_password is None:
|
||||
logger.warning("Generated admin password: {}", password)
|
||||
|
||||
|
||||
async def ensure_server_keypair() -> None:
|
||||
"""Generate and store the server WireGuard keypair in Configuration if missing."""
|
||||
from wiregui.utils.crypto import generate_keypair
|
||||
|
||||
async with async_session() as session:
|
||||
result = await session.execute(select(Configuration).limit(1))
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config is None:
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
|
||||
if config.server_public_key and config.server_private_key:
|
||||
return # already have keys
|
||||
|
||||
try:
|
||||
private_key, public_key = generate_keypair()
|
||||
config.server_private_key = private_key
|
||||
config.server_public_key = public_key
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
logger.info("Server WireGuard keypair generated (pubkey: {}...)", public_key[:20])
|
||||
except Exception as e:
|
||||
logger.warning("Could not generate server keypair (wg CLI not available?): {}", e)
|
||||
22
wiregui/auth/session.py
Normal file
22
wiregui/auth/session.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""Authentication helpers for NiceGUI pages."""
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.user import User
|
||||
|
||||
|
||||
async def authenticate_user(email: str, password: str) -> User | None:
|
||||
"""Verify email/password and return the User if valid, else None."""
|
||||
from wiregui.auth.passwords import verify_password
|
||||
|
||||
async with async_session() as session:
|
||||
stmt = select(User).where(User.email == email)
|
||||
user = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if user is None or user.password_hash is None:
|
||||
return None
|
||||
if not verify_password(password, user.password_hash):
|
||||
return None
|
||||
if user.disabled_at is not None:
|
||||
return None
|
||||
return user
|
||||
134
wiregui/auth/webauthn.py
Normal file
134
wiregui/auth/webauthn.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""WebAuthn (FIDO2) MFA via the webauthn library.
|
||||
|
||||
Registration and authentication ceremonies for platform (native) and
|
||||
cross-platform (portable/security key) authenticators.
|
||||
"""
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
from webauthn import (
|
||||
generate_authentication_options,
|
||||
generate_registration_options,
|
||||
verify_authentication_response,
|
||||
verify_registration_response,
|
||||
)
|
||||
from webauthn.helpers import bytes_to_base64url, base64url_to_bytes
|
||||
from webauthn.helpers.structs import (
|
||||
AuthenticatorSelectionCriteria,
|
||||
PublicKeyCredentialDescriptor,
|
||||
ResidentKeyRequirement,
|
||||
UserVerificationRequirement,
|
||||
)
|
||||
|
||||
from wiregui.config import get_settings
|
||||
|
||||
|
||||
def _rp_id() -> str:
|
||||
"""Get the Relying Party ID from the external URL hostname."""
|
||||
from urllib.parse import urlparse
|
||||
return urlparse(get_settings().external_url).hostname
|
||||
|
||||
|
||||
def _rp_name() -> str:
|
||||
return "WireGUI"
|
||||
|
||||
|
||||
def _origin() -> str:
|
||||
return get_settings().external_url
|
||||
|
||||
|
||||
def create_registration_options(user_id: UUID, user_email: str, existing_credentials: list[dict] = None) -> dict:
|
||||
"""Generate WebAuthn registration options to send to the browser.
|
||||
|
||||
Returns a dict with 'options' (serialized for JSON) and 'challenge' (to store in session).
|
||||
"""
|
||||
exclude = []
|
||||
for cred in (existing_credentials or []):
|
||||
cred_id = cred.get("credential_id")
|
||||
if cred_id:
|
||||
exclude.append(PublicKeyCredentialDescriptor(id=base64url_to_bytes(cred_id)))
|
||||
|
||||
options = generate_registration_options(
|
||||
rp_id=_rp_id(),
|
||||
rp_name=_rp_name(),
|
||||
user_id=str(user_id).encode(),
|
||||
user_name=user_email,
|
||||
user_display_name=user_email,
|
||||
exclude_credentials=exclude,
|
||||
authenticator_selection=AuthenticatorSelectionCriteria(
|
||||
resident_key=ResidentKeyRequirement.PREFERRED,
|
||||
user_verification=UserVerificationRequirement.PREFERRED,
|
||||
),
|
||||
)
|
||||
|
||||
# Serialize for JSON transport
|
||||
from webauthn.helpers import options_to_json
|
||||
return {
|
||||
"options_json": options_to_json(options),
|
||||
"challenge": bytes_to_base64url(options.challenge),
|
||||
}
|
||||
|
||||
|
||||
def verify_registration(credential_json: str, challenge: str) -> dict:
|
||||
"""Verify a WebAuthn registration response from the browser.
|
||||
|
||||
Returns a dict with credential data to store in MFAMethod.payload.
|
||||
"""
|
||||
verification = verify_registration_response(
|
||||
credential=credential_json,
|
||||
expected_challenge=base64url_to_bytes(challenge),
|
||||
expected_rp_id=_rp_id(),
|
||||
expected_origin=_origin(),
|
||||
)
|
||||
|
||||
return {
|
||||
"credential_id": bytes_to_base64url(verification.credential_id),
|
||||
"public_key": bytes_to_base64url(verification.credential_public_key),
|
||||
"sign_count": verification.sign_count,
|
||||
"attestation_type": verification.fmt if hasattr(verification, 'fmt') else "none",
|
||||
}
|
||||
|
||||
|
||||
def create_authentication_options(credentials: list[dict]) -> dict:
|
||||
"""Generate WebAuthn authentication options for existing credentials.
|
||||
|
||||
Returns a dict with 'options' (serialized) and 'challenge' (to store in session).
|
||||
"""
|
||||
allow = []
|
||||
for cred in credentials:
|
||||
cred_id = cred.get("credential_id")
|
||||
if cred_id:
|
||||
allow.append(PublicKeyCredentialDescriptor(id=base64url_to_bytes(cred_id)))
|
||||
|
||||
options = generate_authentication_options(
|
||||
rp_id=_rp_id(),
|
||||
allow_credentials=allow,
|
||||
user_verification=UserVerificationRequirement.PREFERRED,
|
||||
)
|
||||
|
||||
from webauthn.helpers import options_to_json
|
||||
return {
|
||||
"options_json": options_to_json(options),
|
||||
"challenge": bytes_to_base64url(options.challenge),
|
||||
}
|
||||
|
||||
|
||||
def verify_authentication(credential_json: str, challenge: str, stored_credential: dict) -> dict:
|
||||
"""Verify a WebAuthn authentication response.
|
||||
|
||||
Returns updated credential data (new sign_count).
|
||||
"""
|
||||
verification = verify_authentication_response(
|
||||
credential=credential_json,
|
||||
expected_challenge=base64url_to_bytes(challenge),
|
||||
expected_rp_id=_rp_id(),
|
||||
expected_origin=_origin(),
|
||||
credential_public_key=base64url_to_bytes(stored_credential["public_key"]),
|
||||
credential_current_sign_count=stored_credential.get("sign_count", 0),
|
||||
)
|
||||
|
||||
return {
|
||||
"new_sign_count": verification.new_sign_count,
|
||||
}
|
||||
55
wiregui/config.py
Normal file
55
wiregui/config.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
from functools import lru_cache
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix="WG_", env_file=".env")
|
||||
|
||||
# Database
|
||||
database_url: str = "postgresql+asyncpg://wiregui:wiregui@localhost/wiregui"
|
||||
|
||||
# Redis / Valkey
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
# Secret key for JWT signing and Fernet encryption
|
||||
secret_key: str = "change-me-in-production"
|
||||
|
||||
# WireGuard
|
||||
wg_enabled: bool = False # set True in production (requires NET_ADMIN capability)
|
||||
wg_interface: str = "wg0"
|
||||
wg_endpoint_host: str = "localhost"
|
||||
wg_endpoint_port: int = 51820
|
||||
wg_ipv4_network: str = "10.3.2.0/24"
|
||||
wg_ipv6_network: str = "fd00::3:2:0/120"
|
||||
wg_dns: str = "1.1.1.1, 1.0.0.1"
|
||||
wg_mtu: int = 1280
|
||||
wg_persistent_keepalive: int = 25
|
||||
wg_allowed_ips: str = "0.0.0.0/0, ::/0"
|
||||
|
||||
# Auth
|
||||
admin_email: str = "admin@localhost"
|
||||
admin_password: str | None = None
|
||||
local_auth_enabled: bool = True
|
||||
magic_link_enabled: bool = True
|
||||
vpn_session_duration: int = 0 # seconds, 0 = unlimited
|
||||
|
||||
# SMTP
|
||||
smtp_host: str | None = None
|
||||
smtp_port: int = 587
|
||||
smtp_user: str | None = None
|
||||
smtp_password: str | None = None
|
||||
smtp_from: str = "wiregui@localhost"
|
||||
|
||||
# Logging
|
||||
log_to_file: bool = True # write timestamped log file to logs/ directory
|
||||
|
||||
# App
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 13000
|
||||
external_url: str = "http://localhost:13000"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
22
wiregui/db.py
Normal file
22
wiregui/db.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from wiregui.config import get_settings
|
||||
|
||||
engine = create_async_engine(get_settings().database_url)
|
||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession]:
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Test database connectivity."""
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
logger.info("Database connection OK")
|
||||
28
wiregui/logging.py
Normal file
28
wiregui/logging.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
"""Loguru configuration for WireGUI."""
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def setup_logging(log_to_file: bool = False) -> None:
|
||||
"""Configure loguru sinks. Call once at startup."""
|
||||
# Remove default stderr sink and re-add with our format
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sys.stderr,
|
||||
format="<green>{time:HH:mm:ss}</green> | <level>{level:<7}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan> - <level>{message}</level>",
|
||||
level="DEBUG",
|
||||
)
|
||||
|
||||
if log_to_file:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
logger.add(
|
||||
f"logs/wiregui_{timestamp}.log",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level:<7} | {name}:{function}:{line} - {message}",
|
||||
level="DEBUG",
|
||||
rotation="10 MB",
|
||||
retention="7 days",
|
||||
)
|
||||
logger.info("File logging enabled: logs/wiregui_{}.log", timestamp)
|
||||
95
wiregui/main.py
Normal file
95
wiregui/main.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
from loguru import logger
|
||||
from nicegui import app, ui
|
||||
|
||||
from wiregui.api.v0 import router as api_router
|
||||
from wiregui.auth.seed import ensure_server_keypair, seed_admin
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.db import init_db
|
||||
from wiregui.logging import setup_logging
|
||||
|
||||
# Mount REST API
|
||||
app.include_router(api_router, prefix="/api")
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
# Import pages so their @ui.page decorators register routes
|
||||
import wiregui.pages.account # noqa: F401
|
||||
import wiregui.pages.admin.devices # noqa: F401
|
||||
import wiregui.pages.admin.diagnostics # noqa: F401
|
||||
import wiregui.pages.admin.rules # noqa: F401
|
||||
import wiregui.pages.admin.settings # noqa: F401
|
||||
import wiregui.pages.admin.users # noqa: F401
|
||||
import wiregui.pages.auth_magic # noqa: F401
|
||||
import wiregui.pages.auth_oidc # noqa: F401
|
||||
import wiregui.pages.auth_saml # noqa: F401
|
||||
import wiregui.pages.devices # noqa: F401
|
||||
import wiregui.pages.home # noqa: F401
|
||||
import wiregui.pages.login # noqa: F401
|
||||
import wiregui.pages.mfa_challenge # noqa: F401
|
||||
|
||||
|
||||
async def startup() -> None:
|
||||
settings = get_settings()
|
||||
setup_logging(log_to_file=settings.log_to_file)
|
||||
await init_db()
|
||||
await seed_admin()
|
||||
await ensure_server_keypair()
|
||||
|
||||
# Register OIDC providers from config
|
||||
from wiregui.auth.oidc import register_providers
|
||||
await register_providers()
|
||||
|
||||
from wiregui.tasks import register_task
|
||||
from wiregui.tasks.oidc_refresh import oidc_refresh_loop
|
||||
|
||||
from wiregui.tasks.connectivity import connectivity_loop
|
||||
from wiregui.tasks.vpn_session import vpn_session_loop
|
||||
|
||||
# Always run these tasks (even without WG for OIDC refresh and connectivity)
|
||||
register_task(oidc_refresh_loop(), name="oidc-refresh")
|
||||
register_task(connectivity_loop(), name="connectivity-check")
|
||||
|
||||
if settings.wg_enabled:
|
||||
from wiregui.services.firewall import setup_base_tables, setup_masquerade
|
||||
from wiregui.services.wireguard import configure_interface, ensure_interface
|
||||
from wiregui.tasks.reconcile import reconcile
|
||||
from wiregui.tasks.stats import stats_loop
|
||||
|
||||
await ensure_interface()
|
||||
await configure_interface()
|
||||
await setup_base_tables()
|
||||
await setup_masquerade()
|
||||
await reconcile()
|
||||
register_task(stats_loop(), name="wg-stats")
|
||||
register_task(vpn_session_loop(), name="vpn-session-expiry")
|
||||
else:
|
||||
logger.info("WireGuard disabled (WG_WG_ENABLED=false) — running in UI-only mode")
|
||||
|
||||
logger.info("WireGUI ready")
|
||||
|
||||
|
||||
async def shutdown() -> None:
|
||||
from wiregui.tasks import cancel_all
|
||||
await cancel_all()
|
||||
|
||||
|
||||
app.on_startup(startup)
|
||||
app.on_shutdown(shutdown)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
settings = get_settings()
|
||||
ui.run(
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
title="WireGUI",
|
||||
storage_secret=settings.secret_key,
|
||||
reload=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ in {"__main__", "__mp_main__"}:
|
||||
main()
|
||||
21
wiregui/models/__init__.py
Normal file
21
wiregui/models/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
"""All SQLModel table models — imported here so Alembic autogenerate can discover them."""
|
||||
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.connectivity_check import ConnectivityCheck
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
|
||||
__all__ = [
|
||||
"ApiToken",
|
||||
"Configuration",
|
||||
"ConnectivityCheck",
|
||||
"Device",
|
||||
"MFAMethod",
|
||||
"OIDCConnection",
|
||||
"Rule",
|
||||
"User",
|
||||
]
|
||||
24
wiregui/models/api_token.py
Normal file
24
wiregui/models/api_token.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
from datetime import datetime
|
||||
|
||||
from wiregui.utils.time import utcnow
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, Relationship, SQLModel
|
||||
|
||||
|
||||
class ApiToken(SQLModel, table=True):
|
||||
__tablename__ = "api_tokens"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
token_hash: str = Field(unique=True, index=True)
|
||||
expires_at: datetime | None = None
|
||||
|
||||
user_id: UUID = Field(foreign_key="users.id", index=True)
|
||||
|
||||
inserted_at: datetime = Field(default_factory=utcnow)
|
||||
|
||||
# Relationships
|
||||
user: "User" = Relationship(back_populates="api_tokens")
|
||||
|
||||
|
||||
from wiregui.models.user import User # noqa: E402, F401
|
||||
61
wiregui/models/configuration.py
Normal file
61
wiregui/models/configuration.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
from datetime import datetime
|
||||
|
||||
from wiregui.utils.time import utcnow
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, JSON, Column, SQLModel
|
||||
|
||||
|
||||
class Configuration(SQLModel, table=True):
|
||||
__tablename__ = "configurations"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
|
||||
# Device management permissions
|
||||
allow_unprivileged_device_management: bool = Field(default=True)
|
||||
allow_unprivileged_device_configuration: bool = Field(default=True)
|
||||
|
||||
# Auth
|
||||
local_auth_enabled: bool = Field(default=True)
|
||||
disable_vpn_on_oidc_error: bool = Field(default=False)
|
||||
|
||||
# Client defaults
|
||||
default_client_persistent_keepalive: int = Field(default=25)
|
||||
default_client_mtu: int = Field(default=1280)
|
||||
default_client_endpoint: str | None = None
|
||||
default_client_dns: list[str] = Field(
|
||||
default_factory=lambda: ["1.1.1.1", "1.0.0.1"],
|
||||
sa_column=Column(JSON, default=["1.1.1.1", "1.0.0.1"]),
|
||||
)
|
||||
default_client_allowed_ips: list[str] = Field(
|
||||
default_factory=lambda: ["0.0.0.0/0", "::/0"],
|
||||
sa_column=Column(JSON, default=["0.0.0.0/0", "::/0"]),
|
||||
)
|
||||
|
||||
# Server WireGuard keypair (generated on first startup)
|
||||
server_private_key: str | None = None
|
||||
server_public_key: str | None = None
|
||||
|
||||
# VPN session
|
||||
vpn_session_duration: int = Field(default=0) # seconds, 0 = unlimited
|
||||
|
||||
# Logo
|
||||
logo_url: str | None = None
|
||||
logo_type: str | None = None # "url" | "file" | "upload" | None (default)
|
||||
|
||||
# OIDC providers (list of dicts)
|
||||
# Each: {id, label, scope, response_type, client_id, client_secret,
|
||||
# discovery_document_uri, redirect_uri, auto_create_users}
|
||||
openid_connect_providers: list[dict] = Field(
|
||||
default_factory=list, sa_column=Column(JSON, default=[])
|
||||
)
|
||||
|
||||
# SAML identity providers (list of dicts)
|
||||
# Each: {id, label, base_url, metadata, sign_requests, sign_metadata,
|
||||
# signed_assertion_in_resp, signed_envelopes_in_resp, auto_create_users}
|
||||
saml_identity_providers: list[dict] = Field(
|
||||
default_factory=list, sa_column=Column(JSON, default=[])
|
||||
)
|
||||
|
||||
inserted_at: datetime = Field(default_factory=utcnow)
|
||||
updated_at: datetime = Field(default_factory=utcnow)
|
||||
18
wiregui/models/connectivity_check.py
Normal file
18
wiregui/models/connectivity_check.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from datetime import datetime
|
||||
|
||||
from wiregui.utils.time import utcnow
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, JSON, Column, SQLModel
|
||||
|
||||
|
||||
class ConnectivityCheck(SQLModel, table=True):
|
||||
__tablename__ = "connectivity_checks"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
url: str
|
||||
response_code: int | None = None
|
||||
response_headers: dict | None = Field(default=None, sa_column=Column(JSON))
|
||||
response_body: str | None = None
|
||||
|
||||
inserted_at: datetime = Field(default_factory=utcnow)
|
||||
52
wiregui/models/device.py
Normal file
52
wiregui/models/device.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from datetime import datetime
|
||||
|
||||
from wiregui.utils.time import utcnow
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, JSON, Column, Relationship, SQLModel
|
||||
|
||||
|
||||
class Device(SQLModel, table=True):
|
||||
__tablename__ = "devices"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
public_key: str = Field(unique=True, index=True)
|
||||
preshared_key: str | None = None # encrypted at application level
|
||||
|
||||
# Client config: use server defaults or per-device overrides
|
||||
use_default_allowed_ips: bool = Field(default=True)
|
||||
use_default_dns: bool = Field(default=True)
|
||||
use_default_endpoint: bool = Field(default=True)
|
||||
use_default_mtu: bool = Field(default=True)
|
||||
use_default_persistent_keepalive: bool = Field(default=True)
|
||||
|
||||
# Per-device overrides (used when use_default_* is False)
|
||||
endpoint: str | None = None
|
||||
mtu: int | None = None
|
||||
persistent_keepalive: int | None = None
|
||||
allowed_ips: list[str] = Field(default_factory=list, sa_column=Column(JSON, default=[]))
|
||||
dns: list[str] = Field(default_factory=list, sa_column=Column(JSON, default=[]))
|
||||
|
||||
# Assigned tunnel addresses
|
||||
ipv4: str | None = Field(default=None, unique=True)
|
||||
ipv6: str | None = Field(default=None, unique=True)
|
||||
|
||||
# Peer stats (updated periodically from WireGuard)
|
||||
remote_ip: str | None = None
|
||||
rx_bytes: int | None = None
|
||||
tx_bytes: int | None = None
|
||||
latest_handshake: datetime | None = None
|
||||
|
||||
user_id: UUID = Field(foreign_key="users.id", index=True)
|
||||
|
||||
inserted_at: datetime = Field(default_factory=utcnow)
|
||||
updated_at: datetime = Field(default_factory=utcnow)
|
||||
|
||||
# Relationships
|
||||
user: "User" = Relationship(back_populates="devices")
|
||||
|
||||
|
||||
from wiregui.models.user import User # noqa: E402, F401
|
||||
27
wiregui/models/mfa_method.py
Normal file
27
wiregui/models/mfa_method.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from datetime import datetime
|
||||
|
||||
from wiregui.utils.time import utcnow
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, JSON, Column, Relationship, SQLModel
|
||||
|
||||
|
||||
class MFAMethod(SQLModel, table=True):
|
||||
__tablename__ = "mfa_methods"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
name: str
|
||||
type: str # "totp" | "native" | "portable"
|
||||
payload: dict = Field(default_factory=dict, sa_column=Column(JSON)) # encrypted at app level
|
||||
last_used_at: datetime | None = None
|
||||
|
||||
user_id: UUID = Field(foreign_key="users.id", index=True)
|
||||
|
||||
inserted_at: datetime = Field(default_factory=utcnow)
|
||||
updated_at: datetime = Field(default_factory=utcnow)
|
||||
|
||||
# Relationships
|
||||
user: "User" = Relationship(back_populates="mfa_methods")
|
||||
|
||||
|
||||
from wiregui.models.user import User # noqa: E402, F401
|
||||
27
wiregui/models/oidc_connection.py
Normal file
27
wiregui/models/oidc_connection.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from datetime import datetime
|
||||
|
||||
from wiregui.utils.time import utcnow
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, JSON, Column, Relationship, SQLModel
|
||||
|
||||
|
||||
class OIDCConnection(SQLModel, table=True):
|
||||
__tablename__ = "oidc_connections"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
provider: str
|
||||
refresh_token: str | None = None # encrypted at application level
|
||||
refresh_response: dict | None = Field(default=None, sa_column=Column(JSON))
|
||||
refreshed_at: datetime | None = None
|
||||
|
||||
user_id: UUID = Field(foreign_key="users.id", index=True)
|
||||
|
||||
inserted_at: datetime = Field(default_factory=utcnow)
|
||||
updated_at: datetime = Field(default_factory=utcnow)
|
||||
|
||||
# Relationships
|
||||
user: "User" = Relationship(back_populates="oidc_connections")
|
||||
|
||||
|
||||
from wiregui.models.user import User # noqa: E402, F401
|
||||
27
wiregui/models/rule.py
Normal file
27
wiregui/models/rule.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from datetime import datetime
|
||||
|
||||
from wiregui.utils.time import utcnow
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, Relationship, SQLModel
|
||||
|
||||
|
||||
class Rule(SQLModel, table=True):
|
||||
__tablename__ = "rules"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
action: str = Field(default="drop") # "drop" | "accept"
|
||||
destination: str # CIDR notation, e.g. "10.0.0.0/8" or "0.0.0.0/0"
|
||||
port_type: str | None = None # "tcp" | "udp" | None (any)
|
||||
port_range: str | None = None # e.g. "80-443" or "22" or None (any)
|
||||
|
||||
user_id: UUID | None = Field(default=None, foreign_key="users.id", index=True)
|
||||
|
||||
inserted_at: datetime = Field(default_factory=utcnow)
|
||||
updated_at: datetime = Field(default_factory=utcnow)
|
||||
|
||||
# Relationships
|
||||
user: "User" = Relationship(back_populates="rules")
|
||||
|
||||
|
||||
from wiregui.models.user import User # noqa: E402, F401
|
||||
41
wiregui/models/user.py
Normal file
41
wiregui/models/user.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
from datetime import datetime
|
||||
|
||||
from wiregui.utils.time import utcnow
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlmodel import Field, Relationship, SQLModel
|
||||
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
email: str = Field(unique=True, index=True)
|
||||
password_hash: str | None = None
|
||||
role: str = Field(default="unprivileged") # "admin" | "unprivileged"
|
||||
|
||||
last_signed_in_at: datetime | None = None
|
||||
last_signed_in_method: str | None = None
|
||||
|
||||
sign_in_token_hash: str | None = None
|
||||
sign_in_token_created_at: datetime | None = None
|
||||
|
||||
disabled_at: datetime | None = None
|
||||
|
||||
inserted_at: datetime = Field(default_factory=utcnow)
|
||||
updated_at: datetime = Field(default_factory=utcnow)
|
||||
|
||||
# Relationships
|
||||
devices: list["Device"] = Relationship(back_populates="user")
|
||||
oidc_connections: list["OIDCConnection"] = Relationship(back_populates="user")
|
||||
api_tokens: list["ApiToken"] = Relationship(back_populates="user")
|
||||
mfa_methods: list["MFAMethod"] = Relationship(back_populates="user")
|
||||
rules: list["Rule"] = Relationship(back_populates="user")
|
||||
|
||||
|
||||
# Avoid circular imports — these are resolved at runtime by SQLModel
|
||||
from wiregui.models.api_token import ApiToken # noqa: E402, F401
|
||||
from wiregui.models.device import Device # noqa: E402, F401
|
||||
from wiregui.models.mfa_method import MFAMethod # noqa: E402, F401
|
||||
from wiregui.models.oidc_connection import OIDCConnection # noqa: E402, F401
|
||||
from wiregui.models.rule import Rule # noqa: E402, F401
|
||||
0
wiregui/pages/__init__.py
Normal file
0
wiregui/pages/__init__.py
Normal file
388
wiregui/pages/account.py
Normal file
388
wiregui/pages/account.py
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
"""User account page — password change, MFA management, API tokens."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
from nicegui import app, ui
|
||||
from sqlmodel import select
|
||||
|
||||
import json
|
||||
|
||||
from wiregui.auth.api_token import generate_api_token
|
||||
from wiregui.auth.mfa import generate_totp_qr_svg, generate_totp_secret, get_totp_uri, verify_totp_code
|
||||
from wiregui.auth.passwords import hash_password, verify_password
|
||||
from wiregui.auth.webauthn import create_registration_options, verify_registration
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.api_token import ApiToken
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.user import User
|
||||
from wiregui.pages.layout import layout
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
@ui.page("/account")
|
||||
async def account_page():
|
||||
if not app.storage.user.get("authenticated"):
|
||||
return ui.navigate.to("/login")
|
||||
|
||||
layout()
|
||||
user_id = UUID(app.storage.user["user_id"])
|
||||
|
||||
async with async_session() as session:
|
||||
user = await session.get(User, user_id)
|
||||
|
||||
with ui.column().classes("w-full p-4"):
|
||||
ui.label("Account Settings").classes("text-h5 q-mb-md")
|
||||
|
||||
with ui.tabs().classes("w-full") as tabs:
|
||||
profile_tab = ui.tab("Profile")
|
||||
mfa_tab = ui.tab("Two-Factor Auth")
|
||||
tokens_tab = ui.tab("API Tokens")
|
||||
|
||||
with ui.tab_panels(tabs, value=profile_tab).classes("w-full"):
|
||||
|
||||
# === Profile ===
|
||||
with ui.tab_panel(profile_tab):
|
||||
with ui.card().classes("w-full"):
|
||||
ui.label("Account Details").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
with ui.grid(columns=2).classes("w-full gap-2 q-pa-sm"):
|
||||
ui.label("Email:").classes("text-bold")
|
||||
ui.label(user.email)
|
||||
ui.label("Role:").classes("text-bold")
|
||||
ui.label(user.role)
|
||||
ui.label("Last Sign-in:").classes("text-bold")
|
||||
ui.label(str(user.last_signed_in_at)[:19] if user.last_signed_in_at else "-")
|
||||
ui.label("Method:").classes("text-bold")
|
||||
ui.label(user.last_signed_in_method or "-")
|
||||
|
||||
with ui.card().classes("w-full q-mt-md"):
|
||||
ui.label("Change Password").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
|
||||
current_pw = ui.input("Current Password", password=True, password_toggle_button=True).props("outlined dense").classes("w-full")
|
||||
new_pw = ui.input("New Password", password=True, password_toggle_button=True).props("outlined dense").classes("w-full")
|
||||
confirm_pw = ui.input("Confirm New Password", password=True, password_toggle_button=True).props("outlined dense").classes("w-full")
|
||||
|
||||
async def change_password():
|
||||
if not current_pw.value or not new_pw.value:
|
||||
ui.notify("Fill in all password fields", type="negative")
|
||||
return
|
||||
if new_pw.value != confirm_pw.value:
|
||||
ui.notify("New passwords do not match", type="negative")
|
||||
return
|
||||
if len(new_pw.value) < 8:
|
||||
ui.notify("Password must be at least 8 characters", type="negative")
|
||||
return
|
||||
|
||||
async with async_session() as session:
|
||||
u = await session.get(User, user_id)
|
||||
if not verify_password(current_pw.value, u.password_hash):
|
||||
ui.notify("Current password is incorrect", type="negative")
|
||||
return
|
||||
u.password_hash = hash_password(new_pw.value)
|
||||
session.add(u)
|
||||
await session.commit()
|
||||
|
||||
logger.info("Password changed for {}", user.email)
|
||||
ui.notify("Password changed", type="positive")
|
||||
current_pw.value = ""
|
||||
new_pw.value = ""
|
||||
confirm_pw.value = ""
|
||||
|
||||
ui.button("Change Password", on_click=change_password).props("color=primary").classes("q-mt-sm")
|
||||
|
||||
# OIDC connections
|
||||
async with async_session() as session:
|
||||
oidc_conns = (await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.user_id == user_id)
|
||||
)).scalars().all()
|
||||
|
||||
if oidc_conns:
|
||||
with ui.card().classes("w-full q-mt-md"):
|
||||
ui.label("Connected SSO Providers").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
for conn in oidc_conns:
|
||||
with ui.row().classes("w-full items-center justify-between q-pa-xs"):
|
||||
ui.label(f"{conn.provider}").classes("text-bold")
|
||||
ui.label(f"Last refreshed: {str(conn.refreshed_at)[:19] if conn.refreshed_at else 'Never'}")
|
||||
|
||||
# === MFA ===
|
||||
with ui.tab_panel(mfa_tab):
|
||||
await _render_mfa_panel(user_id, user.email)
|
||||
|
||||
# === API Tokens ===
|
||||
with ui.tab_panel(tokens_tab):
|
||||
await _render_tokens_panel(user_id)
|
||||
|
||||
|
||||
async def _render_mfa_panel(user_id: UUID, email: str):
|
||||
"""Render the MFA management tab."""
|
||||
async def load_methods():
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user_id).order_by(MFAMethod.inserted_at)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def refresh_methods():
|
||||
methods = await load_methods()
|
||||
methods_container.clear()
|
||||
with methods_container:
|
||||
if methods:
|
||||
for m in methods:
|
||||
with ui.row().classes("w-full items-center justify-between q-pa-xs"):
|
||||
with ui.row().classes("items-center gap-2"):
|
||||
ui.icon("security").props("color=primary")
|
||||
ui.label(m.name).classes("text-bold")
|
||||
ui.label(f"({m.type})").classes("text-caption text-grey-7")
|
||||
with ui.row().classes("items-center gap-2"):
|
||||
ui.label(f"Last used: {str(m.last_used_at)[:19] if m.last_used_at else 'Never'}").classes("text-caption")
|
||||
ui.button(icon="delete", on_click=lambda mid=m.id: delete_method(mid)).props("flat dense color=negative")
|
||||
ui.separator()
|
||||
else:
|
||||
ui.label("No MFA methods configured.").classes("text-caption text-grey-7 q-pa-sm")
|
||||
|
||||
async def delete_method(method_id):
|
||||
async with async_session() as session:
|
||||
m = await session.get(MFAMethod, method_id)
|
||||
if m and m.user_id == user_id:
|
||||
await session.delete(m)
|
||||
await session.commit()
|
||||
logger.info("MFA method deleted for user {}", email)
|
||||
ui.notify("MFA method removed")
|
||||
await refresh_methods()
|
||||
|
||||
# Registration state
|
||||
registration = {"secret": None}
|
||||
|
||||
def start_registration():
|
||||
secret = generate_totp_secret()
|
||||
registration["secret"] = secret
|
||||
uri = get_totp_uri(secret, email)
|
||||
svg = generate_totp_qr_svg(uri)
|
||||
|
||||
reg_container.clear()
|
||||
with reg_container:
|
||||
ui.label("Scan this QR code with your authenticator app:").classes("text-body2")
|
||||
ui.html(svg).classes("w-64 q-my-sm")
|
||||
ui.label(f"Or enter this secret manually: {secret}").classes("text-caption font-mono")
|
||||
reg_name_input = ui.input("Method Name", value="Authenticator").props("outlined dense").classes("w-full")
|
||||
reg_code_input = ui.input("Verification Code", placeholder="Enter 6-digit code").props("outlined dense maxlength=6").classes("w-full")
|
||||
|
||||
async def verify_and_save():
|
||||
code = reg_code_input.value.strip()
|
||||
name = reg_name_input.value.strip() or "Authenticator"
|
||||
if not verify_totp_code(registration["secret"], code):
|
||||
ui.notify("Invalid code — check your authenticator", type="negative")
|
||||
return
|
||||
|
||||
async with async_session() as session:
|
||||
method = MFAMethod(
|
||||
name=name,
|
||||
type="totp",
|
||||
payload={"secret": registration["secret"]},
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(method)
|
||||
await session.commit()
|
||||
|
||||
logger.info("MFA TOTP registered for {}", email)
|
||||
ui.notify("MFA method added!", type="positive")
|
||||
registration["secret"] = None
|
||||
reg_container.clear()
|
||||
await refresh_methods()
|
||||
|
||||
ui.button("Verify & Save", on_click=verify_and_save).props("color=primary").classes("q-mt-sm")
|
||||
ui.button("Cancel", on_click=lambda: reg_container.clear()).props("flat")
|
||||
|
||||
with ui.card().classes("w-full"):
|
||||
ui.label("Two-Factor Authentication Methods").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
|
||||
methods_container = ui.column().classes("w-full")
|
||||
await refresh_methods()
|
||||
|
||||
with ui.row().classes("q-mt-sm gap-2"):
|
||||
ui.button("Add TOTP Method", icon="add", on_click=start_registration).props("outline")
|
||||
ui.button("Add Security Key", icon="key", on_click=lambda: start_webauthn_registration()).props("outline")
|
||||
|
||||
reg_container = ui.column().classes("w-full q-mt-md")
|
||||
webauthn_state = {"challenge": None}
|
||||
|
||||
async def start_webauthn_registration():
|
||||
# Get existing webauthn credentials to exclude
|
||||
existing = []
|
||||
async with async_session() as session:
|
||||
from sqlmodel import select as sel
|
||||
result = await session.execute(
|
||||
sel(MFAMethod).where(MFAMethod.user_id == user_id, MFAMethod.type.in_(["native", "portable"]))
|
||||
)
|
||||
for m in result.scalars().all():
|
||||
existing.append(m.payload)
|
||||
|
||||
try:
|
||||
reg_data = create_registration_options(user_id, email, existing)
|
||||
except Exception as e:
|
||||
ui.notify(f"WebAuthn not available: {e}", type="negative")
|
||||
return
|
||||
|
||||
webauthn_state["challenge"] = reg_data["challenge"]
|
||||
options_json = reg_data["options_json"]
|
||||
|
||||
# Call browser's navigator.credentials.create() via JavaScript
|
||||
js = f"""
|
||||
async function() {{
|
||||
try {{
|
||||
const options = JSON.parse('{options_json}');
|
||||
// Convert base64url strings to ArrayBuffers
|
||||
options.challenge = Uint8Array.from(atob(options.challenge.replace(/-/g,'+').replace(/_/g,'/')), c => c.charCodeAt(0));
|
||||
options.user.id = Uint8Array.from(atob(options.user.id.replace(/-/g,'+').replace(/_/g,'/')), c => c.charCodeAt(0));
|
||||
if (options.excludeCredentials) {{
|
||||
options.excludeCredentials = options.excludeCredentials.map(c => ({{
|
||||
...c,
|
||||
id: Uint8Array.from(atob(c.id.replace(/-/g,'+').replace(/_/g,'/')), ch => ch.charCodeAt(0))
|
||||
}}));
|
||||
}}
|
||||
const credential = await navigator.credentials.create({{publicKey: options}});
|
||||
// Serialize the response
|
||||
const response = {{
|
||||
id: credential.id,
|
||||
rawId: btoa(String.fromCharCode(...new Uint8Array(credential.rawId))).replace(/\\+/g,'-').replace(/\\//g,'_').replace(/=/g,''),
|
||||
type: credential.type,
|
||||
response: {{
|
||||
attestationObject: btoa(String.fromCharCode(...new Uint8Array(credential.response.attestationObject))).replace(/\\+/g,'-').replace(/\\//g,'_').replace(/=/g,''),
|
||||
clientDataJSON: btoa(String.fromCharCode(...new Uint8Array(credential.response.clientDataJSON))).replace(/\\+/g,'-').replace(/\\//g,'_').replace(/=/g,''),
|
||||
}},
|
||||
}};
|
||||
return JSON.stringify(response);
|
||||
}} catch(e) {{
|
||||
return JSON.stringify({{"error": e.message}});
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
result = await ui.run_javascript(f"({js})()")
|
||||
await _handle_webauthn_response(result)
|
||||
|
||||
async def _handle_webauthn_response(result_json: str):
|
||||
try:
|
||||
result = json.loads(result_json)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
ui.notify("WebAuthn response error", type="negative")
|
||||
return
|
||||
|
||||
if "error" in result:
|
||||
ui.notify(f"WebAuthn failed: {result['error']}", type="negative")
|
||||
return
|
||||
|
||||
challenge = webauthn_state.get("challenge")
|
||||
if not challenge:
|
||||
ui.notify("No pending WebAuthn challenge", type="negative")
|
||||
return
|
||||
|
||||
try:
|
||||
credential_data = verify_registration(result_json, challenge)
|
||||
except Exception as e:
|
||||
ui.notify(f"Verification failed: {e}", type="negative")
|
||||
return
|
||||
|
||||
async with async_session() as session:
|
||||
method = MFAMethod(
|
||||
name="Security Key",
|
||||
type="portable",
|
||||
payload=credential_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(method)
|
||||
await session.commit()
|
||||
|
||||
logger.info("WebAuthn key registered for {}", email)
|
||||
ui.notify("Security key registered!", type="positive")
|
||||
webauthn_state["challenge"] = None
|
||||
await refresh_methods()
|
||||
|
||||
|
||||
async def _render_tokens_panel(user_id: UUID):
|
||||
"""Render the API tokens tab."""
|
||||
async def load_tokens():
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(ApiToken).where(ApiToken.user_id == user_id).order_by(ApiToken.inserted_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def refresh_tokens():
|
||||
tokens = await load_tokens()
|
||||
token_table.rows = [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"created": str(t.inserted_at)[:19],
|
||||
"expires": str(t.expires_at)[:19] if t.expires_at else "Never",
|
||||
"status": "Expired" if t.expires_at and t.expires_at < utcnow() else "Active",
|
||||
}
|
||||
for t in tokens
|
||||
]
|
||||
token_table.update()
|
||||
|
||||
async def create_token():
|
||||
from datetime import timedelta
|
||||
days = int(token_days.value) if token_days.value else 30
|
||||
|
||||
plaintext, token_hash = generate_api_token()
|
||||
expires_at = utcnow() + timedelta(days=days) if days > 0 else None
|
||||
|
||||
async with async_session() as session:
|
||||
token = ApiToken(token_hash=token_hash, expires_at=expires_at, user_id=user_id)
|
||||
session.add(token)
|
||||
await session.commit()
|
||||
|
||||
logger.info("API token created (expires in {} days)", days)
|
||||
|
||||
# Show the token once
|
||||
with ui.dialog(value=True) as token_dialog:
|
||||
with ui.card().classes("w-96"):
|
||||
ui.label("API Token Created").classes("text-h6")
|
||||
ui.label("Copy this token now — it won't be shown again.").classes("text-caption text-negative")
|
||||
ui.input(value=plaintext).props("readonly outlined dense").classes("w-full font-mono q-mt-sm")
|
||||
ui.button("Close", on_click=token_dialog.close).props("flat").classes("w-full q-mt-sm")
|
||||
|
||||
await refresh_tokens()
|
||||
|
||||
async def delete_token(token_id: str):
|
||||
async with async_session() as session:
|
||||
t = await session.get(ApiToken, UUID(token_id))
|
||||
if t and t.user_id == user_id:
|
||||
await session.delete(t)
|
||||
await session.commit()
|
||||
ui.notify("Token deleted")
|
||||
await refresh_tokens()
|
||||
|
||||
with ui.card().classes("w-full"):
|
||||
ui.label("API Tokens").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
ui.label("Use API tokens for programmatic access to the REST API.").classes("text-caption text-grey-7")
|
||||
|
||||
token_columns = [
|
||||
{"name": "created", "label": "Created", "field": "created", "align": "left"},
|
||||
{"name": "expires", "label": "Expires", "field": "expires", "align": "left"},
|
||||
{"name": "status", "label": "Status", "field": "status", "align": "left"},
|
||||
{"name": "actions", "label": "", "field": "id", "align": "center"},
|
||||
]
|
||||
token_table = ui.table(columns=token_columns, rows=[], row_key="id").classes("w-full")
|
||||
token_table.add_slot(
|
||||
"body-cell-actions",
|
||||
'''
|
||||
<q-td :props="props">
|
||||
<q-btn flat dense icon="delete" color="negative"
|
||||
@click.stop="() => $parent.$emit('delete', props.row.id)" />
|
||||
</q-td>
|
||||
''',
|
||||
)
|
||||
token_table.on("delete", lambda e: delete_token(e.args))
|
||||
|
||||
with ui.row().classes("items-center gap-2 q-mt-sm"):
|
||||
token_days = ui.input("Expires in (days)", value="30").props("outlined dense").classes("w-40")
|
||||
ui.button("Create Token", icon="add", on_click=create_token).props("color=primary")
|
||||
|
||||
await refresh_tokens()
|
||||
0
wiregui/pages/admin/__init__.py
Normal file
0
wiregui/pages/admin/__init__.py
Normal file
350
wiregui/pages/admin/devices.py
Normal file
350
wiregui/pages/admin/devices.py
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
"""Admin device management — view and manage all devices across all users."""
|
||||
|
||||
import io
|
||||
from uuid import UUID
|
||||
|
||||
import qrcode
|
||||
import qrcode.image.svg
|
||||
from loguru import logger
|
||||
from nicegui import app, ui
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.user import User
|
||||
from wiregui.pages.layout import layout
|
||||
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated
|
||||
from wiregui.utils.crypto import generate_keypair, generate_preshared_key
|
||||
from wiregui.utils.network import allocate_ipv4, allocate_ipv6
|
||||
from wiregui.utils.server_key import get_server_public_key
|
||||
from wiregui.utils.wg_conf import build_client_config
|
||||
|
||||
|
||||
def _guard():
|
||||
if not app.storage.user.get("authenticated") or app.storage.user.get("role") != "admin":
|
||||
ui.navigate.to("/login")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _format_bytes(b: int | None) -> str:
|
||||
if b is None:
|
||||
return "-"
|
||||
for unit in ("B", "KB", "MB", "GB", "TB"):
|
||||
if b < 1024:
|
||||
return f"{b:.1f} {unit}"
|
||||
b /= 1024
|
||||
return f"{b:.1f} PB"
|
||||
|
||||
|
||||
@ui.page("/admin/devices")
|
||||
async def admin_devices_page():
|
||||
if not _guard():
|
||||
return
|
||||
|
||||
layout()
|
||||
|
||||
# Load users for filter and create form
|
||||
async with async_session() as session:
|
||||
users = (await session.execute(select(User).order_by(User.email))).scalars().all()
|
||||
user_map = {str(u.id): u.email for u in users}
|
||||
|
||||
async def load_devices(user_filter: str | None = None) -> list[dict]:
|
||||
async with async_session() as session:
|
||||
stmt = select(Device).order_by(Device.inserted_at.desc())
|
||||
if user_filter and user_filter != "all":
|
||||
stmt = stmt.where(Device.user_id == UUID(user_filter))
|
||||
result = await session.execute(stmt)
|
||||
return [
|
||||
{
|
||||
"id": str(d.id),
|
||||
"name": d.name,
|
||||
"user": user_map.get(str(d.user_id), "Unknown"),
|
||||
"ipv4": d.ipv4 or "-",
|
||||
"ipv6": d.ipv6 or "-",
|
||||
"public_key": d.public_key[:16] + "...",
|
||||
"rx": _format_bytes(d.rx_bytes),
|
||||
"tx": _format_bytes(d.tx_bytes),
|
||||
"handshake": str(d.latest_handshake)[:19] if d.latest_handshake else "-",
|
||||
}
|
||||
for d in result.scalars().all()
|
||||
]
|
||||
|
||||
async def refresh_table():
|
||||
table.rows = await load_devices(user_filter_select.value)
|
||||
table.update()
|
||||
|
||||
async def on_filter_change():
|
||||
await refresh_table()
|
||||
|
||||
# --- Create device ---
|
||||
async def create_device():
|
||||
name = create_name.value.strip()
|
||||
owner_id = create_user_select.value
|
||||
if not name or not owner_id:
|
||||
ui.notify("Name and user are required", type="negative")
|
||||
return
|
||||
|
||||
try:
|
||||
settings = get_settings()
|
||||
private_key, public_key = generate_keypair()
|
||||
psk = generate_preshared_key()
|
||||
|
||||
async with async_session() as session:
|
||||
ipv4 = await allocate_ipv4(session, settings.wg_ipv4_network)
|
||||
ipv6 = await allocate_ipv6(session, settings.wg_ipv6_network)
|
||||
|
||||
device = Device(
|
||||
name=name,
|
||||
description=create_desc.value.strip() or None,
|
||||
public_key=public_key,
|
||||
preshared_key=psk,
|
||||
ipv4=ipv4,
|
||||
ipv6=ipv6,
|
||||
user_id=UUID(owner_id),
|
||||
# Apply config overrides if not using defaults
|
||||
use_default_allowed_ips=create_use_default_ips.value,
|
||||
use_default_dns=create_use_default_dns.value,
|
||||
use_default_endpoint=create_use_default_endpoint.value,
|
||||
use_default_mtu=create_use_default_mtu.value,
|
||||
use_default_persistent_keepalive=create_use_default_keepalive.value,
|
||||
endpoint=create_endpoint.value.strip() or None if not create_use_default_endpoint.value else None,
|
||||
dns=([s.strip() for s in create_dns.value.split(",") if s.strip()]
|
||||
if not create_use_default_dns.value and create_dns.value else []),
|
||||
mtu=int(create_mtu.value) if not create_use_default_mtu.value and create_mtu.value else None,
|
||||
persistent_keepalive=(int(create_keepalive.value)
|
||||
if not create_use_default_keepalive.value and create_keepalive.value else None),
|
||||
allowed_ips=([s.strip() for s in create_allowed_ips.value.split(",") if s.strip()]
|
||||
if not create_use_default_ips.value and create_allowed_ips.value else []),
|
||||
)
|
||||
session.add(device)
|
||||
await session.commit()
|
||||
await session.refresh(device)
|
||||
|
||||
logger.info("Admin created device: {} for {}", device.name, user_map.get(owner_id))
|
||||
await on_device_created(device)
|
||||
|
||||
# Show config
|
||||
server_pubkey = await get_server_public_key()
|
||||
config_text = build_client_config(device, private_key, server_pubkey)
|
||||
_show_config_dialog(device.name, config_text)
|
||||
|
||||
create_dialog.close()
|
||||
_reset_create_form()
|
||||
await refresh_table()
|
||||
except Exception as e:
|
||||
logger.error("Failed to create device: {}", e)
|
||||
ui.notify(f"Error: {e}", type="negative")
|
||||
|
||||
def _reset_create_form():
|
||||
create_name.value = ""
|
||||
create_desc.value = ""
|
||||
create_use_default_ips.value = True
|
||||
create_use_default_dns.value = True
|
||||
create_use_default_endpoint.value = True
|
||||
create_use_default_mtu.value = True
|
||||
create_use_default_keepalive.value = True
|
||||
|
||||
# --- Edit device ---
|
||||
edit_device_id = {"value": None}
|
||||
|
||||
async def open_edit(device_id: str):
|
||||
async with async_session() as session:
|
||||
device = await session.get(Device, UUID(device_id))
|
||||
if not device:
|
||||
return
|
||||
|
||||
edit_device_id["value"] = device_id
|
||||
edit_name.value = device.name
|
||||
edit_desc.value = device.description or ""
|
||||
edit_use_default_ips.value = device.use_default_allowed_ips
|
||||
edit_use_default_dns.value = device.use_default_dns
|
||||
edit_use_default_endpoint.value = device.use_default_endpoint
|
||||
edit_use_default_mtu.value = device.use_default_mtu
|
||||
edit_use_default_keepalive.value = device.use_default_persistent_keepalive
|
||||
edit_endpoint.value = device.endpoint or ""
|
||||
edit_dns.value = ", ".join(device.dns) if device.dns else ""
|
||||
edit_mtu.value = str(device.mtu) if device.mtu else ""
|
||||
edit_keepalive.value = str(device.persistent_keepalive) if device.persistent_keepalive else ""
|
||||
edit_allowed_ips.value = ", ".join(device.allowed_ips) if device.allowed_ips else ""
|
||||
edit_dialog.open()
|
||||
|
||||
async def save_edit():
|
||||
did = edit_device_id["value"]
|
||||
if not did:
|
||||
return
|
||||
|
||||
async with async_session() as session:
|
||||
device = await session.get(Device, UUID(did))
|
||||
if not device:
|
||||
return
|
||||
|
||||
device.name = edit_name.value.strip()
|
||||
device.description = edit_desc.value.strip() or None
|
||||
device.use_default_allowed_ips = edit_use_default_ips.value
|
||||
device.use_default_dns = edit_use_default_dns.value
|
||||
device.use_default_endpoint = edit_use_default_endpoint.value
|
||||
device.use_default_mtu = edit_use_default_mtu.value
|
||||
device.use_default_persistent_keepalive = edit_use_default_keepalive.value
|
||||
|
||||
if not device.use_default_endpoint:
|
||||
device.endpoint = edit_endpoint.value.strip() or None
|
||||
if not device.use_default_dns:
|
||||
device.dns = [s.strip() for s in edit_dns.value.split(",") if s.strip()]
|
||||
if not device.use_default_mtu:
|
||||
device.mtu = int(edit_mtu.value) if edit_mtu.value else None
|
||||
if not device.use_default_persistent_keepalive:
|
||||
device.persistent_keepalive = int(edit_keepalive.value) if edit_keepalive.value else None
|
||||
if not device.use_default_allowed_ips:
|
||||
device.allowed_ips = [s.strip() for s in edit_allowed_ips.value.split(",") if s.strip()]
|
||||
|
||||
session.add(device)
|
||||
await session.commit()
|
||||
await session.refresh(device)
|
||||
await on_device_updated(device)
|
||||
|
||||
logger.info("Admin updated device: {}", edit_name.value)
|
||||
ui.notify("Device updated")
|
||||
edit_dialog.close()
|
||||
await refresh_table()
|
||||
|
||||
# --- Delete device ---
|
||||
async def delete_device(device_id: str):
|
||||
async with async_session() as session:
|
||||
device = await session.get(Device, UUID(device_id))
|
||||
if device:
|
||||
await session.delete(device)
|
||||
await session.commit()
|
||||
logger.info("Admin deleted device: {}", device.name)
|
||||
await on_device_deleted(device)
|
||||
ui.notify(f"Deleted {device.name}")
|
||||
await refresh_table()
|
||||
|
||||
# --- Page content ---
|
||||
with ui.column().classes("w-full p-4"):
|
||||
with ui.row().classes("w-full items-center justify-between"):
|
||||
ui.label("All Devices").classes("text-h5")
|
||||
with ui.row().classes("items-center gap-4"):
|
||||
filter_options = {"all": "All Users"}
|
||||
filter_options.update(user_map)
|
||||
user_filter_select = ui.select(
|
||||
filter_options, value="all", label="Filter by User",
|
||||
on_change=lambda: on_filter_change(),
|
||||
).props("outlined dense").classes("w-48")
|
||||
ui.button("Add Device", icon="add", on_click=lambda: create_dialog.open()).props("color=primary")
|
||||
|
||||
columns = [
|
||||
{"name": "name", "label": "Name", "field": "name", "align": "left", "sortable": True},
|
||||
{"name": "user", "label": "User", "field": "user", "align": "left", "sortable": True},
|
||||
{"name": "ipv4", "label": "IPv4", "field": "ipv4", "align": "left"},
|
||||
{"name": "ipv6", "label": "IPv6", "field": "ipv6", "align": "left"},
|
||||
{"name": "public_key", "label": "Public Key", "field": "public_key", "align": "left"},
|
||||
{"name": "rx", "label": "RX", "field": "rx", "align": "right"},
|
||||
{"name": "tx", "label": "TX", "field": "tx", "align": "right"},
|
||||
{"name": "handshake", "label": "Last Handshake", "field": "handshake", "align": "left"},
|
||||
{"name": "actions", "label": "", "field": "id", "align": "center"},
|
||||
]
|
||||
table = ui.table(columns=columns, rows=[], row_key="id").classes("w-full")
|
||||
table.add_slot(
|
||||
"body-cell-actions",
|
||||
'''
|
||||
<q-td :props="props">
|
||||
<q-btn flat dense icon="edit" color="primary"
|
||||
@click.stop="() => $parent.$emit('edit', props.row.id)" />
|
||||
<q-btn flat dense icon="delete" color="negative"
|
||||
@click.stop="() => $parent.$emit('delete', props.row.id)" />
|
||||
</q-td>
|
||||
''',
|
||||
)
|
||||
table.on("edit", lambda e: open_edit(e.args))
|
||||
table.on("delete", lambda e: delete_device(e.args))
|
||||
|
||||
# --- Create dialog (full form) ---
|
||||
with ui.dialog() as create_dialog:
|
||||
with ui.card().classes("w-[600px]"):
|
||||
ui.label("New Device").classes("text-h6")
|
||||
|
||||
create_user_select = ui.select(
|
||||
user_map, value=list(user_map.keys())[0] if user_map else None,
|
||||
label="Owner",
|
||||
).props("outlined dense").classes("w-full")
|
||||
|
||||
create_name = ui.input("Device Name").props("outlined dense").classes("w-full")
|
||||
create_desc = ui.input("Description (optional)").props("outlined dense").classes("w-full")
|
||||
|
||||
ui.separator().classes("q-my-sm")
|
||||
ui.label("Configuration Overrides").classes("text-subtitle2")
|
||||
|
||||
with ui.grid(columns=2).classes("w-full gap-2"):
|
||||
create_use_default_ips = ui.switch("Use default Allowed IPs", value=True)
|
||||
create_allowed_ips = ui.input("Allowed IPs", placeholder="0.0.0.0/0, ::/0").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_ips, "value", backward=lambda v: not v)
|
||||
|
||||
create_use_default_dns = ui.switch("Use default DNS", value=True)
|
||||
create_dns = ui.input("DNS Servers", placeholder="1.1.1.1, 1.0.0.1").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_dns, "value", backward=lambda v: not v)
|
||||
|
||||
create_use_default_endpoint = ui.switch("Use default Endpoint", value=True)
|
||||
create_endpoint = ui.input("Endpoint", placeholder="vpn.example.com").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_endpoint, "value", backward=lambda v: not v)
|
||||
|
||||
create_use_default_mtu = ui.switch("Use default MTU", value=True)
|
||||
create_mtu = ui.input("MTU", placeholder="1280").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_mtu, "value", backward=lambda v: not v)
|
||||
|
||||
create_use_default_keepalive = ui.switch("Use default Keepalive", value=True)
|
||||
create_keepalive = ui.input("Persistent Keepalive", placeholder="25").props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_keepalive, "value", backward=lambda v: not v)
|
||||
|
||||
with ui.row().classes("w-full justify-end q-mt-sm"):
|
||||
ui.button("Cancel", on_click=create_dialog.close).props("flat")
|
||||
ui.button("Create", on_click=create_device).props("color=primary")
|
||||
|
||||
# --- Edit dialog (full form) ---
|
||||
with ui.dialog() as edit_dialog:
|
||||
with ui.card().classes("w-[600px]"):
|
||||
ui.label("Edit Device").classes("text-h6")
|
||||
|
||||
edit_name = ui.input("Device Name").props("outlined dense").classes("w-full")
|
||||
edit_desc = ui.input("Description").props("outlined dense").classes("w-full")
|
||||
|
||||
ui.separator().classes("q-my-sm")
|
||||
ui.label("Configuration Overrides").classes("text-subtitle2")
|
||||
|
||||
with ui.grid(columns=2).classes("w-full gap-2"):
|
||||
edit_use_default_ips = ui.switch("Use default Allowed IPs", value=True)
|
||||
edit_allowed_ips = ui.input("Allowed IPs").props("outlined dense").classes("w-full").bind_enabled_from(edit_use_default_ips, "value", backward=lambda v: not v)
|
||||
|
||||
edit_use_default_dns = ui.switch("Use default DNS", value=True)
|
||||
edit_dns = ui.input("DNS Servers").props("outlined dense").classes("w-full").bind_enabled_from(edit_use_default_dns, "value", backward=lambda v: not v)
|
||||
|
||||
edit_use_default_endpoint = ui.switch("Use default Endpoint", value=True)
|
||||
edit_endpoint = ui.input("Endpoint").props("outlined dense").classes("w-full").bind_enabled_from(edit_use_default_endpoint, "value", backward=lambda v: not v)
|
||||
|
||||
edit_use_default_mtu = ui.switch("Use default MTU", value=True)
|
||||
edit_mtu = ui.input("MTU").props("outlined dense").classes("w-full").bind_enabled_from(edit_use_default_mtu, "value", backward=lambda v: not v)
|
||||
|
||||
edit_use_default_keepalive = ui.switch("Use default Keepalive", value=True)
|
||||
edit_keepalive = ui.input("Persistent Keepalive").props("outlined dense").classes("w-full").bind_enabled_from(edit_use_default_keepalive, "value", backward=lambda v: not v)
|
||||
|
||||
with ui.row().classes("w-full justify-end q-mt-sm"):
|
||||
ui.button("Cancel", on_click=edit_dialog.close).props("flat")
|
||||
ui.button("Save", on_click=save_edit).props("color=primary")
|
||||
|
||||
await refresh_table()
|
||||
|
||||
# Auto-refresh stats every 30 seconds
|
||||
ui.timer(30, refresh_table)
|
||||
|
||||
|
||||
def _show_config_dialog(device_name: str, config_text: str):
|
||||
with ui.dialog(value=True) as dialog:
|
||||
with ui.card().classes("w-96"):
|
||||
ui.label(f"Config for {device_name}").classes("text-h6")
|
||||
ui.label("Save this — the private key won't be shown again.").classes("text-caption text-negative")
|
||||
ui.textarea(value=config_text).props("readonly outlined").classes("w-full font-mono text-xs q-mt-sm").style("min-height: 200px")
|
||||
try:
|
||||
qr = qrcode.make(config_text, image_factory=qrcode.image.svg.SvgPathImage)
|
||||
buf = io.BytesIO()
|
||||
qr.save(buf)
|
||||
ui.html(buf.getvalue().decode()).classes("w-full q-mt-sm")
|
||||
except Exception:
|
||||
pass
|
||||
ui.button("Download .conf", on_click=lambda: ui.download(config_text.encode(), f"{device_name}.conf")).props("color=primary outline").classes("w-full q-mt-sm")
|
||||
ui.button("Close", on_click=dialog.close).props("flat").classes("w-full")
|
||||
162
wiregui/pages/admin/diagnostics.py
Normal file
162
wiregui/pages/admin/diagnostics.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""Admin diagnostics page — connectivity checks, WG status, peer stats."""
|
||||
|
||||
from nicegui import app, ui
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.connectivity_check import ConnectivityCheck
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.pages.layout import layout
|
||||
from wiregui.services import notifications
|
||||
|
||||
|
||||
def _guard():
|
||||
if not app.storage.user.get("authenticated") or app.storage.user.get("role") != "admin":
|
||||
ui.navigate.to("/login")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _format_bytes(b: int | None) -> str:
|
||||
if b is None:
|
||||
return "-"
|
||||
for unit in ("B", "KB", "MB", "GB", "TB"):
|
||||
if b < 1024:
|
||||
return f"{b:.1f} {unit}"
|
||||
b /= 1024
|
||||
return f"{b:.1f} PB"
|
||||
|
||||
|
||||
@ui.page("/admin/diagnostics")
|
||||
async def diagnostics_page():
|
||||
if not _guard():
|
||||
return
|
||||
|
||||
layout()
|
||||
settings = get_settings()
|
||||
|
||||
with ui.column().classes("w-full p-4"):
|
||||
ui.label("Diagnostics").classes("text-h5 q-mb-md")
|
||||
|
||||
# --- WireGuard Status ---
|
||||
with ui.card().classes("w-full"):
|
||||
ui.label("WireGuard Interface").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
|
||||
with ui.grid(columns=2).classes("w-full gap-2 q-pa-sm"):
|
||||
ui.label("Interface:").classes("text-bold")
|
||||
ui.label(settings.wg_interface)
|
||||
|
||||
ui.label("Status:").classes("text-bold")
|
||||
ui.label("Enabled" if settings.wg_enabled else "Disabled (UI-only mode)").classes(
|
||||
"text-positive" if settings.wg_enabled else "text-warning"
|
||||
)
|
||||
|
||||
ui.label("IPv4 Network:").classes("text-bold")
|
||||
ui.label(settings.wg_ipv4_network)
|
||||
|
||||
ui.label("IPv6 Network:").classes("text-bold")
|
||||
ui.label(settings.wg_ipv6_network)
|
||||
|
||||
ui.label("Endpoint:").classes("text-bold")
|
||||
ui.label(f"{settings.wg_endpoint_host}:{settings.wg_endpoint_port}")
|
||||
|
||||
# --- Active Peers ---
|
||||
with ui.card().classes("w-full q-mt-md"):
|
||||
ui.label("Active Peers (from DB)").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Device).where(Device.latest_handshake.is_not(None)).order_by(Device.latest_handshake.desc())
|
||||
)
|
||||
active_devices = result.scalars().all()
|
||||
|
||||
if active_devices:
|
||||
peer_columns = [
|
||||
{"name": "name", "label": "Name", "field": "name", "align": "left"},
|
||||
{"name": "public_key", "label": "Public Key", "field": "public_key", "align": "left"},
|
||||
{"name": "ipv4", "label": "IPv4", "field": "ipv4", "align": "left"},
|
||||
{"name": "endpoint", "label": "Remote IP", "field": "endpoint", "align": "left"},
|
||||
{"name": "handshake", "label": "Last Handshake", "field": "handshake", "align": "left"},
|
||||
{"name": "rx", "label": "RX", "field": "rx", "align": "right"},
|
||||
{"name": "tx", "label": "TX", "field": "tx", "align": "right"},
|
||||
]
|
||||
peer_rows = [
|
||||
{
|
||||
"name": d.name,
|
||||
"public_key": d.public_key[:16] + "...",
|
||||
"ipv4": d.ipv4 or "-",
|
||||
"endpoint": d.remote_ip or "-",
|
||||
"handshake": str(d.latest_handshake)[:19] if d.latest_handshake else "-",
|
||||
"rx": _format_bytes(d.rx_bytes),
|
||||
"tx": _format_bytes(d.tx_bytes),
|
||||
}
|
||||
for d in active_devices
|
||||
]
|
||||
ui.table(columns=peer_columns, rows=peer_rows, row_key="name").classes("w-full")
|
||||
else:
|
||||
ui.label("No active peers with recent handshakes.").classes("text-caption text-grey-7 q-pa-sm")
|
||||
|
||||
# --- Connectivity Checks ---
|
||||
with ui.card().classes("w-full q-mt-md"):
|
||||
ui.label("WAN Connectivity Checks").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(ConnectivityCheck).order_by(ConnectivityCheck.inserted_at.desc()).limit(20)
|
||||
)
|
||||
checks = result.scalars().all()
|
||||
|
||||
if checks:
|
||||
check_columns = [
|
||||
{"name": "time", "label": "Checked At", "field": "time", "align": "left"},
|
||||
{"name": "url", "label": "URL", "field": "url", "align": "left"},
|
||||
{"name": "status", "label": "Status", "field": "status", "align": "center"},
|
||||
{"name": "body", "label": "Response", "field": "body", "align": "left"},
|
||||
]
|
||||
check_rows = [
|
||||
{
|
||||
"time": str(c.inserted_at)[:19],
|
||||
"url": c.url,
|
||||
"status": str(c.response_code or "Error"),
|
||||
"body": (c.response_body or "")[:50],
|
||||
}
|
||||
for c in checks
|
||||
]
|
||||
ui.table(columns=check_columns, rows=check_rows, row_key="time").classes("w-full")
|
||||
else:
|
||||
ui.label("No connectivity checks recorded yet.").classes("text-caption text-grey-7 q-pa-sm")
|
||||
|
||||
# --- Notifications ---
|
||||
with ui.card().classes("w-full q-mt-md"):
|
||||
ui.label("System Notifications").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
|
||||
notifs = notifications.current()
|
||||
if notifs:
|
||||
for n in notifs:
|
||||
color = {"error": "negative", "warning": "warning", "info": "info"}.get(n.severity, "grey")
|
||||
with ui.row().classes("w-full items-center q-pa-xs"):
|
||||
ui.icon("error" if n.severity == "error" else "warning" if n.severity == "warning" else "info").props(f"color={color}")
|
||||
ui.label(f"{n.timestamp.strftime('%H:%M:%S')} — {n.message}").classes("text-sm")
|
||||
if n.user:
|
||||
ui.label(f"({n.user})").classes("text-caption text-grey-7")
|
||||
ui.button(icon="close", on_click=lambda nid=n.id: _clear_notif(nid)).props("flat dense size=xs")
|
||||
else:
|
||||
ui.label("No notifications.").classes("text-caption text-grey-7 q-pa-sm")
|
||||
|
||||
if notifs:
|
||||
ui.button("Clear All", on_click=lambda: _clear_all_notifs()).props("flat color=negative").classes("q-mt-sm")
|
||||
|
||||
|
||||
def _clear_notif(nid: str):
|
||||
notifications.clear(nid)
|
||||
ui.navigate.to("/admin/diagnostics")
|
||||
|
||||
|
||||
def _clear_all_notifs():
|
||||
notifications.clear_all()
|
||||
ui.navigate.to("/admin/diagnostics")
|
||||
228
wiregui/pages/admin/rules.py
Normal file
228
wiregui/pages/admin/rules.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
"""Admin firewall rules management page."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
from nicegui import app, ui
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
from wiregui.pages.layout import layout
|
||||
from wiregui.services.events import on_rule_created, on_rule_deleted, on_rule_updated
|
||||
|
||||
|
||||
@ui.page("/admin/rules")
|
||||
async def rules_page():
|
||||
if not app.storage.user.get("authenticated") or app.storage.user.get("role") != "admin":
|
||||
return ui.navigate.to("/login")
|
||||
|
||||
layout()
|
||||
|
||||
# Load users for the dropdown
|
||||
async with async_session() as session:
|
||||
users = (await session.execute(select(User).order_by(User.email))).scalars().all()
|
||||
user_options = {str(u.id): u.email for u in users}
|
||||
|
||||
async def load_rules() -> list[dict]:
|
||||
async with async_session() as session:
|
||||
result = await session.execute(select(Rule).order_by(Rule.inserted_at.desc()))
|
||||
rules = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"action": r.action,
|
||||
"destination": r.destination,
|
||||
"port_type": r.port_type or "any",
|
||||
"port_range": r.port_range or "any",
|
||||
"user": user_options.get(str(r.user_id), "Global") if r.user_id else "Global",
|
||||
}
|
||||
for r in rules
|
||||
]
|
||||
|
||||
async def refresh_table():
|
||||
table.rows = await load_rules()
|
||||
table.update()
|
||||
|
||||
async def create_rule():
|
||||
dest = dest_input.value.strip()
|
||||
if not dest:
|
||||
ui.notify("Destination is required", type="negative")
|
||||
return
|
||||
|
||||
action_val = action_select.value
|
||||
port_type_val = port_type_select.value if port_type_select.value != "any" else None
|
||||
port_range_val = port_range_input.value.strip() or None
|
||||
user_id_val = user_select.value if user_select.value != "global" else None
|
||||
|
||||
async with async_session() as session:
|
||||
rule = Rule(
|
||||
action=action_val,
|
||||
destination=dest,
|
||||
port_type=port_type_val,
|
||||
port_range=port_range_val,
|
||||
user_id=UUID(user_id_val) if user_id_val else None,
|
||||
)
|
||||
session.add(rule)
|
||||
await session.commit()
|
||||
await session.refresh(rule)
|
||||
|
||||
logger.info("Rule created: {} {} -> {}", rule.action, rule.destination, user_id_val or "global")
|
||||
await on_rule_created(rule)
|
||||
|
||||
create_dialog.close()
|
||||
_reset_form()
|
||||
await refresh_table()
|
||||
|
||||
# --- Edit rule ---
|
||||
edit_rule_id = {"value": None}
|
||||
|
||||
async def open_edit(rule_id: str):
|
||||
async with async_session() as session:
|
||||
rule = await session.get(Rule, UUID(rule_id))
|
||||
if not rule:
|
||||
return
|
||||
edit_rule_id["value"] = rule_id
|
||||
edit_action.value = rule.action
|
||||
edit_dest.value = rule.destination
|
||||
edit_port_type.value = rule.port_type or "any"
|
||||
edit_port_range.value = rule.port_range or ""
|
||||
edit_user.value = str(rule.user_id) if rule.user_id else "global"
|
||||
edit_dialog.open()
|
||||
|
||||
async def save_edit():
|
||||
rid = edit_rule_id["value"]
|
||||
if not rid:
|
||||
return
|
||||
|
||||
async with async_session() as session:
|
||||
rule = await session.get(Rule, UUID(rid))
|
||||
if not rule:
|
||||
return
|
||||
|
||||
rule.action = edit_action.value
|
||||
rule.destination = edit_dest.value.strip()
|
||||
rule.port_type = edit_port_type.value if edit_port_type.value != "any" else None
|
||||
rule.port_range = edit_port_range.value.strip() or None
|
||||
rule.user_id = UUID(edit_user.value) if edit_user.value != "global" else None
|
||||
|
||||
session.add(rule)
|
||||
await session.commit()
|
||||
await session.refresh(rule)
|
||||
await on_rule_updated(rule)
|
||||
|
||||
logger.info("Rule updated: {} {}", edit_action.value, edit_dest.value)
|
||||
ui.notify("Rule updated")
|
||||
edit_dialog.close()
|
||||
await refresh_table()
|
||||
|
||||
async def delete_rule(rule_id: str):
|
||||
async with async_session() as session:
|
||||
rule = await session.get(Rule, UUID(rule_id))
|
||||
if rule:
|
||||
await session.delete(rule)
|
||||
await session.commit()
|
||||
logger.info("Rule deleted: {} {}", rule.action, rule.destination)
|
||||
await on_rule_deleted(rule)
|
||||
await refresh_table()
|
||||
|
||||
def _reset_form():
|
||||
dest_input.value = ""
|
||||
action_select.value = "accept"
|
||||
port_type_select.value = "any"
|
||||
port_range_input.value = ""
|
||||
user_select.value = "global"
|
||||
|
||||
# Page content
|
||||
with ui.column().classes("w-full p-4"):
|
||||
with ui.row().classes("w-full items-center justify-between"):
|
||||
ui.label("Firewall Rules").classes("text-h5")
|
||||
ui.button("Add Rule", icon="add", on_click=lambda: create_dialog.open()).props("color=primary")
|
||||
|
||||
columns = [
|
||||
{"name": "action", "label": "Action", "field": "action", "align": "left", "sortable": True},
|
||||
{"name": "destination", "label": "Destination", "field": "destination", "align": "left", "sortable": True},
|
||||
{"name": "port_type", "label": "Protocol", "field": "port_type", "align": "left"},
|
||||
{"name": "port_range", "label": "Port(s)", "field": "port_range", "align": "left"},
|
||||
{"name": "user", "label": "User", "field": "user", "align": "left"},
|
||||
{"name": "actions", "label": "", "field": "id", "align": "center"},
|
||||
]
|
||||
table = ui.table(columns=columns, rows=[], row_key="id").classes("w-full")
|
||||
table.add_slot(
|
||||
"body-cell-actions",
|
||||
'''
|
||||
<q-td :props="props">
|
||||
<q-btn flat dense icon="edit" color="primary"
|
||||
@click.stop="() => $parent.$emit('edit', props.row.id)" />
|
||||
<q-btn flat dense icon="delete" color="negative"
|
||||
@click.stop="() => $parent.$emit('delete', props.row.id)" />
|
||||
</q-td>
|
||||
''',
|
||||
)
|
||||
table.on("edit", lambda e: open_edit(e.args))
|
||||
table.on("delete", lambda e: delete_rule(e.args))
|
||||
|
||||
# Create rule dialog
|
||||
with ui.dialog() as create_dialog:
|
||||
with ui.card().classes("w-96"):
|
||||
ui.label("New Firewall Rule").classes("text-h6")
|
||||
|
||||
action_select = ui.select(
|
||||
["accept", "drop"], value="accept", label="Action",
|
||||
).props("outlined dense").classes("w-full")
|
||||
|
||||
dest_input = ui.input("Destination (CIDR)", placeholder="e.g. 10.0.0.0/8 or 0.0.0.0/0").props(
|
||||
"outlined dense"
|
||||
).classes("w-full")
|
||||
|
||||
port_type_select = ui.select(
|
||||
["any", "tcp", "udp"], value="any", label="Protocol",
|
||||
).props("outlined dense").classes("w-full")
|
||||
|
||||
port_range_input = ui.input("Port Range", placeholder="e.g. 80 or 80-443 (optional)").props(
|
||||
"outlined dense"
|
||||
).classes("w-full")
|
||||
|
||||
user_options_list = [{"label": "Global (all users)", "value": "global"}] + [
|
||||
{"label": email, "value": uid} for uid, email in user_options.items()
|
||||
]
|
||||
user_select = ui.select(
|
||||
{item["value"]: item["label"] for item in user_options_list},
|
||||
value="global",
|
||||
label="Applies to",
|
||||
).props("outlined dense").classes("w-full")
|
||||
|
||||
with ui.row().classes("w-full justify-end q-mt-sm"):
|
||||
ui.button("Cancel", on_click=create_dialog.close).props("flat")
|
||||
ui.button("Create", on_click=create_rule).props("color=primary")
|
||||
|
||||
# Edit rule dialog
|
||||
user_options_map = {"global": "Global (all users)"}
|
||||
user_options_map.update(user_options)
|
||||
|
||||
with ui.dialog() as edit_dialog:
|
||||
with ui.card().classes("w-96"):
|
||||
ui.label("Edit Firewall Rule").classes("text-h6")
|
||||
|
||||
edit_action = ui.select(
|
||||
["accept", "drop"], value="accept", label="Action",
|
||||
).props("outlined dense").classes("w-full")
|
||||
|
||||
edit_dest = ui.input("Destination (CIDR)").props("outlined dense").classes("w-full")
|
||||
|
||||
edit_port_type = ui.select(
|
||||
["any", "tcp", "udp"], value="any", label="Protocol",
|
||||
).props("outlined dense").classes("w-full")
|
||||
|
||||
edit_port_range = ui.input("Port Range").props("outlined dense").classes("w-full")
|
||||
|
||||
edit_user = ui.select(
|
||||
user_options_map, value="global", label="Applies to",
|
||||
).props("outlined dense").classes("w-full")
|
||||
|
||||
with ui.row().classes("w-full justify-end q-mt-sm"):
|
||||
ui.button("Cancel", on_click=edit_dialog.close).props("flat")
|
||||
ui.button("Save", on_click=save_edit).props("color=primary")
|
||||
|
||||
await refresh_table()
|
||||
367
wiregui/pages/admin/settings.py
Normal file
367
wiregui/pages/admin/settings.py
Normal file
|
|
@ -0,0 +1,367 @@
|
|||
"""Admin settings pages — tabbed interface for configuration management."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
from nicegui import app, ui
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.pages.layout import layout
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
def _guard():
|
||||
if not app.storage.user.get("authenticated") or app.storage.user.get("role") != "admin":
|
||||
ui.navigate.to("/login")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _get_or_create_config() -> Configuration:
|
||||
async with async_session() as session:
|
||||
result = await session.execute(select(Configuration).limit(1))
|
||||
config = result.scalar_one_or_none()
|
||||
if not config:
|
||||
config = Configuration()
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
|
||||
VPN_SESSION_OPTIONS = {
|
||||
0: "Never (unlimited)",
|
||||
3600: "Every Hour",
|
||||
86400: "Every Day",
|
||||
604800: "Every Week",
|
||||
2592000: "Every 30 Days",
|
||||
7776000: "Every 90 Days",
|
||||
}
|
||||
|
||||
|
||||
@ui.page("/admin/settings")
|
||||
async def settings_page():
|
||||
if not _guard():
|
||||
return
|
||||
|
||||
layout()
|
||||
config = await _get_or_create_config()
|
||||
|
||||
# --- Client Defaults tab ---
|
||||
async def save_defaults():
|
||||
async with async_session() as session:
|
||||
c = await session.get(Configuration, config.id)
|
||||
c.default_client_endpoint = defaults_endpoint.value.strip() or None
|
||||
c.default_client_dns = [s.strip() for s in defaults_dns.value.split(",") if s.strip()]
|
||||
c.default_client_mtu = int(defaults_mtu.value) if defaults_mtu.value else 1280
|
||||
c.default_client_persistent_keepalive = int(defaults_keepalive.value) if defaults_keepalive.value else 25
|
||||
c.default_client_allowed_ips = [s.strip() for s in defaults_allowed_ips.value.split(",") if s.strip()]
|
||||
c.updated_at = utcnow()
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
logger.info("Client defaults updated")
|
||||
ui.notify("Client defaults saved", type="positive")
|
||||
|
||||
# --- Security tab ---
|
||||
async def save_security():
|
||||
async with async_session() as session:
|
||||
c = await session.get(Configuration, config.id)
|
||||
c.vpn_session_duration = security_vpn_duration.value
|
||||
c.local_auth_enabled = security_local_auth.value
|
||||
c.allow_unprivileged_device_management = security_unpriv_mgmt.value
|
||||
c.allow_unprivileged_device_configuration = security_unpriv_config.value
|
||||
c.disable_vpn_on_oidc_error = security_disable_vpn_oidc.value
|
||||
c.updated_at = utcnow()
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
logger.info("Security settings updated")
|
||||
ui.notify("Security settings saved", type="positive")
|
||||
|
||||
# --- OIDC provider management ---
|
||||
async def save_oidc_provider():
|
||||
provider = {
|
||||
"id": oidc_id.value.strip(),
|
||||
"label": oidc_label.value.strip(),
|
||||
"scope": oidc_scope.value.strip(),
|
||||
"response_type": "code",
|
||||
"client_id": oidc_client_id.value.strip(),
|
||||
"client_secret": oidc_client_secret.value.strip(),
|
||||
"discovery_document_uri": oidc_discovery.value.strip(),
|
||||
"auto_create_users": oidc_auto_create.value,
|
||||
}
|
||||
if not all([provider["id"], provider["label"], provider["client_id"],
|
||||
provider["client_secret"], provider["discovery_document_uri"]]):
|
||||
ui.notify("All required fields must be filled", type="negative")
|
||||
return
|
||||
|
||||
async with async_session() as session:
|
||||
c = await session.get(Configuration, config.id)
|
||||
providers = list(c.openid_connect_providers or [])
|
||||
# Replace existing or add new
|
||||
providers = [p for p in providers if p.get("id") != provider["id"]]
|
||||
providers.append(provider)
|
||||
c.openid_connect_providers = providers
|
||||
c.updated_at = utcnow()
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
|
||||
logger.info("OIDC provider saved: {}", provider["id"])
|
||||
ui.notify(f"OIDC provider '{provider['label']}' saved", type="positive")
|
||||
oidc_dialog.close()
|
||||
await refresh_oidc_table()
|
||||
|
||||
async def delete_oidc_provider(provider_id: str):
|
||||
async with async_session() as session:
|
||||
c = await session.get(Configuration, config.id)
|
||||
c.openid_connect_providers = [p for p in (c.openid_connect_providers or []) if p.get("id") != provider_id]
|
||||
c.updated_at = utcnow()
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
logger.info("OIDC provider deleted: {}", provider_id)
|
||||
ui.notify("OIDC provider deleted")
|
||||
await refresh_oidc_table()
|
||||
|
||||
async def refresh_oidc_table():
|
||||
async with async_session() as session:
|
||||
c = await session.get(Configuration, config.id)
|
||||
providers = c.openid_connect_providers or []
|
||||
oidc_table.rows = [
|
||||
{
|
||||
"id": p.get("id", ""),
|
||||
"label": p.get("label", ""),
|
||||
"client_id": p.get("client_id", ""),
|
||||
"discovery": p.get("discovery_document_uri", "")[:50] + "...",
|
||||
"auto_create": "Yes" if p.get("auto_create_users") else "No",
|
||||
}
|
||||
for p in providers
|
||||
]
|
||||
oidc_table.update()
|
||||
|
||||
# --- Page content ---
|
||||
with ui.column().classes("w-full p-4"):
|
||||
ui.label("Settings").classes("text-h5 q-mb-md")
|
||||
|
||||
with ui.tabs().classes("w-full") as tabs:
|
||||
defaults_tab = ui.tab("Client Defaults")
|
||||
security_tab = ui.tab("Security")
|
||||
auth_tab = ui.tab("Authentication")
|
||||
|
||||
with ui.tab_panels(tabs, value=defaults_tab).classes("w-full"):
|
||||
|
||||
# === Client Defaults ===
|
||||
with ui.tab_panel(defaults_tab):
|
||||
with ui.card().classes("w-full"):
|
||||
ui.label("Default Client Configuration").classes("text-subtitle1 text-bold")
|
||||
ui.label("These defaults apply to new devices unless overridden per-device.").classes("text-caption text-grey-7")
|
||||
ui.separator()
|
||||
|
||||
defaults_endpoint = ui.input(
|
||||
"Endpoint", value=config.default_client_endpoint or "",
|
||||
placeholder="vpn.example.com",
|
||||
).props("outlined dense").classes("w-full")
|
||||
ui.label("IPv4/IPv6 address or FQDN clients connect to").classes("text-caption text-grey-7")
|
||||
|
||||
defaults_dns = ui.input(
|
||||
"DNS Servers", value=", ".join(config.default_client_dns),
|
||||
placeholder="1.1.1.1, 1.0.0.1",
|
||||
).props("outlined dense").classes("w-full q-mt-sm")
|
||||
ui.label("Comma-separated. Leave blank to omit.").classes("text-caption text-grey-7")
|
||||
|
||||
defaults_allowed_ips = ui.input(
|
||||
"Allowed IPs", value=", ".join(config.default_client_allowed_ips),
|
||||
placeholder="0.0.0.0/0, ::/0",
|
||||
).props("outlined dense").classes("w-full q-mt-sm")
|
||||
ui.label("CIDR ranges for split or full tunnel.").classes("text-caption text-grey-7")
|
||||
|
||||
with ui.row().classes("w-full gap-4 q-mt-sm"):
|
||||
defaults_mtu = ui.input(
|
||||
"MTU", value=str(config.default_client_mtu),
|
||||
placeholder="1280",
|
||||
).props("outlined dense").classes("w-48")
|
||||
|
||||
defaults_keepalive = ui.input(
|
||||
"Persistent Keepalive", value=str(config.default_client_persistent_keepalive),
|
||||
placeholder="25",
|
||||
).props("outlined dense").classes("w-48")
|
||||
|
||||
ui.button("Save Defaults", on_click=save_defaults).props("color=primary").classes("q-mt-md")
|
||||
|
||||
# === Security ===
|
||||
with ui.tab_panel(security_tab):
|
||||
with ui.card().classes("w-full"):
|
||||
ui.label("Authentication & Access").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
|
||||
security_vpn_duration = ui.select(
|
||||
VPN_SESSION_OPTIONS,
|
||||
value=config.vpn_session_duration,
|
||||
label="VPN Session Duration",
|
||||
).props("outlined dense").classes("w-full")
|
||||
ui.label("How often users must re-authenticate to maintain VPN access.").classes("text-caption text-grey-7")
|
||||
|
||||
ui.separator().classes("q-my-md")
|
||||
|
||||
security_local_auth = ui.switch("Local Authentication (email/password)", value=config.local_auth_enabled)
|
||||
security_unpriv_mgmt = ui.switch("Allow Unprivileged Device Management", value=config.allow_unprivileged_device_management)
|
||||
security_unpriv_config = ui.switch("Allow Unprivileged Device Configuration", value=config.allow_unprivileged_device_configuration)
|
||||
|
||||
ui.separator().classes("q-my-md")
|
||||
ui.label("SSO Behavior").classes("text-subtitle2")
|
||||
security_disable_vpn_oidc = ui.switch("Auto-disable VPN on OIDC refresh error", value=config.disable_vpn_on_oidc_error)
|
||||
|
||||
ui.button("Save Security Settings", on_click=save_security).props("color=primary").classes("q-mt-md")
|
||||
|
||||
# === Authentication (OIDC/SAML) ===
|
||||
with ui.tab_panel(auth_tab):
|
||||
with ui.card().classes("w-full"):
|
||||
ui.label("OpenID Connect Providers").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
|
||||
oidc_columns = [
|
||||
{"name": "id", "label": "Config ID", "field": "id", "align": "left"},
|
||||
{"name": "label", "label": "Label", "field": "label", "align": "left"},
|
||||
{"name": "client_id", "label": "Client ID", "field": "client_id", "align": "left"},
|
||||
{"name": "discovery", "label": "Discovery URI", "field": "discovery", "align": "left"},
|
||||
{"name": "auto_create", "label": "Auto-create", "field": "auto_create", "align": "center"},
|
||||
{"name": "actions", "label": "", "field": "id", "align": "center"},
|
||||
]
|
||||
oidc_table = ui.table(columns=oidc_columns, rows=[], row_key="id").classes("w-full")
|
||||
oidc_table.add_slot(
|
||||
"body-cell-actions",
|
||||
'''
|
||||
<q-td :props="props">
|
||||
<q-btn flat dense icon="delete" color="negative"
|
||||
@click.stop="() => $parent.$emit('delete', props.row.id)" />
|
||||
</q-td>
|
||||
''',
|
||||
)
|
||||
oidc_table.on("delete", lambda e: delete_oidc_provider(e.args))
|
||||
|
||||
ui.button("Add OIDC Provider", icon="add", on_click=lambda: oidc_dialog.open()).props("outline").classes("q-mt-sm")
|
||||
|
||||
with ui.card().classes("w-full q-mt-md"):
|
||||
ui.label("SAML Identity Providers").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
|
||||
saml_columns = [
|
||||
{"name": "id", "label": "Config ID", "field": "id", "align": "left"},
|
||||
{"name": "label", "label": "Label", "field": "label", "align": "left"},
|
||||
{"name": "metadata", "label": "Metadata", "field": "metadata", "align": "left"},
|
||||
{"name": "auto_create", "label": "Auto-create", "field": "auto_create", "align": "center"},
|
||||
{"name": "actions", "label": "", "field": "id", "align": "center"},
|
||||
]
|
||||
saml_table = ui.table(columns=saml_columns, rows=[], row_key="id").classes("w-full")
|
||||
saml_table.add_slot(
|
||||
"body-cell-actions",
|
||||
'''
|
||||
<q-td :props="props">
|
||||
<q-btn flat dense icon="delete" color="negative"
|
||||
@click.stop="() => $parent.$emit('delete', props.row.id)" />
|
||||
</q-td>
|
||||
''',
|
||||
)
|
||||
saml_table.on("delete", lambda e: delete_saml_provider(e.args))
|
||||
|
||||
ui.button("Add SAML Provider", icon="add", on_click=lambda: saml_dialog.open()).props("outline").classes("q-mt-sm")
|
||||
|
||||
# --- SAML provider management ---
|
||||
async def save_saml_provider():
|
||||
provider = {
|
||||
"id": saml_id.value.strip(),
|
||||
"label": saml_label.value.strip(),
|
||||
"metadata": saml_metadata_input.value.strip(),
|
||||
"base_url": f"{get_settings().external_url}/auth/saml",
|
||||
"sign_requests": saml_sign_requests.value,
|
||||
"sign_metadata": saml_sign_metadata.value,
|
||||
"signed_assertion_in_resp": saml_signed_assertion.value,
|
||||
"signed_envelopes_in_resp": saml_signed_envelopes.value,
|
||||
"auto_create_users": saml_auto_create.value,
|
||||
}
|
||||
if not all([provider["id"], provider["label"], provider["metadata"]]):
|
||||
ui.notify("Config ID, Label, and Metadata are required", type="negative")
|
||||
return
|
||||
|
||||
async with async_session() as session:
|
||||
c = await session.get(Configuration, config.id)
|
||||
providers = list(c.saml_identity_providers or [])
|
||||
providers = [p for p in providers if p.get("id") != provider["id"]]
|
||||
providers.append(provider)
|
||||
c.saml_identity_providers = providers
|
||||
c.updated_at = utcnow()
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
|
||||
logger.info("SAML provider saved: {}", provider["id"])
|
||||
ui.notify(f"SAML provider '{provider['label']}' saved", type="positive")
|
||||
saml_dialog.close()
|
||||
await refresh_saml_table()
|
||||
|
||||
async def delete_saml_provider(provider_id: str):
|
||||
async with async_session() as session:
|
||||
c = await session.get(Configuration, config.id)
|
||||
c.saml_identity_providers = [p for p in (c.saml_identity_providers or []) if p.get("id") != provider_id]
|
||||
c.updated_at = utcnow()
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
logger.info("SAML provider deleted: {}", provider_id)
|
||||
ui.notify("SAML provider deleted")
|
||||
await refresh_saml_table()
|
||||
|
||||
async def refresh_saml_table():
|
||||
async with async_session() as session:
|
||||
c = await session.get(Configuration, config.id)
|
||||
providers = c.saml_identity_providers or []
|
||||
saml_table.rows = [
|
||||
{
|
||||
"id": p.get("id", ""),
|
||||
"label": p.get("label", ""),
|
||||
"metadata": (p.get("metadata", ""))[:40] + "..." if len(p.get("metadata", "")) > 40 else p.get("metadata", ""),
|
||||
"auto_create": "Yes" if p.get("auto_create_users") else "No",
|
||||
}
|
||||
for p in providers
|
||||
]
|
||||
saml_table.update()
|
||||
|
||||
# --- OIDC provider dialog ---
|
||||
with ui.dialog() as oidc_dialog:
|
||||
with ui.card().classes("w-[500px]"):
|
||||
ui.label("OIDC Provider").classes("text-h6")
|
||||
oidc_id = ui.input("Config ID", placeholder="google").props("outlined dense").classes("w-full")
|
||||
oidc_label = ui.input("Label", placeholder="Sign in with Google").props("outlined dense").classes("w-full")
|
||||
oidc_scope = ui.input("Scope", value="openid email profile").props("outlined dense").classes("w-full")
|
||||
oidc_client_id = ui.input("Client ID").props("outlined dense").classes("w-full")
|
||||
oidc_client_secret = ui.input("Client Secret", password=True, password_toggle_button=True).props("outlined dense").classes("w-full")
|
||||
oidc_discovery = ui.input("Discovery Document URI", placeholder="https://accounts.google.com/.well-known/openid-configuration").props("outlined dense").classes("w-full")
|
||||
oidc_auto_create = ui.switch("Auto-create users", value=False)
|
||||
|
||||
with ui.row().classes("w-full justify-end q-mt-sm"):
|
||||
ui.button("Cancel", on_click=oidc_dialog.close).props("flat")
|
||||
ui.button("Save", on_click=save_oidc_provider).props("color=primary")
|
||||
|
||||
# --- SAML provider dialog ---
|
||||
with ui.dialog() as saml_dialog:
|
||||
with ui.card().classes("w-[500px]"):
|
||||
ui.label("SAML Identity Provider").classes("text-h6")
|
||||
saml_id = ui.input("Config ID", placeholder="okta-saml").props("outlined dense").classes("w-full")
|
||||
saml_label = ui.input("Label", placeholder="Sign in with Okta").props("outlined dense").classes("w-full")
|
||||
saml_metadata_input = ui.textarea("IdP Metadata (XML)").props("outlined").classes("w-full").style("min-height: 120px")
|
||||
ui.label("Paste the full XML metadata from your identity provider.").classes("text-caption text-grey-7")
|
||||
|
||||
ui.separator().classes("q-my-sm")
|
||||
ui.label("Security Options").classes("text-subtitle2")
|
||||
saml_sign_requests = ui.switch("Sign authentication requests", value=True)
|
||||
saml_sign_metadata = ui.switch("Sign SP metadata", value=True)
|
||||
saml_signed_assertion = ui.switch("Require signed assertions in response", value=True)
|
||||
saml_signed_envelopes = ui.switch("Require signed envelopes in response", value=True)
|
||||
|
||||
ui.separator().classes("q-my-sm")
|
||||
saml_auto_create = ui.switch("Auto-create users", value=False)
|
||||
|
||||
with ui.row().classes("w-full justify-end q-mt-sm"):
|
||||
ui.button("Cancel", on_click=saml_dialog.close).props("flat")
|
||||
ui.button("Save", on_click=save_saml_provider).props("color=primary")
|
||||
|
||||
await refresh_oidc_table()
|
||||
await refresh_saml_table()
|
||||
236
wiregui/pages/admin/users.py
Normal file
236
wiregui/pages/admin/users.py
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
"""Admin user management page."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
from nicegui import app, ui
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import func, select
|
||||
|
||||
from wiregui.auth.passwords import hash_password
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.models.user import User
|
||||
from wiregui.pages.layout import layout
|
||||
from wiregui.services.events import on_device_deleted
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
def _guard():
|
||||
if not app.storage.user.get("authenticated") or app.storage.user.get("role") != "admin":
|
||||
ui.navigate.to("/login")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@ui.page("/admin/users")
|
||||
async def users_page():
|
||||
if not _guard():
|
||||
return
|
||||
|
||||
layout()
|
||||
|
||||
async def load_users() -> list[dict]:
|
||||
async with async_session() as session:
|
||||
# Get users with device counts via subquery
|
||||
device_count_sq = (
|
||||
select(Device.user_id, func.count().label("device_count"))
|
||||
.group_by(Device.user_id)
|
||||
.subquery()
|
||||
)
|
||||
result = await session.execute(
|
||||
select(User).order_by(User.email)
|
||||
)
|
||||
users = result.scalars().all()
|
||||
|
||||
# Get device counts separately
|
||||
counts_result = await session.execute(
|
||||
select(Device.user_id, func.count().label("cnt")).group_by(Device.user_id)
|
||||
)
|
||||
counts = {str(row[0]): row[1] for row in counts_result.all()}
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(u.id),
|
||||
"email": u.email,
|
||||
"role": u.role,
|
||||
"devices": counts.get(str(u.id), 0),
|
||||
"last_signed_in": str(u.last_signed_in_at or "-"),
|
||||
"method": u.last_signed_in_method or "-",
|
||||
"status": "Disabled" if u.disabled_at else "Active",
|
||||
"created": str(u.inserted_at)[:19],
|
||||
}
|
||||
for u in users
|
||||
]
|
||||
|
||||
async def refresh_table():
|
||||
table.rows = await load_users()
|
||||
table.update()
|
||||
|
||||
# --- Create user ---
|
||||
async def create_user():
|
||||
email = create_email.value.strip()
|
||||
pwd = create_password.value
|
||||
role = create_role.value
|
||||
|
||||
if not email or not pwd:
|
||||
ui.notify("Email and password are required", type="negative")
|
||||
return
|
||||
|
||||
async with async_session() as session:
|
||||
existing = (await session.execute(select(User).where(User.email == email))).scalar_one_or_none()
|
||||
if existing:
|
||||
ui.notify(f"User {email} already exists", type="negative")
|
||||
return
|
||||
|
||||
user = User(email=email, password_hash=hash_password(pwd), role=role)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
logger.info("Admin created user: {} ({})", email, role)
|
||||
ui.notify(f"User {email} created")
|
||||
create_dialog.close()
|
||||
create_email.value = ""
|
||||
create_password.value = ""
|
||||
create_role.value = "unprivileged"
|
||||
await refresh_table()
|
||||
|
||||
# --- Edit user ---
|
||||
edit_user_id = {"value": None}
|
||||
|
||||
async def open_edit(user_id: str):
|
||||
async with async_session() as session:
|
||||
user = await session.get(User, UUID(user_id))
|
||||
if not user:
|
||||
return
|
||||
edit_user_id["value"] = user_id
|
||||
edit_email.value = user.email
|
||||
edit_role.value = user.role
|
||||
edit_password.value = ""
|
||||
edit_disabled.value = user.disabled_at is not None
|
||||
edit_dialog.open()
|
||||
|
||||
async def save_edit():
|
||||
uid = edit_user_id["value"]
|
||||
if not uid:
|
||||
return
|
||||
|
||||
async with async_session() as session:
|
||||
user = await session.get(User, UUID(uid))
|
||||
if not user:
|
||||
return
|
||||
|
||||
user.email = edit_email.value.strip()
|
||||
user.role = edit_role.value
|
||||
|
||||
if edit_password.value:
|
||||
user.password_hash = hash_password(edit_password.value)
|
||||
|
||||
if edit_disabled.value and not user.disabled_at:
|
||||
user.disabled_at = utcnow()
|
||||
elif not edit_disabled.value and user.disabled_at:
|
||||
user.disabled_at = None
|
||||
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
logger.info("Admin updated user: {}", edit_email.value)
|
||||
ui.notify("User updated")
|
||||
edit_dialog.close()
|
||||
await refresh_table()
|
||||
|
||||
# --- Delete user ---
|
||||
async def delete_user(user_id: str):
|
||||
current_user_id = app.storage.user.get("user_id")
|
||||
if user_id == current_user_id:
|
||||
ui.notify("Cannot delete your own account", type="negative")
|
||||
return
|
||||
|
||||
async with async_session() as session:
|
||||
user = await session.get(User, UUID(user_id))
|
||||
if not user:
|
||||
return
|
||||
|
||||
# Delete user's devices (and fire WG events)
|
||||
devices_result = await session.execute(
|
||||
select(Device).where(Device.user_id == user.id)
|
||||
)
|
||||
for device in devices_result.scalars().all():
|
||||
await session.delete(device)
|
||||
await on_device_deleted(device)
|
||||
|
||||
# Delete user's rules
|
||||
await session.execute(
|
||||
select(Rule).where(Rule.user_id == user.id)
|
||||
)
|
||||
rules_result = await session.execute(select(Rule).where(Rule.user_id == user.id))
|
||||
for rule in rules_result.scalars().all():
|
||||
await session.delete(rule)
|
||||
|
||||
await session.delete(user)
|
||||
await session.commit()
|
||||
|
||||
logger.info("Admin deleted user: {}", user.email)
|
||||
ui.notify(f"User {user.email} deleted")
|
||||
await refresh_table()
|
||||
|
||||
def on_row_click(e):
|
||||
open_edit(e.args["id"])
|
||||
|
||||
# --- Page content ---
|
||||
with ui.column().classes("w-full p-4"):
|
||||
with ui.row().classes("w-full items-center justify-between"):
|
||||
ui.label("Users").classes("text-h5")
|
||||
ui.button("Add User", icon="person_add", on_click=lambda: create_dialog.open()).props("color=primary")
|
||||
|
||||
columns = [
|
||||
{"name": "email", "label": "Email", "field": "email", "align": "left", "sortable": True},
|
||||
{"name": "role", "label": "Role", "field": "role", "align": "left", "sortable": True},
|
||||
{"name": "devices", "label": "Devices", "field": "devices", "align": "center"},
|
||||
{"name": "status", "label": "Status", "field": "status", "align": "left"},
|
||||
{"name": "last_signed_in", "label": "Last Sign-in", "field": "last_signed_in", "align": "left"},
|
||||
{"name": "method", "label": "Method", "field": "method", "align": "left"},
|
||||
{"name": "created", "label": "Created", "field": "created", "align": "left"},
|
||||
{"name": "actions", "label": "", "field": "id", "align": "center"},
|
||||
]
|
||||
table = ui.table(columns=columns, rows=[], row_key="id").classes("w-full")
|
||||
table.on("rowClick", on_row_click)
|
||||
table.add_slot(
|
||||
"body-cell-actions",
|
||||
'''
|
||||
<q-td :props="props">
|
||||
<q-btn flat dense icon="edit" color="primary"
|
||||
@click.stop="() => $parent.$emit('edit', props.row.id)" />
|
||||
<q-btn flat dense icon="delete" color="negative"
|
||||
@click.stop="() => $parent.$emit('delete', props.row.id)" />
|
||||
</q-td>
|
||||
''',
|
||||
)
|
||||
table.on("edit", lambda e: open_edit(e.args))
|
||||
table.on("delete", lambda e: delete_user(e.args))
|
||||
|
||||
# --- Create dialog ---
|
||||
with ui.dialog() as create_dialog:
|
||||
with ui.card().classes("w-96"):
|
||||
ui.label("New User").classes("text-h6")
|
||||
create_email = ui.input("Email").props("outlined dense").classes("w-full")
|
||||
create_password = ui.input("Password", password=True, password_toggle_button=True).props("outlined dense").classes("w-full")
|
||||
create_role = ui.select(["unprivileged", "admin"], value="unprivileged", label="Role").props("outlined dense").classes("w-full")
|
||||
with ui.row().classes("w-full justify-end q-mt-sm"):
|
||||
ui.button("Cancel", on_click=create_dialog.close).props("flat")
|
||||
ui.button("Create", on_click=create_user).props("color=primary")
|
||||
|
||||
# --- Edit dialog ---
|
||||
with ui.dialog() as edit_dialog:
|
||||
with ui.card().classes("w-96"):
|
||||
ui.label("Edit User").classes("text-h6")
|
||||
edit_email = ui.input("Email").props("outlined dense").classes("w-full")
|
||||
edit_role = ui.select(["unprivileged", "admin"], value="unprivileged", label="Role").props("outlined dense").classes("w-full")
|
||||
edit_password = ui.input("New Password (leave blank to keep)", password=True, password_toggle_button=True).props("outlined dense").classes("w-full")
|
||||
edit_disabled = ui.switch("Disabled")
|
||||
with ui.row().classes("w-full justify-end q-mt-sm"):
|
||||
ui.button("Cancel", on_click=edit_dialog.close).props("flat")
|
||||
ui.button("Save", on_click=save_edit).props("color=primary")
|
||||
|
||||
await refresh_table()
|
||||
91
wiregui/pages/auth_magic.py
Normal file
91
wiregui/pages/auth_magic.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
"""Magic link authentication — request and verify signed JWT email links."""
|
||||
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
from nicegui import app, ui
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.jwt import create_access_token, decode_access_token
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.user import User
|
||||
from wiregui.services.email import send_magic_link
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
@ui.page("/auth/magic-link")
|
||||
async def magic_link_request_page():
|
||||
"""Page to request a magic link email."""
|
||||
if app.storage.user.get("authenticated"):
|
||||
return ui.navigate.to("/")
|
||||
|
||||
async def request_link():
|
||||
email_val = email_input.value.strip()
|
||||
if not email_val:
|
||||
ui.notify("Enter your email", type="negative")
|
||||
return
|
||||
|
||||
# Always show success to avoid user enumeration
|
||||
ui.notify("If an account exists, a sign-in link has been sent.", type="positive")
|
||||
|
||||
async with async_session() as session:
|
||||
result = await session.execute(select(User).where(User.email == email_val))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user and user.disabled_at is None:
|
||||
settings = get_settings()
|
||||
token = create_access_token(
|
||||
user_id=str(user.id),
|
||||
role=user.role,
|
||||
expires_delta=timedelta(minutes=15),
|
||||
)
|
||||
link = f"{settings.external_url}/auth/magic/{user.id}/{token}"
|
||||
await send_magic_link(email_val, link)
|
||||
logger.info("Magic link sent to {}", email_val)
|
||||
|
||||
with ui.column().classes("absolute-center items-center"):
|
||||
ui.label("WireGUI").classes("text-h4 text-bold")
|
||||
ui.label("Sign in with magic link").classes("text-subtitle1 q-mb-md")
|
||||
|
||||
with ui.card().classes("w-80"):
|
||||
email_input = ui.input("Email").props("outlined dense").classes("w-full")
|
||||
ui.button("Send Magic Link", on_click=request_link).classes("w-full q-mt-sm")
|
||||
email_input.on("keydown.enter", request_link)
|
||||
|
||||
ui.button("Back to login", on_click=lambda: ui.navigate.to("/login")).props("flat").classes("q-mt-md")
|
||||
|
||||
|
||||
@ui.page("/auth/magic/{user_id}/{token}")
|
||||
async def magic_link_verify_page(user_id: str, token: str):
|
||||
"""Verify a magic link token and sign the user in."""
|
||||
payload = decode_access_token(token)
|
||||
if not payload or payload.get("sub") != user_id:
|
||||
with ui.column().classes("absolute-center items-center"):
|
||||
ui.label("Invalid or expired link").classes("text-h5 text-negative")
|
||||
ui.button("Back to login", on_click=lambda: ui.navigate.to("/login")).props("flat")
|
||||
return
|
||||
|
||||
async with async_session() as session:
|
||||
user = await session.get(User, UUID(user_id))
|
||||
if not user or user.disabled_at is not None:
|
||||
with ui.column().classes("absolute-center items-center"):
|
||||
ui.label("Account not found or disabled").classes("text-h5 text-negative")
|
||||
ui.button("Back to login", on_click=lambda: ui.navigate.to("/login")).props("flat")
|
||||
return
|
||||
|
||||
user.last_signed_in_at = utcnow()
|
||||
user.last_signed_in_method = "magic_link"
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
logger.info("Magic link login: {}", user.email)
|
||||
|
||||
app.storage.user.update(
|
||||
authenticated=True,
|
||||
user_id=str(user.id),
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
)
|
||||
ui.navigate.to("/")
|
||||
120
wiregui/pages/auth_oidc.py
Normal file
120
wiregui/pages/auth_oidc.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
"""OIDC authentication routes — redirect to provider and handle callback."""
|
||||
|
||||
from loguru import logger
|
||||
from nicegui import app
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from wiregui.auth.oidc import get_client, get_provider_config
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
@app.get("/auth/oidc/{provider_id}")
|
||||
async def oidc_redirect(provider_id: str, request: Request):
|
||||
"""Redirect user to the OIDC provider's authorization endpoint."""
|
||||
try:
|
||||
client = get_client(provider_id)
|
||||
except ValueError:
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
settings = get_settings()
|
||||
redirect_uri = f"{settings.external_url}/auth/oidc/{provider_id}/callback"
|
||||
return await client.authorize_redirect(request, redirect_uri)
|
||||
|
||||
|
||||
@app.get("/auth/oidc/{provider_id}/callback")
|
||||
async def oidc_callback(provider_id: str, request: Request):
|
||||
"""Handle the OIDC provider callback — exchange code for tokens and create session."""
|
||||
try:
|
||||
client = get_client(provider_id)
|
||||
except ValueError:
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
try:
|
||||
token = await client.authorize_access_token(request)
|
||||
except Exception as e:
|
||||
logger.error("OIDC token exchange failed for {}: {}", provider_id, e)
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
userinfo = token.get("userinfo")
|
||||
if not userinfo:
|
||||
try:
|
||||
userinfo = await client.userinfo()
|
||||
except Exception as e:
|
||||
logger.error("OIDC userinfo failed for {}: {}", provider_id, e)
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
email = userinfo.get("email")
|
||||
if not email:
|
||||
logger.error("OIDC provider {} did not return email", provider_id)
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
provider_config = await get_provider_config(provider_id)
|
||||
auto_create = provider_config.get("auto_create_users", False) if provider_config else False
|
||||
|
||||
async with async_session() as session:
|
||||
# Find or create user
|
||||
result = await session.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
if not auto_create:
|
||||
logger.warning("OIDC: user {} not found and auto-create disabled for {}", email, provider_id)
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
user = User(email=email, role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
logger.info("OIDC: auto-created user {} via {}", email, provider_id)
|
||||
|
||||
if user.disabled_at is not None:
|
||||
logger.warning("OIDC: disabled user {} attempted login via {}", email, provider_id)
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
# Update sign-in tracking
|
||||
user.last_signed_in_at = utcnow()
|
||||
user.last_signed_in_method = f"oidc:{provider_id}"
|
||||
session.add(user)
|
||||
|
||||
# Store/update OIDC connection with refresh token
|
||||
refresh_token = token.get("refresh_token")
|
||||
existing_conn = (await session.execute(
|
||||
select(OIDCConnection).where(
|
||||
OIDCConnection.user_id == user.id,
|
||||
OIDCConnection.provider == provider_id,
|
||||
)
|
||||
)).scalar_one_or_none()
|
||||
|
||||
if existing_conn:
|
||||
existing_conn.refresh_token = refresh_token
|
||||
existing_conn.refreshed_at = utcnow()
|
||||
existing_conn.refresh_response = dict(token)
|
||||
session.add(existing_conn)
|
||||
else:
|
||||
conn = OIDCConnection(
|
||||
provider=provider_id,
|
||||
refresh_token=refresh_token,
|
||||
refresh_response=dict(token),
|
||||
refreshed_at=utcnow(),
|
||||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info("OIDC login: {} via {}", email, provider_id)
|
||||
|
||||
# Set NiceGUI session — store in Starlette session since we're in a plain route
|
||||
request.session["authenticated"] = True
|
||||
request.session["user_id"] = str(user.id)
|
||||
request.session["email"] = user.email
|
||||
request.session["role"] = user.role
|
||||
|
||||
return RedirectResponse(url="/")
|
||||
129
wiregui/pages/auth_saml.py
Normal file
129
wiregui/pages/auth_saml.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""SAML authentication routes — SP-initiated SSO redirect and ACS callback."""
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, Response
|
||||
from loguru import logger
|
||||
from nicegui import app
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.saml import create_saml_auth, get_login_url, get_metadata, process_response
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.user import User
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
async def _get_saml_provider(provider_id: str) -> dict | None:
|
||||
async with async_session() as session:
|
||||
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||
if not config:
|
||||
return None
|
||||
for p in config.saml_identity_providers or []:
|
||||
if p.get("id") == provider_id:
|
||||
return p
|
||||
return None
|
||||
|
||||
|
||||
def _request_data_from_fastapi(request: Request) -> dict:
|
||||
settings = get_settings()
|
||||
parsed = urlparse(settings.external_url)
|
||||
return {
|
||||
"http_host": parsed.hostname,
|
||||
"script_name": "",
|
||||
"server_port": parsed.port or (443 if parsed.scheme == "https" else 80),
|
||||
"get_data": dict(request.query_params),
|
||||
"post_data": {},
|
||||
"https": "on" if parsed.scheme == "https" else "off",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/auth/saml/{provider_id}")
|
||||
async def saml_redirect(provider_id: str, request: Request):
|
||||
"""Redirect user to the SAML IdP."""
|
||||
provider = await _get_saml_provider(provider_id)
|
||||
if not provider:
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
try:
|
||||
req_data = _request_data_from_fastapi(request)
|
||||
auth = create_saml_auth(provider, req_data)
|
||||
login_url = get_login_url(auth)
|
||||
return RedirectResponse(url=login_url)
|
||||
except Exception as e:
|
||||
logger.error("SAML redirect failed for {}: {}", provider_id, e)
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
|
||||
@app.post("/auth/saml/{provider_id}/callback")
|
||||
async def saml_callback(provider_id: str, request: Request):
|
||||
"""Handle the SAML ACS callback (POST with SAMLResponse)."""
|
||||
provider = await _get_saml_provider(provider_id)
|
||||
if not provider:
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
try:
|
||||
form_data = await request.form()
|
||||
req_data = _request_data_from_fastapi(request)
|
||||
req_data["post_data"] = dict(form_data)
|
||||
|
||||
auth = create_saml_auth(provider, req_data)
|
||||
user_data = process_response(auth)
|
||||
|
||||
if not user_data or not user_data.get("email"):
|
||||
logger.warning("SAML callback: no valid user data from {}", provider_id)
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
email = user_data["email"]
|
||||
auto_create = provider.get("auto_create_users", False)
|
||||
|
||||
async with async_session() as session:
|
||||
result = await session.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
if not auto_create:
|
||||
logger.warning("SAML: user {} not found, auto-create disabled for {}", email, provider_id)
|
||||
return RedirectResponse(url="/login")
|
||||
user = User(email=email, role="unprivileged")
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
logger.info("SAML: auto-created user {} via {}", email, provider_id)
|
||||
|
||||
if user.disabled_at is not None:
|
||||
logger.warning("SAML: disabled user {} attempted login via {}", email, provider_id)
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
user.last_signed_in_at = utcnow()
|
||||
user.last_signed_in_method = f"saml:{provider_id}"
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
request.session["authenticated"] = True
|
||||
request.session["user_id"] = str(user.id)
|
||||
request.session["email"] = user.email
|
||||
request.session["role"] = user.role
|
||||
|
||||
logger.info("SAML login: {} via {}", email, provider_id)
|
||||
return RedirectResponse(url="/", status_code=303)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("SAML callback failed for {}: {}", provider_id, e)
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
|
||||
@app.get("/auth/saml/{provider_id}/metadata")
|
||||
async def saml_metadata(provider_id: str):
|
||||
"""Return SP metadata XML for the SAML provider."""
|
||||
provider = await _get_saml_provider(provider_id)
|
||||
if not provider:
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
metadata_xml = get_metadata(provider)
|
||||
return Response(content=metadata_xml, media_type="application/xml")
|
||||
except Exception as e:
|
||||
logger.error("SAML metadata generation failed for {}: {}", provider_id, e)
|
||||
return Response(status_code=500)
|
||||
463
wiregui/pages/devices.py
Normal file
463
wiregui/pages/devices.py
Normal file
|
|
@ -0,0 +1,463 @@
|
|||
"""User-facing device management pages."""
|
||||
|
||||
import io
|
||||
from uuid import UUID
|
||||
|
||||
import qrcode
|
||||
import qrcode.image.svg
|
||||
from loguru import logger
|
||||
from nicegui import app, ui
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.pages.layout import layout
|
||||
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated
|
||||
from wiregui.utils.crypto import generate_keypair, generate_preshared_key
|
||||
from wiregui.utils.network import allocate_ipv4, allocate_ipv6
|
||||
from wiregui.utils.server_key import get_server_public_key
|
||||
from wiregui.utils.wg_conf import build_client_config
|
||||
|
||||
|
||||
def _format_bytes(b: int | None) -> str:
|
||||
if b is None:
|
||||
return "-"
|
||||
for unit in ("B", "KB", "MB", "GB", "TB"):
|
||||
if b < 1024:
|
||||
return f"{b:.1f} {unit}"
|
||||
b /= 1024
|
||||
return f"{b:.1f} PB"
|
||||
|
||||
|
||||
@ui.page("/devices")
|
||||
async def devices_page():
|
||||
if not app.storage.user.get("authenticated"):
|
||||
return ui.navigate.to("/login")
|
||||
|
||||
layout()
|
||||
user_id = UUID(app.storage.user["user_id"])
|
||||
|
||||
async def load_devices() -> list[Device]:
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Device).where(Device.user_id == user_id).order_by(Device.inserted_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def refresh_table():
|
||||
devices = await load_devices()
|
||||
table.rows = [
|
||||
{
|
||||
"id": str(d.id),
|
||||
"name": d.name,
|
||||
"description": d.description or "",
|
||||
"ipv4": d.ipv4 or "-",
|
||||
"ipv6": d.ipv6 or "-",
|
||||
"public_key": d.public_key[:16] + "...",
|
||||
"rx": _format_bytes(d.rx_bytes),
|
||||
"tx": _format_bytes(d.tx_bytes),
|
||||
"handshake": str(d.latest_handshake)[:19] if d.latest_handshake else "-",
|
||||
}
|
||||
for d in devices
|
||||
]
|
||||
table.update()
|
||||
|
||||
# --- Create device ---
|
||||
async def create_device():
|
||||
name = create_name.value.strip()
|
||||
if not name:
|
||||
ui.notify("Device name is required", type="negative")
|
||||
return
|
||||
|
||||
try:
|
||||
settings = get_settings()
|
||||
private_key, public_key = generate_keypair()
|
||||
psk = generate_preshared_key()
|
||||
|
||||
async with async_session() as session:
|
||||
ipv4 = await allocate_ipv4(session, settings.wg_ipv4_network)
|
||||
ipv6 = await allocate_ipv6(session, settings.wg_ipv6_network)
|
||||
|
||||
device = Device(
|
||||
name=name,
|
||||
description=create_desc.value.strip() or None,
|
||||
public_key=public_key,
|
||||
preshared_key=psk,
|
||||
ipv4=ipv4,
|
||||
ipv6=ipv6,
|
||||
user_id=user_id,
|
||||
use_default_allowed_ips=create_use_default_ips.value,
|
||||
use_default_dns=create_use_default_dns.value,
|
||||
use_default_endpoint=create_use_default_endpoint.value,
|
||||
use_default_mtu=create_use_default_mtu.value,
|
||||
use_default_persistent_keepalive=create_use_default_keepalive.value,
|
||||
endpoint=(create_endpoint.value.strip() or None
|
||||
if not create_use_default_endpoint.value else None),
|
||||
dns=([s.strip() for s in create_dns.value.split(",") if s.strip()]
|
||||
if not create_use_default_dns.value and create_dns.value else []),
|
||||
mtu=(int(create_mtu.value)
|
||||
if not create_use_default_mtu.value and create_mtu.value else None),
|
||||
persistent_keepalive=(int(create_keepalive.value)
|
||||
if not create_use_default_keepalive.value and create_keepalive.value else None),
|
||||
allowed_ips=([s.strip() for s in create_allowed_ips.value.split(",") if s.strip()]
|
||||
if not create_use_default_ips.value and create_allowed_ips.value else []),
|
||||
)
|
||||
session.add(device)
|
||||
await session.commit()
|
||||
await session.refresh(device)
|
||||
|
||||
logger.info("Device created: {} ({})", device.name, device.ipv4)
|
||||
await on_device_created(device)
|
||||
|
||||
server_pubkey = await get_server_public_key()
|
||||
config_text = build_client_config(device, private_key, server_pubkey)
|
||||
_show_config_dialog(device.name, config_text)
|
||||
|
||||
create_dialog.close()
|
||||
_reset_create_form()
|
||||
await refresh_table()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create device: {}", e)
|
||||
ui.notify(f"Error: {e}", type="negative")
|
||||
|
||||
def _reset_create_form():
|
||||
create_name.value = ""
|
||||
create_desc.value = ""
|
||||
create_use_default_ips.value = True
|
||||
create_use_default_dns.value = True
|
||||
create_use_default_endpoint.value = True
|
||||
create_use_default_mtu.value = True
|
||||
create_use_default_keepalive.value = True
|
||||
create_endpoint.value = ""
|
||||
create_dns.value = ""
|
||||
create_mtu.value = ""
|
||||
create_keepalive.value = ""
|
||||
create_allowed_ips.value = ""
|
||||
|
||||
# --- Delete device ---
|
||||
async def delete_device(device_id: str):
|
||||
async with async_session() as session:
|
||||
device = await session.get(Device, UUID(device_id))
|
||||
if device and device.user_id == user_id:
|
||||
await session.delete(device)
|
||||
await session.commit()
|
||||
logger.info("Device deleted: {}", device.name)
|
||||
await on_device_deleted(device)
|
||||
ui.notify(f"Deleted {device.name}")
|
||||
await refresh_table()
|
||||
|
||||
def on_row_click(e):
|
||||
ui.navigate.to(f"/devices/{e.args['id']}")
|
||||
|
||||
# --- Page content ---
|
||||
with ui.column().classes("w-full p-4"):
|
||||
with ui.row().classes("w-full items-center justify-between"):
|
||||
ui.label("My Devices").classes("text-h5")
|
||||
ui.button("Add Device", icon="add", on_click=lambda: create_dialog.open()).props("color=primary")
|
||||
|
||||
columns = [
|
||||
{"name": "name", "label": "Name", "field": "name", "align": "left", "sortable": True},
|
||||
{"name": "ipv4", "label": "IPv4", "field": "ipv4", "align": "left"},
|
||||
{"name": "ipv6", "label": "IPv6", "field": "ipv6", "align": "left"},
|
||||
{"name": "public_key", "label": "Public Key", "field": "public_key", "align": "left"},
|
||||
{"name": "rx", "label": "RX", "field": "rx", "align": "right"},
|
||||
{"name": "tx", "label": "TX", "field": "tx", "align": "right"},
|
||||
{"name": "handshake", "label": "Last Handshake", "field": "handshake", "align": "left"},
|
||||
{"name": "actions", "label": "", "field": "id", "align": "center"},
|
||||
]
|
||||
table = ui.table(columns=columns, rows=[], row_key="id").classes("w-full")
|
||||
table.on("rowClick", on_row_click)
|
||||
table.add_slot(
|
||||
"body-cell-actions",
|
||||
'''
|
||||
<q-td :props="props">
|
||||
<q-btn flat dense icon="delete" color="negative"
|
||||
@click.stop="() => $parent.$emit('delete', props.row.id)" />
|
||||
</q-td>
|
||||
''',
|
||||
)
|
||||
table.on("delete", lambda e: delete_device(e.args))
|
||||
|
||||
# --- Create device dialog (full form) ---
|
||||
with ui.dialog() as create_dialog:
|
||||
with ui.card().classes("w-[600px]"):
|
||||
ui.label("New Device").classes("text-h6")
|
||||
|
||||
create_name = ui.input("Device Name").props("outlined dense").classes("w-full")
|
||||
create_desc = ui.input("Description (optional)").props("outlined dense").classes("w-full")
|
||||
|
||||
ui.separator().classes("q-my-sm")
|
||||
ui.label("Configuration Overrides").classes("text-subtitle2")
|
||||
ui.label("Toggle off to set custom values instead of server defaults.").classes("text-caption text-grey-7")
|
||||
|
||||
with ui.grid(columns=2).classes("w-full gap-2"):
|
||||
create_use_default_ips = ui.switch("Use default Allowed IPs", value=True)
|
||||
create_allowed_ips = ui.input("Allowed IPs", placeholder="0.0.0.0/0, ::/0").props(
|
||||
"outlined dense"
|
||||
).classes("w-full").bind_enabled_from(create_use_default_ips, "value", backward=lambda v: not v)
|
||||
|
||||
create_use_default_dns = ui.switch("Use default DNS", value=True)
|
||||
create_dns = ui.input("DNS Servers", placeholder="1.1.1.1, 1.0.0.1").props(
|
||||
"outlined dense"
|
||||
).classes("w-full").bind_enabled_from(create_use_default_dns, "value", backward=lambda v: not v)
|
||||
|
||||
create_use_default_endpoint = ui.switch("Use default Endpoint", value=True)
|
||||
create_endpoint = ui.input("Endpoint", placeholder="vpn.example.com").props(
|
||||
"outlined dense"
|
||||
).classes("w-full").bind_enabled_from(create_use_default_endpoint, "value", backward=lambda v: not v)
|
||||
|
||||
create_use_default_mtu = ui.switch("Use default MTU", value=True)
|
||||
create_mtu = ui.input("MTU", placeholder="1280").props(
|
||||
"outlined dense"
|
||||
).classes("w-full").bind_enabled_from(create_use_default_mtu, "value", backward=lambda v: not v)
|
||||
|
||||
create_use_default_keepalive = ui.switch("Use default Keepalive", value=True)
|
||||
create_keepalive = ui.input("Persistent Keepalive", placeholder="25").props(
|
||||
"outlined dense"
|
||||
).classes("w-full").bind_enabled_from(create_use_default_keepalive, "value", backward=lambda v: not v)
|
||||
|
||||
with ui.row().classes("w-full justify-end q-mt-md"):
|
||||
ui.button("Cancel", on_click=create_dialog.close).props("flat")
|
||||
ui.button("Create", on_click=create_device).props("color=primary")
|
||||
|
||||
await refresh_table()
|
||||
|
||||
# Auto-refresh stats every 30 seconds
|
||||
ui.timer(30, refresh_table)
|
||||
|
||||
|
||||
@ui.page("/devices/{device_id}")
|
||||
async def device_detail_page(device_id: str):
|
||||
if not app.storage.user.get("authenticated"):
|
||||
return ui.navigate.to("/login")
|
||||
|
||||
layout()
|
||||
user_id = UUID(app.storage.user["user_id"])
|
||||
|
||||
async with async_session() as sess:
|
||||
device = await sess.get(Device, UUID(device_id))
|
||||
if not device or device.user_id != user_id:
|
||||
ui.label("Device not found").classes("text-h5 text-negative p-4")
|
||||
return
|
||||
|
||||
# --- Edit handlers ---
|
||||
async def save_edit():
|
||||
async with async_session() as session:
|
||||
d = await session.get(Device, UUID(device_id))
|
||||
if not d:
|
||||
return
|
||||
|
||||
d.name = edit_name.value.strip()
|
||||
d.description = edit_desc.value.strip() or None
|
||||
d.use_default_allowed_ips = edit_use_default_ips.value
|
||||
d.use_default_dns = edit_use_default_dns.value
|
||||
d.use_default_endpoint = edit_use_default_endpoint.value
|
||||
d.use_default_mtu = edit_use_default_mtu.value
|
||||
d.use_default_persistent_keepalive = edit_use_default_keepalive.value
|
||||
|
||||
if not d.use_default_endpoint:
|
||||
d.endpoint = edit_endpoint.value.strip() or None
|
||||
if not d.use_default_dns:
|
||||
d.dns = [s.strip() for s in edit_dns.value.split(",") if s.strip()]
|
||||
if not d.use_default_mtu:
|
||||
d.mtu = int(edit_mtu.value) if edit_mtu.value else None
|
||||
if not d.use_default_persistent_keepalive:
|
||||
d.persistent_keepalive = int(edit_keepalive.value) if edit_keepalive.value else None
|
||||
if not d.use_default_allowed_ips:
|
||||
d.allowed_ips = [s.strip() for s in edit_allowed_ips.value.split(",") if s.strip()]
|
||||
|
||||
session.add(d)
|
||||
await session.commit()
|
||||
await session.refresh(d)
|
||||
await on_device_updated(d)
|
||||
|
||||
logger.info("Device updated: {}", edit_name.value)
|
||||
ui.notify("Device updated", type="positive")
|
||||
ui.navigate.to(f"/devices/{device_id}")
|
||||
|
||||
async def delete_and_redirect():
|
||||
async with async_session() as session:
|
||||
d = await session.get(Device, UUID(device_id))
|
||||
if d:
|
||||
await session.delete(d)
|
||||
await session.commit()
|
||||
logger.info("Device deleted: {}", d.name)
|
||||
await on_device_deleted(d)
|
||||
ui.navigate.to("/devices")
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# --- Page content ---
|
||||
with ui.column().classes("w-full p-4"):
|
||||
with ui.row().classes("items-center gap-2"):
|
||||
ui.button(icon="arrow_back", on_click=lambda: ui.navigate.to("/devices")).props("flat")
|
||||
ui.label(device.name).classes("text-h5")
|
||||
if device.description:
|
||||
ui.label(f"— {device.description}").classes("text-subtitle1 text-grey-7")
|
||||
|
||||
# Device info card
|
||||
with ui.card().classes("w-full q-mt-md"):
|
||||
ui.label("Device Details").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
with ui.grid(columns=2).classes("w-full gap-2 q-pa-sm"):
|
||||
ui.label("Public Key:").classes("text-bold")
|
||||
ui.label(device.public_key).classes("font-mono text-sm")
|
||||
|
||||
ui.label("IPv4:").classes("text-bold")
|
||||
ui.label(device.ipv4 or "-")
|
||||
|
||||
ui.label("IPv6:").classes("text-bold")
|
||||
ui.label(device.ipv6 or "-")
|
||||
|
||||
ui.label("Created:").classes("text-bold")
|
||||
ui.label(str(device.inserted_at)[:19])
|
||||
|
||||
# Traffic stats (live-updating)
|
||||
with ui.card().classes("w-full q-mt-md"):
|
||||
ui.label("Traffic Stats").classes("text-subtitle1 text-bold")
|
||||
ui.label("Auto-refreshes every 30s").classes("text-caption text-grey-7")
|
||||
ui.separator()
|
||||
with ui.grid(columns=2).classes("w-full gap-2 q-pa-sm"):
|
||||
ui.label("RX:").classes("text-bold")
|
||||
stat_rx = ui.label(_format_bytes(device.rx_bytes))
|
||||
|
||||
ui.label("TX:").classes("text-bold")
|
||||
stat_tx = ui.label(_format_bytes(device.tx_bytes))
|
||||
|
||||
ui.label("Last Handshake:").classes("text-bold")
|
||||
stat_handshake = ui.label(str(device.latest_handshake)[:19] if device.latest_handshake else "-")
|
||||
|
||||
ui.label("Remote IP:").classes("text-bold")
|
||||
stat_remote = ui.label(device.remote_ip or "-")
|
||||
|
||||
async def refresh_stats():
|
||||
async with async_session() as session:
|
||||
d = await session.get(Device, UUID(device_id))
|
||||
if not d:
|
||||
return
|
||||
stat_rx.text = _format_bytes(d.rx_bytes)
|
||||
stat_tx.text = _format_bytes(d.tx_bytes)
|
||||
stat_handshake.text = str(d.latest_handshake)[:19] if d.latest_handshake else "-"
|
||||
stat_remote.text = d.remote_ip or "-"
|
||||
|
||||
ui.timer(30, refresh_stats)
|
||||
|
||||
# Active configuration
|
||||
with ui.card().classes("w-full q-mt-md"):
|
||||
ui.label("Active Configuration").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
with ui.grid(columns=2).classes("w-full gap-2 q-pa-sm"):
|
||||
_ips = device.allowed_ips if not device.use_default_allowed_ips else settings.wg_allowed_ips
|
||||
ui.label("Allowed IPs:").classes("text-bold")
|
||||
ui.label(str(_ips) if isinstance(_ips, str) else ", ".join(_ips) if _ips else "-")
|
||||
|
||||
_dns = device.dns if not device.use_default_dns else settings.wg_dns
|
||||
ui.label("DNS:").classes("text-bold")
|
||||
ui.label(str(_dns) if isinstance(_dns, str) else ", ".join(_dns) if _dns else "-")
|
||||
|
||||
_ep = device.endpoint if not device.use_default_endpoint else settings.wg_endpoint_host
|
||||
ui.label("Endpoint:").classes("text-bold")
|
||||
ui.label(f"{_ep}:{settings.wg_endpoint_port}" if _ep else "-")
|
||||
|
||||
_mtu = device.mtu if not device.use_default_mtu else settings.wg_mtu
|
||||
ui.label("MTU:").classes("text-bold")
|
||||
ui.label(str(_mtu) if _mtu else "-")
|
||||
|
||||
_ka = device.persistent_keepalive if not device.use_default_persistent_keepalive else settings.wg_persistent_keepalive
|
||||
ui.label("Persistent Keepalive:").classes("text-bold")
|
||||
ui.label(str(_ka) if _ka else "-")
|
||||
|
||||
# Edit form
|
||||
with ui.card().classes("w-full q-mt-md"):
|
||||
ui.label("Edit Device").classes("text-subtitle1 text-bold")
|
||||
ui.separator()
|
||||
|
||||
edit_name = ui.input("Device Name", value=device.name).props("outlined dense").classes("w-full")
|
||||
edit_desc = ui.input("Description", value=device.description or "").props("outlined dense").classes("w-full")
|
||||
|
||||
ui.separator().classes("q-my-sm")
|
||||
ui.label("Configuration Overrides").classes("text-subtitle2")
|
||||
|
||||
with ui.grid(columns=2).classes("w-full gap-2"):
|
||||
edit_use_default_ips = ui.switch("Use default Allowed IPs", value=device.use_default_allowed_ips)
|
||||
edit_allowed_ips = ui.input(
|
||||
"Allowed IPs", value=", ".join(device.allowed_ips) if device.allowed_ips else "",
|
||||
).props("outlined dense").classes("w-full").bind_enabled_from(
|
||||
edit_use_default_ips, "value", backward=lambda v: not v
|
||||
)
|
||||
|
||||
edit_use_default_dns = ui.switch("Use default DNS", value=device.use_default_dns)
|
||||
edit_dns = ui.input(
|
||||
"DNS Servers", value=", ".join(device.dns) if device.dns else "",
|
||||
).props("outlined dense").classes("w-full").bind_enabled_from(
|
||||
edit_use_default_dns, "value", backward=lambda v: not v
|
||||
)
|
||||
|
||||
edit_use_default_endpoint = ui.switch("Use default Endpoint", value=device.use_default_endpoint)
|
||||
edit_endpoint = ui.input(
|
||||
"Endpoint", value=device.endpoint or "",
|
||||
).props("outlined dense").classes("w-full").bind_enabled_from(
|
||||
edit_use_default_endpoint, "value", backward=lambda v: not v
|
||||
)
|
||||
|
||||
edit_use_default_mtu = ui.switch("Use default MTU", value=device.use_default_mtu)
|
||||
edit_mtu = ui.input(
|
||||
"MTU", value=str(device.mtu) if device.mtu else "",
|
||||
).props("outlined dense").classes("w-full").bind_enabled_from(
|
||||
edit_use_default_mtu, "value", backward=lambda v: not v
|
||||
)
|
||||
|
||||
edit_use_default_keepalive = ui.switch("Use default Keepalive", value=device.use_default_persistent_keepalive)
|
||||
edit_keepalive = ui.input(
|
||||
"Persistent Keepalive", value=str(device.persistent_keepalive) if device.persistent_keepalive else "",
|
||||
).props("outlined dense").classes("w-full").bind_enabled_from(
|
||||
edit_use_default_keepalive, "value", backward=lambda v: not v
|
||||
)
|
||||
|
||||
ui.button("Save Changes", on_click=save_edit).props("color=primary").classes("q-mt-md")
|
||||
|
||||
# Danger zone
|
||||
with ui.card().classes("w-full q-mt-md"):
|
||||
ui.label("Danger Zone").classes("text-subtitle1 text-bold text-negative")
|
||||
ui.separator()
|
||||
ui.button("Delete Device", icon="delete", on_click=lambda: confirm_dialog.open()).props(
|
||||
"color=negative outline"
|
||||
)
|
||||
|
||||
# Confirm delete dialog
|
||||
with ui.dialog() as confirm_dialog:
|
||||
with ui.card().classes("w-80"):
|
||||
ui.label("Delete Device?").classes("text-h6")
|
||||
ui.label(f"This will permanently remove '{device.name}' and its WireGuard peer.").classes("text-body2")
|
||||
with ui.row().classes("w-full justify-end q-mt-sm"):
|
||||
ui.button("Cancel", on_click=confirm_dialog.close).props("flat")
|
||||
ui.button("Delete", on_click=delete_and_redirect).props("color=negative")
|
||||
|
||||
|
||||
def _show_config_dialog(device_name: str, config_text: str):
|
||||
"""Show a dialog with the WireGuard client configuration and QR code."""
|
||||
with ui.dialog(value=True) as dialog:
|
||||
with ui.card().classes("w-96"):
|
||||
ui.label(f"Config for {device_name}").classes("text-h6")
|
||||
ui.label("Save this — the private key won't be shown again.").classes("text-caption text-negative")
|
||||
|
||||
ui.textarea(value=config_text).props("readonly outlined").classes(
|
||||
"w-full font-mono text-xs q-mt-sm"
|
||||
).style("min-height: 200px")
|
||||
|
||||
try:
|
||||
qr = qrcode.make(config_text, image_factory=qrcode.image.svg.SvgPathImage)
|
||||
buf = io.BytesIO()
|
||||
qr.save(buf)
|
||||
ui.html(buf.getvalue().decode()).classes("w-full q-mt-sm")
|
||||
except Exception:
|
||||
ui.label("QR code generation failed").classes("text-caption text-grey")
|
||||
|
||||
ui.button(
|
||||
"Download .conf",
|
||||
on_click=lambda: ui.download(config_text.encode(), f"{device_name}.conf"),
|
||||
).props("color=primary outline").classes("w-full q-mt-sm")
|
||||
|
||||
ui.button("Close", on_click=dialog.close).props("flat").classes("w-full")
|
||||
9
wiregui/pages/home.py
Normal file
9
wiregui/pages/home.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from nicegui import app, ui
|
||||
|
||||
|
||||
@ui.page("/")
|
||||
def home_page():
|
||||
if not app.storage.user.get("authenticated"):
|
||||
return ui.navigate.to("/login")
|
||||
# Redirect to devices as the main landing page
|
||||
return ui.navigate.to("/devices")
|
||||
48
wiregui/pages/layout.py
Normal file
48
wiregui/pages/layout.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Shared layout — sidebar navigation + header."""
|
||||
|
||||
from nicegui import app, ui
|
||||
|
||||
from wiregui.services import notifications
|
||||
|
||||
|
||||
def layout(title: str = "WireGUI"):
|
||||
"""Render the shared app chrome (header + sidebar). Call at the top of each page."""
|
||||
user_email = app.storage.user.get("email", "")
|
||||
role = app.storage.user.get("role", "")
|
||||
|
||||
def logout():
|
||||
app.storage.user.clear()
|
||||
ui.navigate.to("/login")
|
||||
|
||||
# Header
|
||||
with ui.header().classes("items-center justify-between"):
|
||||
with ui.row().classes("items-center"):
|
||||
ui.button(icon="menu", on_click=lambda: drawer.toggle()).props("flat color=white")
|
||||
ui.label("WireGUI").classes("text-h6")
|
||||
with ui.row().classes("items-center"):
|
||||
if role == "admin":
|
||||
notif_count = notifications.count()
|
||||
with ui.button(
|
||||
icon="notifications",
|
||||
on_click=lambda: ui.navigate.to("/admin/diagnostics"),
|
||||
).props("flat color=white"):
|
||||
if notif_count > 0:
|
||||
ui.badge(str(notif_count), color="red").props("floating")
|
||||
ui.label(f"{user_email}").classes("text-subtitle2")
|
||||
ui.button("Logout", on_click=logout).props("flat color=white")
|
||||
|
||||
# Sidebar
|
||||
with ui.left_drawer(value=True, bordered=True).classes("bg-grey-1") as drawer:
|
||||
ui.label("Navigation").classes("text-subtitle2 q-pa-sm text-grey-7")
|
||||
ui.separator()
|
||||
ui.item("Devices", on_click=lambda: ui.navigate.to("/devices")).classes("cursor-pointer")
|
||||
ui.item("Account", on_click=lambda: ui.navigate.to("/account")).classes("cursor-pointer")
|
||||
|
||||
if role == "admin":
|
||||
ui.separator()
|
||||
ui.label("Admin").classes("text-subtitle2 q-pa-sm text-grey-7")
|
||||
ui.item("Users", on_click=lambda: ui.navigate.to("/admin/users")).classes("cursor-pointer")
|
||||
ui.item("All Devices", on_click=lambda: ui.navigate.to("/admin/devices")).classes("cursor-pointer")
|
||||
ui.item("Rules", on_click=lambda: ui.navigate.to("/admin/rules")).classes("cursor-pointer")
|
||||
ui.item("Settings", on_click=lambda: ui.navigate.to("/admin/settings")).classes("cursor-pointer")
|
||||
ui.item("Diagnostics", on_click=lambda: ui.navigate.to("/admin/diagnostics")).classes("cursor-pointer")
|
||||
82
wiregui/pages/login.py
Normal file
82
wiregui/pages/login.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
"""Login page — email/password, MFA redirect, OIDC provider buttons."""
|
||||
|
||||
from nicegui import app, ui
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.oidc import load_providers
|
||||
from wiregui.auth.session import authenticate_user
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
|
||||
@ui.page("/login")
|
||||
async def login_page():
|
||||
if app.storage.user.get("authenticated"):
|
||||
return ui.navigate.to("/")
|
||||
|
||||
# Load OIDC providers for SSO buttons
|
||||
oidc_providers = await load_providers()
|
||||
|
||||
async def try_login():
|
||||
user = await authenticate_user(email.value, password.value)
|
||||
if user is None:
|
||||
ui.notify("Invalid email or password", type="negative")
|
||||
return
|
||||
|
||||
# Check if user has MFA methods
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user.id)
|
||||
)
|
||||
mfa_methods = result.scalars().all()
|
||||
|
||||
# Update sign-in tracking
|
||||
user_record = await session.get(type(user), user.id)
|
||||
user_record.last_signed_in_at = utcnow()
|
||||
user_record.last_signed_in_method = "local"
|
||||
session.add(user_record)
|
||||
await session.commit()
|
||||
|
||||
if mfa_methods:
|
||||
# Store pending auth and redirect to MFA challenge
|
||||
app.storage.user["pending_mfa"] = {
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"role": user.role,
|
||||
}
|
||||
ui.navigate.to("/mfa")
|
||||
else:
|
||||
# No MFA — complete login directly
|
||||
app.storage.user.update(
|
||||
authenticated=True,
|
||||
user_id=str(user.id),
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
)
|
||||
ui.navigate.to("/")
|
||||
|
||||
with ui.column().classes("absolute-center items-center"):
|
||||
ui.label("WireGUI").classes("text-h4 text-bold")
|
||||
ui.label("Sign in to your account").classes("text-subtitle1 q-mb-md")
|
||||
|
||||
with ui.card().classes("w-80"):
|
||||
email = ui.input("Email").props("outlined dense").classes("w-full")
|
||||
password = ui.input("Password", password=True, password_toggle_button=True).props(
|
||||
"outlined dense"
|
||||
).classes("w-full")
|
||||
ui.button("Sign in", on_click=try_login).classes("w-full q-mt-sm")
|
||||
|
||||
password.on("keydown.enter", try_login)
|
||||
|
||||
# OIDC provider buttons
|
||||
if oidc_providers:
|
||||
ui.separator().classes("q-my-md")
|
||||
ui.label("Or sign in with").classes("text-caption text-center w-full")
|
||||
for provider in oidc_providers:
|
||||
pid = provider.get("id", "")
|
||||
label = provider.get("label", pid)
|
||||
ui.button(
|
||||
label,
|
||||
on_click=lambda p=pid: ui.navigate.to(f"/auth/oidc/{p}"),
|
||||
).props("outline").classes("w-full q-mt-xs")
|
||||
93
wiregui/pages/mfa_challenge.py
Normal file
93
wiregui/pages/mfa_challenge.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""MFA challenge page — presented after password login when user has MFA enabled."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
from nicegui import app, ui
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.auth.mfa import verify_totp_code
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.mfa_method import MFAMethod
|
||||
|
||||
|
||||
@ui.page("/mfa")
|
||||
async def mfa_challenge_page():
|
||||
# Must have passed password auth (pending_mfa set by login page)
|
||||
pending = app.storage.user.get("pending_mfa")
|
||||
if not pending:
|
||||
return ui.navigate.to("/login")
|
||||
|
||||
user_id = UUID(pending["user_id"])
|
||||
|
||||
# Load user's MFA methods
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(MFAMethod).where(MFAMethod.user_id == user_id)
|
||||
)
|
||||
methods = result.scalars().all()
|
||||
|
||||
totp_methods = [m for m in methods if m.type == "totp"]
|
||||
|
||||
if not totp_methods:
|
||||
# No MFA methods — shouldn't be here, complete login
|
||||
_complete_login(pending)
|
||||
return
|
||||
|
||||
async def verify_code():
|
||||
code = code_input.value.strip()
|
||||
if not code:
|
||||
ui.notify("Enter your authentication code", type="negative")
|
||||
return
|
||||
|
||||
for method in totp_methods:
|
||||
secret = method.payload.get("secret")
|
||||
if secret and verify_totp_code(secret, code):
|
||||
# Update last used
|
||||
async with async_session() as session:
|
||||
m = await session.get(MFAMethod, method.id)
|
||||
if m:
|
||||
from wiregui.utils.time import utcnow
|
||||
m.last_used_at = utcnow()
|
||||
session.add(m)
|
||||
await session.commit()
|
||||
|
||||
logger.info("MFA verified for user {}", pending["email"])
|
||||
_complete_login(pending)
|
||||
return
|
||||
|
||||
ui.notify("Invalid code", type="negative")
|
||||
|
||||
with ui.column().classes("absolute-center items-center"):
|
||||
ui.label("Two-Factor Authentication").classes("text-h5")
|
||||
ui.label(f"Enter the code from your authenticator app for {pending['email']}").classes(
|
||||
"text-subtitle1 q-mb-md"
|
||||
)
|
||||
|
||||
with ui.card().classes("w-80"):
|
||||
code_input = ui.input("Authentication Code").props(
|
||||
"outlined dense maxlength=6"
|
||||
).classes("w-full text-center font-mono text-lg")
|
||||
ui.button("Verify", on_click=verify_code).classes("w-full q-mt-sm")
|
||||
|
||||
code_input.on("keydown.enter", verify_code)
|
||||
|
||||
ui.button("Cancel", on_click=lambda: _cancel_mfa()).props("flat").classes("q-mt-md")
|
||||
|
||||
|
||||
def _complete_login(pending: dict):
|
||||
"""Complete the login by setting full auth state."""
|
||||
app.storage.user.update(
|
||||
authenticated=True,
|
||||
user_id=pending["user_id"],
|
||||
email=pending["email"],
|
||||
role=pending["role"],
|
||||
)
|
||||
# Clear pending state
|
||||
app.storage.user.pop("pending_mfa", None)
|
||||
ui.navigate.to("/")
|
||||
|
||||
|
||||
def _cancel_mfa():
|
||||
app.storage.user.clear()
|
||||
ui.navigate.to("/login")
|
||||
9
wiregui/redis.py
Normal file
9
wiregui/redis.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
import redis.asyncio as redis
|
||||
|
||||
from wiregui.config import get_settings
|
||||
|
||||
pool = redis.ConnectionPool.from_url(get_settings().redis_url)
|
||||
|
||||
|
||||
def get_redis() -> redis.Redis:
|
||||
return redis.Redis(connection_pool=pool)
|
||||
0
wiregui/schemas/__init__.py
Normal file
0
wiregui/schemas/__init__.py
Normal file
37
wiregui/schemas/configuration.py
Normal file
37
wiregui/schemas/configuration.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ConfigurationRead(BaseModel):
|
||||
id: UUID
|
||||
allow_unprivileged_device_management: bool
|
||||
allow_unprivileged_device_configuration: bool
|
||||
local_auth_enabled: bool
|
||||
disable_vpn_on_oidc_error: bool
|
||||
default_client_persistent_keepalive: int
|
||||
default_client_mtu: int
|
||||
default_client_endpoint: str | None
|
||||
default_client_dns: list[str]
|
||||
default_client_allowed_ips: list[str]
|
||||
vpn_session_duration: int
|
||||
logo_url: str | None
|
||||
logo_type: str | None
|
||||
inserted_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class ConfigurationUpdate(BaseModel):
|
||||
allow_unprivileged_device_management: bool | None = None
|
||||
allow_unprivileged_device_configuration: bool | None = None
|
||||
local_auth_enabled: bool | None = None
|
||||
disable_vpn_on_oidc_error: bool | None = None
|
||||
default_client_persistent_keepalive: int | None = None
|
||||
default_client_mtu: int | None = None
|
||||
default_client_endpoint: str | None = None
|
||||
default_client_dns: list[str] | None = None
|
||||
default_client_allowed_ips: list[str] | None = None
|
||||
vpn_session_duration: int | None = None
|
||||
logo_url: str | None = None
|
||||
logo_type: str | None = None
|
||||
50
wiregui/schemas/device.py
Normal file
50
wiregui/schemas/device.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DeviceRead(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
description: str | None
|
||||
public_key: str
|
||||
ipv4: str | None
|
||||
ipv6: str | None
|
||||
use_default_allowed_ips: bool
|
||||
use_default_dns: bool
|
||||
use_default_endpoint: bool
|
||||
use_default_mtu: bool
|
||||
use_default_persistent_keepalive: bool
|
||||
endpoint: str | None
|
||||
mtu: int | None
|
||||
persistent_keepalive: int | None
|
||||
allowed_ips: list[str]
|
||||
dns: list[str]
|
||||
rx_bytes: int | None
|
||||
tx_bytes: int | None
|
||||
latest_handshake: datetime | None
|
||||
user_id: UUID
|
||||
inserted_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class DeviceCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
user_id: UUID | None = None # admin can assign to another user
|
||||
|
||||
|
||||
class DeviceUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
use_default_allowed_ips: bool | None = None
|
||||
use_default_dns: bool | None = None
|
||||
use_default_endpoint: bool | None = None
|
||||
use_default_mtu: bool | None = None
|
||||
use_default_persistent_keepalive: bool | None = None
|
||||
endpoint: str | None = None
|
||||
mtu: int | None = None
|
||||
persistent_keepalive: int | None = None
|
||||
allowed_ips: list[str] | None = None
|
||||
dns: list[str] | None = None
|
||||
31
wiregui/schemas/rule.py
Normal file
31
wiregui/schemas/rule.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RuleRead(BaseModel):
|
||||
id: UUID
|
||||
action: str
|
||||
destination: str
|
||||
port_type: str | None
|
||||
port_range: str | None
|
||||
user_id: UUID | None
|
||||
inserted_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class RuleCreate(BaseModel):
|
||||
action: str = "drop"
|
||||
destination: str
|
||||
port_type: str | None = None
|
||||
port_range: str | None = None
|
||||
user_id: UUID | None = None
|
||||
|
||||
|
||||
class RuleUpdate(BaseModel):
|
||||
action: str | None = None
|
||||
destination: str | None = None
|
||||
port_type: str | None = None
|
||||
port_range: str | None = None
|
||||
user_id: UUID | None = None
|
||||
28
wiregui/schemas/user.py
Normal file
28
wiregui/schemas/user.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
|
||||
class UserRead(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
role: str
|
||||
disabled_at: datetime | None
|
||||
last_signed_in_at: datetime | None
|
||||
last_signed_in_method: str | None
|
||||
inserted_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
role: str = "unprivileged"
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
email: str | None = None
|
||||
password: str | None = None
|
||||
role: str | None = None
|
||||
disabled_at: datetime | None = None
|
||||
0
wiregui/services/__init__.py
Normal file
0
wiregui/services/__init__.py
Normal file
51
wiregui/services/email.py
Normal file
51
wiregui/services/email.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
"""Email sending via aiosmtplib for magic links and notifications."""
|
||||
|
||||
import aiosmtplib
|
||||
from email.message import EmailMessage
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from wiregui.config import get_settings
|
||||
|
||||
|
||||
async def send_email(to: str, subject: str, body: str) -> bool:
|
||||
"""Send an email via configured SMTP. Returns True on success."""
|
||||
settings = get_settings()
|
||||
|
||||
if not settings.smtp_host:
|
||||
logger.warning("SMTP not configured — email to {} not sent", to)
|
||||
return False
|
||||
|
||||
msg = EmailMessage()
|
||||
msg["From"] = settings.smtp_from
|
||||
msg["To"] = to
|
||||
msg["Subject"] = subject
|
||||
msg.set_content(body)
|
||||
|
||||
try:
|
||||
await aiosmtplib.send(
|
||||
msg,
|
||||
hostname=settings.smtp_host,
|
||||
port=settings.smtp_port,
|
||||
username=settings.smtp_user,
|
||||
password=settings.smtp_password,
|
||||
start_tls=True,
|
||||
)
|
||||
logger.info("Email sent to {}: {}", to, subject)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to send email to {}: {}", to, e)
|
||||
return False
|
||||
|
||||
|
||||
async def send_magic_link(to: str, link: str) -> bool:
|
||||
"""Send a magic link sign-in email."""
|
||||
subject = "WireGUI — Sign in link"
|
||||
body = f"""You requested a sign-in link for WireGUI.
|
||||
|
||||
Click here to sign in:
|
||||
{link}
|
||||
|
||||
This link expires in 15 minutes. If you didn't request this, you can safely ignore this email.
|
||||
"""
|
||||
return await send_email(to, subject, body)
|
||||
136
wiregui/services/events.py
Normal file
136
wiregui/services/events.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""Event bridge — propagates database changes to WireGuard and firewall."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from wiregui.config import get_settings
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.device import Device
|
||||
from wiregui.models.rule import Rule
|
||||
from wiregui.services import firewall, wireguard
|
||||
|
||||
|
||||
def _device_allowed_ips(device: Device) -> list[str]:
|
||||
"""Build the allowed-ips list for a device peer (its tunnel addresses)."""
|
||||
ips = []
|
||||
if device.ipv4:
|
||||
ips.append(f"{device.ipv4}/32")
|
||||
if device.ipv6:
|
||||
ips.append(f"{device.ipv6}/128")
|
||||
return ips
|
||||
|
||||
|
||||
# --- Device events ---
|
||||
|
||||
|
||||
async def on_device_created(device: Device) -> None:
|
||||
"""Configure WireGuard peer and firewall after a new device is created."""
|
||||
settings = get_settings()
|
||||
if not settings.wg_enabled:
|
||||
return
|
||||
try:
|
||||
await wireguard.add_peer(
|
||||
public_key=device.public_key,
|
||||
allowed_ips=_device_allowed_ips(device),
|
||||
preshared_key=device.preshared_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to add WG peer for device {}: {}", device.name, e)
|
||||
|
||||
try:
|
||||
# Ensure user chain exists before adding jump rules
|
||||
await firewall.add_user_chain(str(device.user_id))
|
||||
await firewall.add_device_jump_rule(
|
||||
str(device.user_id), device.ipv4, device.ipv6,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to add firewall jump rule for device {}: {}", device.name, e)
|
||||
|
||||
|
||||
async def on_device_deleted(device: Device) -> None:
|
||||
"""Remove WireGuard peer after a device is deleted."""
|
||||
if not get_settings().wg_enabled:
|
||||
return
|
||||
try:
|
||||
await wireguard.remove_peer(public_key=device.public_key)
|
||||
except Exception as e:
|
||||
logger.error("Failed to remove WG peer for device {}: {}", device.name, e)
|
||||
# Firewall jump rules are cleaned up on next rebuild
|
||||
|
||||
|
||||
async def on_device_updated(device: Device) -> None:
|
||||
"""Update WireGuard peer after a device is modified."""
|
||||
if not get_settings().wg_enabled:
|
||||
return
|
||||
try:
|
||||
await wireguard.add_peer(
|
||||
public_key=device.public_key,
|
||||
allowed_ips=_device_allowed_ips(device),
|
||||
preshared_key=device.preshared_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to update WG peer for device {}: {}", device.name, e)
|
||||
|
||||
|
||||
# --- Rule events ---
|
||||
|
||||
|
||||
async def on_rule_created(rule: Rule) -> None:
|
||||
"""Apply a new firewall rule."""
|
||||
if not get_settings().wg_enabled:
|
||||
return
|
||||
if rule.user_id is None:
|
||||
return # global rules handled via rebuild
|
||||
try:
|
||||
await firewall.apply_rule(
|
||||
str(rule.user_id), rule.destination, rule.action,
|
||||
rule.port_type, rule.port_range,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to apply firewall rule {}: {}", rule.id, e)
|
||||
|
||||
|
||||
async def on_rule_updated(rule: Rule) -> None:
|
||||
"""Firewall rule updated — rebuild the user's chain."""
|
||||
if not get_settings().wg_enabled:
|
||||
return
|
||||
if rule.user_id is None:
|
||||
return
|
||||
await _rebuild_user_chain(str(rule.user_id))
|
||||
|
||||
|
||||
async def on_rule_deleted(rule: Rule) -> None:
|
||||
"""Firewall rule removed — rebuild the user's chain."""
|
||||
if not get_settings().wg_enabled:
|
||||
return
|
||||
if rule.user_id is None:
|
||||
return
|
||||
await _rebuild_user_chain(str(rule.user_id))
|
||||
|
||||
|
||||
async def _rebuild_user_chain(user_id: str) -> None:
|
||||
"""Flush and rebuild a single user's firewall chain from current DB rules."""
|
||||
try:
|
||||
from sqlmodel import select as sel
|
||||
|
||||
async with async_session() as session:
|
||||
rules = (await session.execute(
|
||||
sel(Rule).where(Rule.user_id == UUID(user_id))
|
||||
)).scalars().all()
|
||||
|
||||
devices = (await session.execute(
|
||||
sel(Device).where(Device.user_id == UUID(user_id))
|
||||
)).scalars().all()
|
||||
|
||||
await firewall.rebuild_all_rules([{
|
||||
"user_id": user_id,
|
||||
"devices": [{"ipv4": d.ipv4, "ipv6": d.ipv6} for d in devices],
|
||||
"rules": [
|
||||
{"destination": r.destination, "action": r.action,
|
||||
"port_type": r.port_type, "port_range": r.port_range}
|
||||
for r in rules
|
||||
],
|
||||
}])
|
||||
except Exception as e:
|
||||
logger.error("Failed to rebuild firewall chain for user {}: {}", user_id, e)
|
||||
191
wiregui/services/firewall.py
Normal file
191
wiregui/services/firewall.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
"""nftables firewall management — per-user chains and sets for device traffic filtering."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from wiregui.config import get_settings
|
||||
|
||||
TABLE_NAME = "wiregui"
|
||||
|
||||
|
||||
async def _nft(cmd: str) -> str:
|
||||
"""Run an nft command and return stdout."""
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"nft", *cmd.split(),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"nft {cmd} failed: {stderr.decode().strip()}")
|
||||
return stdout.decode().strip()
|
||||
|
||||
|
||||
async def _nft_batch(commands: list[str]) -> None:
|
||||
"""Run multiple nft commands in a single atomic batch."""
|
||||
batch = "\n".join(commands)
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"nft", "-f", "-",
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate(batch.encode())
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"nft batch failed: {stderr.decode().strip()}")
|
||||
|
||||
|
||||
async def setup_base_tables() -> None:
|
||||
"""Create the base wiregui table with forward and postrouting chains."""
|
||||
commands = [
|
||||
f"add table inet {TABLE_NAME}",
|
||||
# Forward chain for filtering device traffic
|
||||
f"add chain inet {TABLE_NAME} forward {{ type filter hook forward priority 0; policy accept; }}",
|
||||
# Postrouting for NAT/masquerade
|
||||
f"add chain inet {TABLE_NAME} postrouting {{ type nat hook postrouting priority 100; policy accept; }}",
|
||||
]
|
||||
try:
|
||||
await _nft_batch(commands)
|
||||
logger.info("Base nftables table '{}' created", TABLE_NAME)
|
||||
except RuntimeError as e:
|
||||
# Table may already exist
|
||||
if "File exists" not in str(e):
|
||||
raise
|
||||
logger.debug("Base nftables table '{}' already exists", TABLE_NAME)
|
||||
|
||||
|
||||
async def setup_masquerade(iface: str | None = None) -> None:
|
||||
"""Add masquerade rules for VPN traffic — NAT only traffic originating from WG subnets."""
|
||||
settings = get_settings()
|
||||
iface = iface or settings.wg_interface
|
||||
v4_net = settings.wg_ipv4_network
|
||||
v6_net = settings.wg_ipv6_network
|
||||
commands = [
|
||||
f"flush chain inet {TABLE_NAME} postrouting",
|
||||
f'add rule inet {TABLE_NAME} postrouting ip saddr {v4_net} oifname != "{iface}" masquerade',
|
||||
f'add rule inet {TABLE_NAME} postrouting ip6 saddr {v6_net} oifname != "{iface}" masquerade',
|
||||
]
|
||||
try:
|
||||
await _nft_batch(commands)
|
||||
logger.info("Masquerade rule added for {}", iface)
|
||||
except RuntimeError as e:
|
||||
logger.debug("Masquerade setup: {}", e)
|
||||
|
||||
|
||||
async def add_user_chain(user_id: str) -> None:
|
||||
"""Create a per-user chain for firewall rules."""
|
||||
chain = _user_chain_name(user_id)
|
||||
commands = [
|
||||
f"add chain inet {TABLE_NAME} {chain}",
|
||||
]
|
||||
try:
|
||||
await _nft_batch(commands)
|
||||
logger.debug("User chain created: {}", chain)
|
||||
except RuntimeError as e:
|
||||
if "File exists" not in str(e):
|
||||
raise
|
||||
|
||||
|
||||
async def remove_user_chain(user_id: str) -> None:
|
||||
"""Remove a per-user chain and all its rules."""
|
||||
chain = _user_chain_name(user_id)
|
||||
try:
|
||||
await _nft_batch([
|
||||
f"flush chain inet {TABLE_NAME} {chain}",
|
||||
f"delete chain inet {TABLE_NAME} {chain}",
|
||||
])
|
||||
logger.debug("User chain removed: {}", chain)
|
||||
except RuntimeError as e:
|
||||
logger.debug("Remove user chain {}: {}", chain, e)
|
||||
|
||||
|
||||
async def add_device_jump_rule(user_id: str, device_ipv4: str | None, device_ipv6: str | None) -> None:
|
||||
"""Add jump rules in the forward chain to route device traffic to the user chain."""
|
||||
chain = _user_chain_name(user_id)
|
||||
commands = []
|
||||
if device_ipv4:
|
||||
commands.append(
|
||||
f"add rule inet {TABLE_NAME} forward ip saddr {device_ipv4} jump {chain}"
|
||||
)
|
||||
if device_ipv6:
|
||||
commands.append(
|
||||
f"add rule inet {TABLE_NAME} forward ip6 saddr {device_ipv6} jump {chain}"
|
||||
)
|
||||
if commands:
|
||||
await _nft_batch(commands)
|
||||
logger.debug("Jump rules added for device {}/{} -> {}", device_ipv4, device_ipv6, chain)
|
||||
|
||||
|
||||
async def apply_rule(user_id: str, destination: str, action: str, port_type: str | None = None, port_range: str | None = None) -> None:
|
||||
"""Add a filter rule to a user's chain."""
|
||||
chain = _user_chain_name(user_id)
|
||||
rule = _build_rule_expr(destination, action, port_type, port_range)
|
||||
await _nft_batch([f"add rule inet {TABLE_NAME} {chain} {rule}"])
|
||||
logger.debug("Rule applied in {}: {} -> {}", chain, destination, action)
|
||||
|
||||
|
||||
async def rebuild_all_rules(users_devices_rules: list[dict]) -> None:
|
||||
"""Full reconciliation: flush and rebuild all per-user chains from DB state.
|
||||
|
||||
Args:
|
||||
users_devices_rules: list of dicts with keys:
|
||||
user_id, devices (list of {ipv4, ipv6}), rules (list of {destination, action, port_type, port_range})
|
||||
"""
|
||||
commands = []
|
||||
|
||||
for entry in users_devices_rules:
|
||||
user_id = entry["user_id"]
|
||||
chain = _user_chain_name(user_id)
|
||||
|
||||
# Create/flush user chain
|
||||
commands.append(f"add chain inet {TABLE_NAME} {chain}")
|
||||
commands.append(f"flush chain inet {TABLE_NAME} {chain}")
|
||||
|
||||
# Add rules
|
||||
for rule in entry.get("rules", []):
|
||||
expr = _build_rule_expr(
|
||||
rule["destination"], rule["action"],
|
||||
rule.get("port_type"), rule.get("port_range"),
|
||||
)
|
||||
commands.append(f"add rule inet {TABLE_NAME} {chain} {expr}")
|
||||
|
||||
# Flush forward chain jump rules and re-add
|
||||
commands.append(f"flush chain inet {TABLE_NAME} forward")
|
||||
for entry in users_devices_rules:
|
||||
user_id = entry["user_id"]
|
||||
chain = _user_chain_name(user_id)
|
||||
for dev in entry.get("devices", []):
|
||||
if dev.get("ipv4"):
|
||||
commands.append(f"add rule inet {TABLE_NAME} forward ip saddr {dev['ipv4']} jump {chain}")
|
||||
if dev.get("ipv6"):
|
||||
commands.append(f"add rule inet {TABLE_NAME} forward ip6 saddr {dev['ipv6']} jump {chain}")
|
||||
|
||||
if commands:
|
||||
await _nft_batch(commands)
|
||||
logger.info("Firewall rules rebuilt for {} users", len(users_devices_rules))
|
||||
|
||||
|
||||
def _user_chain_name(user_id: str) -> str:
|
||||
"""Generate a deterministic chain name from a user ID."""
|
||||
# Use first 12 chars of UUID (without hyphens) to keep names short
|
||||
short = user_id.replace("-", "")[:12]
|
||||
return f"user_{short}"
|
||||
|
||||
|
||||
def _build_rule_expr(destination: str, action: str, port_type: str | None = None, port_range: str | None = None) -> str:
|
||||
"""Build an nftables rule expression string."""
|
||||
# Determine IP version from destination
|
||||
if ":" in destination:
|
||||
addr_match = f"ip6 daddr {destination}"
|
||||
else:
|
||||
addr_match = f"ip daddr {destination}"
|
||||
|
||||
parts = [addr_match]
|
||||
|
||||
if port_type and port_range:
|
||||
parts.append(f"{port_type} dport {port_range}")
|
||||
|
||||
parts.append(action)
|
||||
return " ".join(parts)
|
||||
65
wiregui/services/notifications.py
Normal file
65
wiregui/services/notifications.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
"""In-memory notification queue with severity levels."""
|
||||
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
MAX_NOTIFICATIONS = 100
|
||||
|
||||
|
||||
class Notification:
|
||||
__slots__ = ("id", "severity", "message", "user", "timestamp")
|
||||
|
||||
def __init__(self, severity: str, message: str, user: str | None = None):
|
||||
self.id = str(uuid4())
|
||||
self.severity = severity # "info" | "warning" | "error"
|
||||
self.message = message
|
||||
self.user = user
|
||||
self.timestamp = utcnow()
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"severity": self.severity,
|
||||
"message": self.message,
|
||||
"user": self.user,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
_notifications: deque[Notification] = deque(maxlen=MAX_NOTIFICATIONS)
|
||||
|
||||
|
||||
def add(severity: str, message: str, user: str | None = None) -> Notification:
|
||||
"""Add a notification to the queue."""
|
||||
n = Notification(severity, message, user)
|
||||
_notifications.appendleft(n)
|
||||
logger.debug("Notification added: [{}] {}", severity, message)
|
||||
return n
|
||||
|
||||
|
||||
def current() -> list[Notification]:
|
||||
"""Return all current notifications (newest first)."""
|
||||
return list(_notifications)
|
||||
|
||||
|
||||
def clear(notification_id: str) -> None:
|
||||
"""Remove a specific notification by ID."""
|
||||
for i, n in enumerate(_notifications):
|
||||
if n.id == notification_id:
|
||||
del _notifications[i]
|
||||
return
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
"""Remove all notifications."""
|
||||
_notifications.clear()
|
||||
|
||||
|
||||
def count() -> int:
|
||||
return len(_notifications)
|
||||
188
wiregui/services/wireguard.py
Normal file
188
wiregui/services/wireguard.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
"""WireGuard interface management via subprocess calls to `wg` and `ip`."""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from wiregui.config import get_settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeerInfo:
|
||||
public_key: str
|
||||
endpoint: str | None = None
|
||||
allowed_ips: list[str] = field(default_factory=list)
|
||||
latest_handshake: datetime | None = None
|
||||
rx_bytes: int = 0
|
||||
tx_bytes: int = 0
|
||||
|
||||
|
||||
async def _run(args: list[str], input_data: str | None = None) -> str:
|
||||
"""Run a subprocess and return stdout. Raises on non-zero exit."""
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*args,
|
||||
stdin=asyncio.subprocess.PIPE if input_data else None,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate(input_data.encode() if input_data else None)
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"{' '.join(args)} failed (rc={proc.returncode}): {stderr.decode().strip()}")
|
||||
return stdout.decode().strip()
|
||||
|
||||
|
||||
async def ensure_interface(iface: str | None = None) -> None:
|
||||
"""Create WireGuard interface if it doesn't exist, assign server IPs and bring it up."""
|
||||
settings = get_settings()
|
||||
iface = iface or settings.wg_interface
|
||||
|
||||
# Check if interface exists
|
||||
try:
|
||||
await _run(["ip", "link", "show", iface])
|
||||
logger.debug("Interface {} already exists", iface)
|
||||
return
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
logger.info("Creating WireGuard interface {}", iface)
|
||||
await _run(["ip", "link", "add", iface, "type", "wireguard"])
|
||||
|
||||
# Assign server IP (first host in each network)
|
||||
from ipaddress import IPv4Network, IPv6Network
|
||||
|
||||
v4_net = IPv4Network(settings.wg_ipv4_network, strict=False)
|
||||
v4_server = str(list(v4_net.hosts())[0])
|
||||
await _run(["ip", "address", "add", f"{v4_server}/{v4_net.prefixlen}", "dev", iface])
|
||||
|
||||
v6_net = IPv6Network(settings.wg_ipv6_network, strict=False)
|
||||
v6_server = str(list(v6_net.hosts())[0])
|
||||
await _run(["ip", "address", "add", f"{v6_server}/{v6_net.prefixlen}", "dev", iface])
|
||||
|
||||
await _run(["ip", "link", "set", iface, "up"])
|
||||
logger.info("Interface {} is up with {} and {}", iface, v4_server, v6_server)
|
||||
|
||||
|
||||
async def configure_interface(iface: str | None = None) -> None:
|
||||
"""Set the server private key and listen port on the WireGuard interface from DB config."""
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.configuration import Configuration
|
||||
|
||||
settings = get_settings()
|
||||
iface = iface or settings.wg_interface
|
||||
|
||||
async with async_session() as session:
|
||||
result = await session.execute(select(Configuration).limit(1))
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if not config or not config.server_private_key:
|
||||
logger.error("No server private key in Configuration — WG interface not configured")
|
||||
return
|
||||
|
||||
# Write private key to a temp file (stdin piping has issues with uvloop)
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
key_fd, key_path = tempfile.mkstemp()
|
||||
try:
|
||||
os.write(key_fd, config.server_private_key.encode())
|
||||
os.close(key_fd)
|
||||
os.chmod(key_path, 0o600)
|
||||
await _run(["wg", "set", iface, "private-key", key_path, "listen-port", str(settings.wg_endpoint_port)])
|
||||
finally:
|
||||
os.unlink(key_path)
|
||||
|
||||
logger.info("WireGuard interface {} configured (listen-port={})", iface, settings.wg_endpoint_port)
|
||||
|
||||
|
||||
async def set_private_key(private_key_path: str, iface: str | None = None) -> None:
|
||||
"""Set the WireGuard private key from a file."""
|
||||
settings = get_settings()
|
||||
iface = iface or settings.wg_interface
|
||||
await _run(["wg", "set", iface, "private-key", private_key_path])
|
||||
|
||||
|
||||
async def set_listen_port(port: int, iface: str | None = None) -> None:
|
||||
"""Set the WireGuard listen port."""
|
||||
settings = get_settings()
|
||||
iface = iface or settings.wg_interface
|
||||
await _run(["wg", "set", iface, "listen-port", str(port)])
|
||||
|
||||
|
||||
async def add_peer(
|
||||
public_key: str,
|
||||
allowed_ips: list[str],
|
||||
preshared_key: str | None = None,
|
||||
iface: str | None = None,
|
||||
) -> None:
|
||||
"""Add or update a WireGuard peer."""
|
||||
settings = get_settings()
|
||||
iface = iface or settings.wg_interface
|
||||
|
||||
args = ["wg", "set", iface, "peer", public_key, "allowed-ips", ",".join(allowed_ips)]
|
||||
if preshared_key:
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
psk_fd, psk_path = tempfile.mkstemp()
|
||||
try:
|
||||
os.write(psk_fd, preshared_key.encode())
|
||||
os.close(psk_fd)
|
||||
os.chmod(psk_path, 0o600)
|
||||
await _run([
|
||||
"wg", "set", iface, "peer", public_key,
|
||||
"allowed-ips", ",".join(allowed_ips),
|
||||
"preshared-key", psk_path,
|
||||
])
|
||||
finally:
|
||||
os.unlink(psk_path)
|
||||
else:
|
||||
await _run(args)
|
||||
|
||||
logger.info("Peer added/updated: {} -> {}", public_key[:20], allowed_ips)
|
||||
|
||||
|
||||
async def remove_peer(public_key: str, iface: str | None = None) -> None:
|
||||
"""Remove a WireGuard peer."""
|
||||
settings = get_settings()
|
||||
iface = iface or settings.wg_interface
|
||||
await _run(["wg", "set", iface, "peer", public_key, "remove"])
|
||||
logger.info("Peer removed: {}", public_key[:20])
|
||||
|
||||
|
||||
async def get_peers(iface: str | None = None) -> list[PeerInfo]:
|
||||
"""Parse `wg show <iface> dump` and return peer information."""
|
||||
settings = get_settings()
|
||||
iface = iface or settings.wg_interface
|
||||
|
||||
try:
|
||||
output = await _run(["wg", "show", iface, "dump"])
|
||||
except RuntimeError:
|
||||
return []
|
||||
|
||||
peers = []
|
||||
for line in output.splitlines()[1:]: # skip the interface line
|
||||
parts = line.split("\t")
|
||||
if len(parts) < 8:
|
||||
continue
|
||||
pub_key = parts[0]
|
||||
# parts: public_key, preshared_key, endpoint, allowed_ips, latest_handshake, rx, tx, keepalive
|
||||
endpoint = parts[2] if parts[2] != "(none)" else None
|
||||
allowed_ips = parts[3].split(",") if parts[3] != "(none)" else []
|
||||
handshake_ts = int(parts[4]) if parts[4] != "0" else None
|
||||
latest_handshake = datetime.utcfromtimestamp(handshake_ts) if handshake_ts else None
|
||||
rx_bytes = int(parts[5])
|
||||
tx_bytes = int(parts[6])
|
||||
|
||||
peers.append(PeerInfo(
|
||||
public_key=pub_key,
|
||||
endpoint=endpoint,
|
||||
allowed_ips=allowed_ips,
|
||||
latest_handshake=latest_handshake,
|
||||
rx_bytes=rx_bytes,
|
||||
tx_bytes=tx_bytes,
|
||||
))
|
||||
return peers
|
||||
28
wiregui/tasks/__init__.py
Normal file
28
wiregui/tasks/__init__.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
"""Background tasks — registered on app startup."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from wiregui.config import get_settings
|
||||
|
||||
_tasks: list[asyncio.Task] = []
|
||||
|
||||
|
||||
def register_task(coro, name: str) -> None:
|
||||
"""Schedule a coroutine as a background task."""
|
||||
task = asyncio.create_task(coro, name=name)
|
||||
_tasks.append(task)
|
||||
logger.info("Background task registered: {}", name)
|
||||
|
||||
|
||||
async def cancel_all() -> None:
|
||||
"""Cancel all registered background tasks (called on shutdown)."""
|
||||
for task in _tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
_tasks.clear()
|
||||
logger.info("All background tasks cancelled")
|
||||
60
wiregui/tasks/connectivity.py
Normal file
60
wiregui/tasks/connectivity.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""Periodic WAN connectivity checks — fetch a URL and log the result."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.configuration import Configuration
|
||||
from wiregui.models.connectivity_check import ConnectivityCheck
|
||||
from wiregui.services import notifications
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
DEFAULT_URL = "https://ping-dev.firezone.dev"
|
||||
DEFAULT_INTERVAL = 300 # 5 minutes
|
||||
|
||||
|
||||
async def connectivity_loop() -> None:
|
||||
"""Run forever: perform connectivity checks at a configurable interval."""
|
||||
logger.info("Connectivity check task started")
|
||||
await asyncio.sleep(60) # Initial delay to avoid startup spam
|
||||
while True:
|
||||
try:
|
||||
await _check_connectivity()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Connectivity check failed: {}", e)
|
||||
await asyncio.sleep(DEFAULT_INTERVAL)
|
||||
|
||||
|
||||
async def _check_connectivity() -> None:
|
||||
"""Fetch the connectivity check URL and store the result."""
|
||||
url = DEFAULT_URL
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.get(url)
|
||||
|
||||
check = ConnectivityCheck(
|
||||
url=url,
|
||||
response_code=resp.status_code,
|
||||
response_headers=dict(resp.headers),
|
||||
response_body=resp.text[:500],
|
||||
)
|
||||
logger.debug("Connectivity check: {} -> {}", url, resp.status_code)
|
||||
|
||||
except Exception as e:
|
||||
check = ConnectivityCheck(
|
||||
url=url,
|
||||
response_code=None,
|
||||
response_body=str(e)[:500],
|
||||
)
|
||||
logger.warning("Connectivity check failed: {}", e)
|
||||
notifications.add("warning", f"WAN connectivity check failed: {e}")
|
||||
|
||||
async with async_session() as session:
|
||||
session.add(check)
|
||||
await session.commit()
|
||||
108
wiregui/tasks/oidc_refresh.py
Normal file
108
wiregui/tasks/oidc_refresh.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
"""Periodically refresh OIDC tokens for all active connections."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
|
||||
from wiregui.db import async_session
|
||||
from wiregui.models.oidc_connection import OIDCConnection
|
||||
from wiregui.models.user import User
|
||||
from wiregui.services import notifications
|
||||
from wiregui.utils.time import utcnow
|
||||
|
||||
INTERVAL_SECONDS = 600 # 10 minutes
|
||||
|
||||
|
||||
async def oidc_refresh_loop() -> None:
|
||||
"""Run forever: refresh OIDC tokens every INTERVAL_SECONDS."""
|
||||
logger.info("OIDC refresh task started (interval={}s)", INTERVAL_SECONDS)
|
||||
await asyncio.sleep(60) # Initial delay to avoid startup spam
|
||||
while True:
|
||||
try:
|
||||
await _refresh_all()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("OIDC refresh cycle failed: {}", e)
|
||||
await asyncio.sleep(INTERVAL_SECONDS)
|
||||
|
||||
|
||||
async def _refresh_all() -> None:
|
||||
"""Attempt to refresh all stored OIDC tokens."""
|
||||
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
||||
|
||||
from wiregui.auth.oidc import load_providers
|
||||
|
||||
providers = await load_providers()
|
||||
provider_map = {p["id"]: p for p in providers}
|
||||
|
||||
async with async_session() as session:
|
||||
result = await session.execute(
|
||||
select(OIDCConnection).where(OIDCConnection.refresh_token.is_not(None))
|
||||
)
|
||||
connections = result.scalars().all()
|
||||
|
||||
if not connections:
|
||||
return
|
||||
|
||||
refreshed = 0
|
||||
failed = 0
|
||||
|
||||
for conn in connections:
|
||||
provider_config = provider_map.get(conn.provider)
|
||||
if not provider_config:
|
||||
continue
|
||||
|
||||
try:
|
||||
async with AsyncOAuth2Client(
|
||||
client_id=provider_config["client_id"],
|
||||
client_secret=provider_config["client_secret"],
|
||||
) as client:
|
||||
# Load server metadata to get token endpoint
|
||||
discovery_url = provider_config.get("discovery_document_uri")
|
||||
if discovery_url:
|
||||
import httpx
|
||||
resp = await client.get(discovery_url)
|
||||
metadata = resp.json()
|
||||
token_endpoint = metadata.get("token_endpoint")
|
||||
else:
|
||||
continue
|
||||
|
||||
new_token = await client.refresh_token(
|
||||
url=token_endpoint,
|
||||
refresh_token=conn.refresh_token,
|
||||
)
|
||||
|
||||
# Update connection
|
||||
async with async_session() as session:
|
||||
c = await session.get(OIDCConnection, conn.id)
|
||||
if c:
|
||||
c.refresh_token = new_token.get("refresh_token", c.refresh_token)
|
||||
c.refresh_response = dict(new_token)
|
||||
c.refreshed_at = utcnow()
|
||||
session.add(c)
|
||||
await session.commit()
|
||||
|
||||
refreshed += 1
|
||||
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.warning("OIDC refresh failed for connection {} (provider={}): {}",
|
||||
conn.id, conn.provider, e)
|
||||
|
||||
# Check if we should disable VPN
|
||||
from wiregui.models.configuration import Configuration
|
||||
async with async_session() as session:
|
||||
config = (await session.execute(select(Configuration).limit(1))).scalar_one_or_none()
|
||||
if config and config.disable_vpn_on_oidc_error:
|
||||
user = await session.get(User, conn.user_id)
|
||||
if user:
|
||||
notifications.add(
|
||||
"error",
|
||||
f"OIDC refresh failed for {user.email} ({conn.provider}). VPN access may be affected.",
|
||||
user=user.email,
|
||||
)
|
||||
|
||||
if refreshed or failed:
|
||||
logger.info("OIDC refresh: {} succeeded, {} failed", refreshed, failed)
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue