wiregui/tests/conftest.py
Stefano Bertelli a012635dff
Some checks failed
Dev / test (push) Failing after 7m41s
Dev / docker (push) Has been skipped
fix: remove unit tests redundant with e2e, fix test DB isolation
Remove 7 test files fully covered by e2e tests (admin, account, models,
API routes, integration MFA/OIDC, notifications). Trim 5 more files to
keep only edge cases not reachable via e2e.

Fix conftest to replace wiregui.db engine/session at import time so all
code uses the test database. Use session-scoped tables with per-test
savepoint isolation to prevent data leaking between tests.
2026-03-31 21:27:46 -05:00

100 lines
3.7 KiB
Python

"""Shared test fixtures — async DB session using a test database.
The module-level code below replaces ``wiregui.db.engine`` and
``wiregui.db.async_session`` with instances pointing at the **test** database
*before* any test (or other module) can grab a reference to the originals.
This means every ``from wiregui.db import async_session`` — whether in test
files or in production code like ``wiregui.utils.server_key`` — will get the
test-database session maker.
"""
import os
from collections.abc import AsyncGenerator
import pytest_asyncio
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlmodel import SQLModel
import wiregui.db as _db_module
from wiregui.config import get_settings
# All models must be imported so SQLModel.metadata knows about them
from wiregui.models import * # noqa: F401, F403
def _test_database_url() -> str:
"""Use a separate test DB locally, but in CI just use the main DB (it's ephemeral)."""
url = get_settings().database_url
if os.environ.get("CI"):
return url # CI: use the service container DB directly
base, _dbname = url.rsplit("/", 1)
return f"{base}/wiregui_test"
TEST_DATABASE_URL = _test_database_url()
def _ensure_test_db_sync():
"""Ensure test database exists. Skip in CI (uses main DB)."""
if os.environ.get("CI"):
return
import asyncio
async def _create():
base_url = get_settings().database_url.rsplit("/", 1)[0] + "/postgres"
admin_engine = create_async_engine(base_url, isolation_level="AUTOCOMMIT")
try:
async with admin_engine.connect() as conn:
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = 'wiregui_test'")
)
if result.scalar() is None:
await conn.execute(text("CREATE DATABASE wiregui_test"))
finally:
await admin_engine.dispose()
asyncio.run(_create())
_ensure_test_db_sync()
# ---------------------------------------------------------------------------
# Replace the production engine/session in wiregui.db at import time so that
# every module that does ``from wiregui.db import async_session`` picks up the
# test database. This MUST happen before test modules are collected (which
# triggers their top-level imports).
# ---------------------------------------------------------------------------
_test_engine = create_async_engine(TEST_DATABASE_URL)
_test_session_factory = async_sessionmaker(_test_engine, expire_on_commit=False)
_db_module.engine = _test_engine
_db_module.async_session = _test_session_factory
@pytest_asyncio.fixture(scope="session", autouse=True)
async def _setup_test_tables():
"""Create all tables once at the start of the test session, drop at end."""
async with _test_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
yield
async with _test_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
await _test_engine.dispose()
@pytest_asyncio.fixture
async def session() -> AsyncGenerator[AsyncSession]:
"""Per-test session with transaction isolation.
The session is bound to a connection-level transaction that is always
rolled back at teardown. When tested code calls ``session.commit()``,
SQLAlchemy only releases a SAVEPOINT — the outer transaction is never
committed, so no test data persists between tests.
"""
async with _test_engine.connect() as conn:
txn = await conn.begin()
sess = AsyncSession(bind=conn, expire_on_commit=False, join_transaction_mode="create_savepoint")
yield sess
await sess.close()
await txn.rollback()