169 lines
5.4 KiB
Python
169 lines
5.4 KiB
Python
|
|
"""Tests for SQLModel table definitions."""
|
||
|
|
|
||
|
|
import pytest # noqa: F401 — needed for pytest.raises
|
||
|
|
from sqlmodel import select
|
||
|
|
|
||
|
|
from wiregui.models.api_token import ApiToken
|
||
|
|
from wiregui.models.configuration import Configuration
|
||
|
|
from wiregui.models.connectivity_check import ConnectivityCheck
|
||
|
|
from wiregui.models.device import Device
|
||
|
|
from wiregui.models.mfa_method import MFAMethod
|
||
|
|
from wiregui.models.oidc_connection import OIDCConnection
|
||
|
|
from wiregui.models.rule import Rule
|
||
|
|
from wiregui.models.user import User
|
||
|
|
|
||
|
|
|
||
|
|
async def test_create_user(session):
|
||
|
|
user = User(email="alice@example.com", role="admin")
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
result = await session.execute(select(User).where(User.email == "alice@example.com"))
|
||
|
|
fetched = result.scalar_one()
|
||
|
|
assert fetched.id == user.id
|
||
|
|
assert fetched.role == "admin"
|
||
|
|
assert fetched.disabled_at is None
|
||
|
|
|
||
|
|
|
||
|
|
async def test_create_device_with_user(session):
|
||
|
|
user = User(email="bob@example.com")
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
device = Device(
|
||
|
|
name="laptop",
|
||
|
|
public_key="pk-test-device-001",
|
||
|
|
user_id=user.id,
|
||
|
|
)
|
||
|
|
session.add(device)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
result = await session.execute(select(Device).where(Device.public_key == "pk-test-device-001"))
|
||
|
|
fetched = result.scalar_one()
|
||
|
|
assert fetched.name == "laptop"
|
||
|
|
assert fetched.user_id == user.id
|
||
|
|
assert fetched.use_default_dns is True
|
||
|
|
assert fetched.use_default_allowed_ips is True
|
||
|
|
assert fetched.rx_bytes is None
|
||
|
|
|
||
|
|
|
||
|
|
async def test_device_unique_public_key(session):
|
||
|
|
user = User(email="carol@example.com")
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
d1 = Device(name="d1", public_key="duplicate-key", user_id=user.id)
|
||
|
|
session.add(d1)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
d2 = Device(name="d2", public_key="duplicate-key", user_id=user.id)
|
||
|
|
session.add(d2)
|
||
|
|
with pytest.raises(Exception): # IntegrityError
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
|
||
|
|
async def test_create_rule(session):
|
||
|
|
user = User(email="dave@example.com")
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
rule = Rule(action="accept", destination="10.0.0.0/8", user_id=user.id)
|
||
|
|
session.add(rule)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
result = await session.execute(select(Rule).where(Rule.user_id == user.id))
|
||
|
|
fetched = result.scalar_one()
|
||
|
|
assert fetched.action == "accept"
|
||
|
|
assert fetched.destination == "10.0.0.0/8"
|
||
|
|
assert fetched.port_type is None
|
||
|
|
assert fetched.port_range is None
|
||
|
|
|
||
|
|
|
||
|
|
async def test_create_rule_with_port(session):
|
||
|
|
rule = Rule(
|
||
|
|
action="drop",
|
||
|
|
destination="192.168.0.0/16",
|
||
|
|
port_type="tcp",
|
||
|
|
port_range="80-443",
|
||
|
|
)
|
||
|
|
session.add(rule)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
fetched = (await session.execute(select(Rule).where(Rule.id == rule.id))).scalar_one()
|
||
|
|
assert fetched.port_type == "tcp"
|
||
|
|
assert fetched.port_range == "80-443"
|
||
|
|
assert fetched.user_id is None # global rule
|
||
|
|
|
||
|
|
|
||
|
|
async def test_create_mfa_method(session):
|
||
|
|
user = User(email="eve@example.com")
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
mfa = MFAMethod(
|
||
|
|
name="My Authenticator",
|
||
|
|
type="totp",
|
||
|
|
payload={"secret": "JBSWY3DPEHPK3PXP"},
|
||
|
|
user_id=user.id,
|
||
|
|
)
|
||
|
|
session.add(mfa)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
fetched = (await session.execute(select(MFAMethod).where(MFAMethod.user_id == user.id))).scalar_one()
|
||
|
|
assert fetched.type == "totp"
|
||
|
|
assert fetched.payload["secret"] == "JBSWY3DPEHPK3PXP"
|
||
|
|
|
||
|
|
|
||
|
|
async def test_create_oidc_connection(session):
|
||
|
|
user = User(email="frank@example.com")
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
conn = OIDCConnection(provider="google", refresh_token="tok_abc", user_id=user.id)
|
||
|
|
session.add(conn)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
fetched = (await session.execute(select(OIDCConnection).where(OIDCConnection.user_id == user.id))).scalar_one()
|
||
|
|
assert fetched.provider == "google"
|
||
|
|
assert fetched.refresh_token == "tok_abc"
|
||
|
|
|
||
|
|
|
||
|
|
async def test_create_api_token(session):
|
||
|
|
user = User(email="grace@example.com")
|
||
|
|
session.add(user)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
token = ApiToken(token_hash="sha256_fake_hash", user_id=user.id)
|
||
|
|
session.add(token)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
fetched = (await session.execute(select(ApiToken).where(ApiToken.user_id == user.id))).scalar_one()
|
||
|
|
assert fetched.token_hash == "sha256_fake_hash"
|
||
|
|
assert fetched.expires_at is None
|
||
|
|
|
||
|
|
|
||
|
|
async def test_create_connectivity_check(session):
|
||
|
|
check = ConnectivityCheck(url="https://example.com", response_code=200)
|
||
|
|
session.add(check)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
fetched = (await session.execute(select(ConnectivityCheck).where(ConnectivityCheck.id == check.id))).scalar_one()
|
||
|
|
assert fetched.response_code == 200
|
||
|
|
|
||
|
|
|
||
|
|
async def test_configuration_defaults(session):
|
||
|
|
config = Configuration()
|
||
|
|
session.add(config)
|
||
|
|
await session.flush()
|
||
|
|
|
||
|
|
fetched = (await session.execute(select(Configuration).where(Configuration.id == config.id))).scalar_one()
|
||
|
|
assert fetched.allow_unprivileged_device_management is True
|
||
|
|
assert fetched.local_auth_enabled is True
|
||
|
|
assert fetched.default_client_mtu == 1280
|
||
|
|
assert fetched.default_client_persistent_keepalive == 25
|
||
|
|
assert fetched.default_client_dns == ["1.1.1.1", "1.0.0.1"]
|
||
|
|
assert fetched.default_client_allowed_ips == ["0.0.0.0/0", "::/0"]
|
||
|
|
assert fetched.vpn_session_duration == 0
|
||
|
|
assert fetched.openid_connect_providers == []
|
||
|
|
assert fetched.saml_identity_providers == []
|