"""Tests for sageapi.middleware.dapi_auth""" import hashlib import hmac import time from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from starlette.requests import Request from starlette.responses import JSONResponse from starlette.testclient import TestClient from sageapi.middleware.dapi_auth import ( DEFAULT_TIMESTAMP_TOLERANCE_SEC, DapiAuthError, DapiAuthMiddleware, DapiKeyRecord, authenticate_request, compute_dapi_signature, lookup_api_key, verify_timestamp, ) def make_signature(method: str, path: str, timestamp: str, secret: str, body: bytes = b"") -> str: """Helper to create a valid DAPI signature.""" return compute_dapi_signature(method, path, timestamp, secret, body if body else None) # --------------------------------------------------------------------------- # compute_dapi_signature tests # --------------------------------------------------------------------------- class TestComputeDapiSignature: def test_basic_signature(self): sig = compute_dapi_signature("GET", "/api/test", "1700000000.0", "my-secret") assert isinstance(sig, str) assert len(sig) == 64 # SHA-256 hex length def test_signature_with_body(self): body = b'{"key": "value"}' sig = compute_dapi_signature("POST", "/api/test", "1700000000.0", "my-secret", body) assert isinstance(sig, str) assert len(sig) == 64 def test_signature_deterministic(self): sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret") sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret") assert sig1 == sig2 def test_different_body_different_signature(self): sig1 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body1") sig2 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body2") assert sig1 != sig2 def test_empty_body_same_as_no_body(self): sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret") sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret", None) assert sig1 == sig2 def test_manual_verification(self): """Verify the signature matches the expected HMAC-SHA256 output.""" method, path, ts, secret = "POST", "/v1/chat", "1700000000.0", "test-secret" body = b'{"message":"hello"}' body_hash = hashlib.sha256(body).hexdigest() string_to_sign = f"{method}\n{path}\n{ts}\n{body_hash}" expected = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha256).hexdigest() actual = compute_dapi_signature(method, path, ts, secret, body) assert actual == expected # --------------------------------------------------------------------------- # verify_timestamp tests # --------------------------------------------------------------------------- class TestVerifyTimestamp: def test_valid_timestamp(self): ts = str(time.time()) assert verify_timestamp(ts) is True def test_expired_timestamp(self): ts = str(time.time() - 600) # 10 minutes ago assert verify_timestamp(ts) is False def test_future_timestamp(self): ts = str(time.time() + 600) # 10 minutes in future assert verify_timestamp(ts) is False def test_invalid_timestamp(self): assert verify_timestamp("not-a-number") is False assert verify_timestamp("") is False def test_custom_tolerance(self): ts = str(time.time() - 10) assert verify_timestamp(ts, tolerance_sec=5) is False assert verify_timestamp(ts, tolerance_sec=15) is True # --------------------------------------------------------------------------- # DapiKeyRecord tests # --------------------------------------------------------------------------- class TestDapiKeyRecord: def test_active_key(self): rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None) assert rec.is_active is True def test_inactive_key(self): rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="inactive", expire_date=None) assert rec.is_active is False def test_not_expired_when_no_expiry(self): rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None) assert rec.is_expired is False def test_expired_with_past_date(self): past = datetime(2020, 1, 1, tzinfo=timezone.utc) rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=past) assert rec.is_expired is True def test_not_expired_with_future_date(self): future = datetime(2099, 1, 1, tzinfo=timezone.utc) rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=future) assert rec.is_expired is False def test_expired_with_iso_string_past(self): rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2020-01-01T00:00:00Z") assert rec.is_expired is True def test_not_expired_with_iso_string_future(self): rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2099-01-01T00:00:00Z") assert rec.is_expired is False # --------------------------------------------------------------------------- # lookup_api_key tests # --------------------------------------------------------------------------- class TestLookupApiKey: @pytest.fixture(autouse=True) def mock_sage_modules(self): """Create mock ahserver and sqlor modules for testing.""" import sys import types # Create fake modules ahserver = types.ModuleType("ahserver") ahserver_serverenv = types.ModuleType("ahserver.serverenv") ahserver_serverenv.ServerEnv = MagicMock ahserver.serverenv = ahserver_serverenv sqlor = types.ModuleType("sqlor") sqlor_dbpools = types.ModuleType("sqlor.dbpools") mock_ctx = AsyncMock() mock_sor = AsyncMock() mock_sor.R = AsyncMock(return_value=[]) mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor) mock_ctx.__aexit__ = AsyncMock(return_value=False) sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx) sqlor.dbpools = sqlor_dbpools sys.modules["ahserver"] = ahserver sys.modules["ahserver.serverenv"] = ahserver_serverenv sys.modules["sqlor"] = sqlor sys.modules["sqlor.dbpools"] = sqlor_dbpools yield mock_ctx, mock_sor # Cleanup for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]: if mod in sys.modules: del sys.modules[mod] @pytest.mark.asyncio async def test_lookup_found(self, mock_sage_modules): mock_ctx, mock_sor = mock_sage_modules mock_rec = { "id": 1, "apikey": "test-key", "secret": "test-secret", "status": "active", "expire_date": "2099-01-01T00:00:00Z", "description": "Test key", } mock_sor.R.return_value = [mock_rec] result = await lookup_api_key("test-key") assert result is not None assert result.apikey == "test-key" assert result.secret == "test-secret" assert result.is_active is True @pytest.mark.asyncio async def test_lookup_not_found(self, mock_sage_modules): mock_ctx, mock_sor = mock_sage_modules mock_sor.R.return_value = [] result = await lookup_api_key("nonexistent") assert result is None # --------------------------------------------------------------------------- # authenticate_request tests # --------------------------------------------------------------------------- class TestAuthenticateRequest: def _make_request(self, headers: dict, body: bytes = b"", method: str = "GET", path: str = "/test") -> Request: scope = { "type": "http", "method": method, "path": path, "headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()], "query_string": b"", } async def receive(): return {"type": "http.request", "body": body, "more_body": False} return Request(scope, receive) @pytest.mark.asyncio async def test_missing_api_key_header(self): req = self._make_request({ "X-DAPI-Timestamp": str(time.time()), "X-DAPI-Signature": "abc", }) with pytest.raises(DapiAuthError) as exc: await authenticate_request(req) assert exc.value.status_code == 401 assert "X-DAPI-Key" in exc.value.detail @pytest.mark.asyncio async def test_missing_timestamp_header(self): req = self._make_request({ "X-DAPI-Key": "test-key", "X-DAPI-Signature": "abc", }) with pytest.raises(DapiAuthError) as exc: await authenticate_request(req) assert exc.value.status_code == 401 assert "X-DAPI-Timestamp" in exc.value.detail @pytest.mark.asyncio async def test_missing_signature_header(self): req = self._make_request({ "X-DAPI-Key": "test-key", "X-DAPI-Timestamp": str(time.time()), }) with pytest.raises(DapiAuthError) as exc: await authenticate_request(req) assert exc.value.status_code == 401 assert "X-DAPI-Signature" in exc.value.detail @pytest.mark.asyncio async def test_expired_timestamp(self): ts = str(time.time() - 600) req = self._make_request({ "X-DAPI-Key": "test-key", "X-DAPI-Timestamp": ts, "X-DAPI-Signature": "abc", }) with pytest.raises(DapiAuthError) as exc: await authenticate_request(req) assert exc.value.status_code == 401 assert "timestamp" in exc.value.detail.lower() @pytest.mark.asyncio async def test_invalid_api_key(self): ts = str(time.time()) req = self._make_request({ "X-DAPI-Key": "bad-key", "X-DAPI-Timestamp": ts, "X-DAPI-Signature": "abc", }) with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=None): with pytest.raises(DapiAuthError) as exc: await authenticate_request(req) assert exc.value.status_code == 401 assert "Invalid API key" in exc.value.detail @pytest.mark.asyncio async def test_inactive_key(self): ts = str(time.time()) req = self._make_request({ "X-DAPI-Key": "test-key", "X-DAPI-Timestamp": ts, "X-DAPI-Signature": "abc", }) key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="secret", status="inactive", expire_date=None) with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec): with pytest.raises(DapiAuthError) as exc: await authenticate_request(req) assert exc.value.status_code == 403 assert "inactive" in exc.value.detail.lower() @pytest.mark.asyncio async def test_expired_key(self): ts = str(time.time()) req = self._make_request({ "X-DAPI-Key": "test-key", "X-DAPI-Timestamp": ts, "X-DAPI-Signature": "abc", }) key_rec = DapiKeyRecord( id=1, apikey="test-key", secret="secret", status="active", expire_date=datetime(2020, 1, 1, tzinfo=timezone.utc), ) with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec): with pytest.raises(DapiAuthError) as exc: await authenticate_request(req) assert exc.value.status_code == 403 assert "expired" in exc.value.detail.lower() @pytest.mark.asyncio async def test_invalid_signature(self): ts = str(time.time()) body = b'{"test": "data"}' req = self._make_request({ "X-DAPI-Key": "test-key", "X-DAPI-Timestamp": ts, "X-DAPI-Signature": "invalid-signature", }, body=body, method="POST") key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="real-secret", status="active", expire_date=None) with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec): with pytest.raises(DapiAuthError) as exc: await authenticate_request(req) assert exc.value.status_code == 401 assert "Invalid signature" in exc.value.detail @pytest.mark.asyncio async def test_valid_authentication(self): ts = str(time.time()) secret = "my-secret" body = b'{"test": "data"}' sig = compute_dapi_signature("POST", "/test", ts, secret, body) req = self._make_request({ "X-DAPI-Key": "test-key", "X-DAPI-Timestamp": ts, "X-DAPI-Signature": sig, }, body=body, method="POST") key_rec = DapiKeyRecord(id=1, apikey="test-key", secret=secret, status="active", expire_date=None) with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec): result = await authenticate_request(req) assert result.apikey == "test-key" assert result.secret == secret # --------------------------------------------------------------------------- # DapiAuthMiddleware tests # --------------------------------------------------------------------------- class TestDapiAuthMiddleware: @pytest.fixture(autouse=True) def mock_sage_modules(self): """Create mock ahserver and sqlor modules.""" import sys import types ahserver = types.ModuleType("ahserver") ahserver_serverenv = types.ModuleType("ahserver.serverenv") ahserver_serverenv.ServerEnv = MagicMock ahserver.serverenv = ahserver_serverenv sqlor = types.ModuleType("sqlor") sqlor_dbpools = types.ModuleType("sqlor.dbpools") mock_sor = AsyncMock() mock_sor.R = AsyncMock(return_value=[]) mock_ctx = AsyncMock() mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor) mock_ctx.__aexit__ = AsyncMock(return_value=False) sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx) sqlor.dbpools = sqlor_dbpools sys.modules["ahserver"] = ahserver sys.modules["ahserver.serverenv"] = ahserver_serverenv sys.modules["sqlor"] = sqlor sys.modules["sqlor.dbpools"] = sqlor_dbpools yield for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]: if mod in sys.modules: del sys.modules[mod] def _build_test_app(self, exclude_paths=None, tolerance_sec=DEFAULT_TIMESTAMP_TOLERANCE_SEC): from starlette.applications import Starlette from starlette.responses import PlainTextResponse from starlette.routing import Route async def protected(request: Request): return PlainTextResponse("OK - authenticated") async def health(request: Request): return PlainTextResponse("healthy") app = Starlette( routes=[ Route("/protected", endpoint=protected, methods=["GET"]), Route("/health", endpoint=health, methods=["GET"]), ] ) app.add_middleware( DapiAuthMiddleware, exclude_paths=exclude_paths or ["/health"], tolerance_sec=tolerance_sec, ) return app def test_unauthenticated_request_rejected(self): app = self._build_test_app() client = TestClient(app) resp = client.get("/protected") assert resp.status_code == 401 assert "error" in resp.json() def test_excluded_path_bypasses_auth(self): app = self._build_test_app() client = TestClient(app) resp = client.get("/health") assert resp.status_code == 200 assert resp.text == "healthy" def test_valid_request_accepted(self): ts = str(time.time()) secret = "test-secret" sig = compute_dapi_signature("GET", "/protected", ts, secret) key_rec = DapiKeyRecord(id=1, apikey="valid-key", secret=secret, status="active", expire_date=None) app = self._build_test_app() client = TestClient(app) with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec): resp = client.get( "/protected", headers={ "X-DAPI-Key": "valid-key", "X-DAPI-Timestamp": ts, "X-DAPI-Signature": sig, }, ) assert resp.status_code == 200 assert resp.text == "OK - authenticated"