diff --git a/tests/test_api.py b/tests/test_api.py index 56d8ec6..2019ad4 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,8 +1,6 @@ """Tests for REST API endpoints and token auth.""" -import hashlib - -from wiregui.auth.api_token import generate_api_token, resolve_bearer_token +from wiregui.auth.api_token import _token_hmac, generate_api_token, resolve_bearer_token from wiregui.auth.passwords import hash_password from wiregui.models.api_token import ApiToken from wiregui.models.user import User @@ -15,7 +13,7 @@ from wiregui.utils.time import utcnow def test_generate_api_token(): plaintext, token_hash = generate_api_token() assert len(plaintext) > 20 - assert token_hash == hashlib.sha512(plaintext.encode()).hexdigest() + assert token_hash == _token_hmac(plaintext) def test_generate_api_token_unique(): diff --git a/wiregui/auth/api_token.py b/wiregui/auth/api_token.py index d930e9f..f2ad837 100644 --- a/wiregui/auth/api_token.py +++ b/wiregui/auth/api_token.py @@ -1,27 +1,33 @@ """API token authentication — Bearer token via Authorization header.""" -import hashlib +import hmac import secrets from loguru import logger from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import select +from wiregui.config import get_settings from wiregui.models.api_token import ApiToken from wiregui.models.user import User from wiregui.utils.time import utcnow +def _token_hmac(token: str) -> str: + """Compute a keyed HMAC-SHA256 digest of an API token.""" + key = get_settings().secret_key.encode() + return hmac.new(key, token.encode(), "sha256").hexdigest() + + def generate_api_token() -> tuple[str, str]: """Generate a new API token. Returns (plaintext_token, token_hash).""" plaintext = secrets.token_urlsafe(32) - token_hash = hashlib.sha512(plaintext.encode()).hexdigest() - return plaintext, token_hash + return plaintext, _token_hmac(plaintext) async def resolve_bearer_token(session: AsyncSession, token: str) -> User | None: """Look up a Bearer token and return the associated user, or None.""" - token_hash = hashlib.sha512(token.encode()).hexdigest() + token_hash = _token_hmac(token) result = await session.execute( select(ApiToken).where(ApiToken.token_hash == token_hash) )