- Sync engine: BaseSync abstract class + 4 sync modules (users/pricing/uapi/llmage) - Checkpoint management via sync_state table - Batch processing with retry and exponential backoff - Incremental fetch from Sage DB via sqlor - UPSERT to local cache tables - API handlers: balance/accounting/users/pricing/health - Balance: cache lookup + Sage fallback - Accounting: create with idempotency, query with filters/pagination - Users: keyword search, org filter - Pricing: filter by ppid/llmid/type/status - Health: basic + readiness checks (DB connectivity) - DAPI auth: middleware + authenticate_request function - HMAC-SHA256 signature verification - Timestamp window validation - Sage downapikey table lookup - HTTP client: SageHttpClient with aiohttp - Auto DAPI signature injection - Connection pooling, retry, timeout - Router: 12 routes registered - Module init: load_sageapi() wires everything to ServerEnv
459 lines
16 KiB
Python
459 lines
16 KiB
Python
"""Tests for sageapi.middleware.dapi_auth"""
|
|
|
|
import hashlib
|
|
import hmac
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse
|
|
from starlette.testclient import TestClient
|
|
|
|
from sageapi.middleware.dapi_auth import (
|
|
DEFAULT_TIMESTAMP_TOLERANCE_SEC,
|
|
DapiAuthError,
|
|
DapiAuthMiddleware,
|
|
DapiKeyRecord,
|
|
authenticate_request,
|
|
compute_dapi_signature,
|
|
lookup_api_key,
|
|
verify_timestamp,
|
|
)
|
|
|
|
|
|
def make_signature(method: str, path: str, timestamp: str, secret: str, body: bytes = b"") -> str:
|
|
"""Helper to create a valid DAPI signature."""
|
|
return compute_dapi_signature(method, path, timestamp, secret, body if body else None)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# compute_dapi_signature tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestComputeDapiSignature:
|
|
def test_basic_signature(self):
|
|
sig = compute_dapi_signature("GET", "/api/test", "1700000000.0", "my-secret")
|
|
assert isinstance(sig, str)
|
|
assert len(sig) == 64 # SHA-256 hex length
|
|
|
|
def test_signature_with_body(self):
|
|
body = b'{"key": "value"}'
|
|
sig = compute_dapi_signature("POST", "/api/test", "1700000000.0", "my-secret", body)
|
|
assert isinstance(sig, str)
|
|
assert len(sig) == 64
|
|
|
|
def test_signature_deterministic(self):
|
|
sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret")
|
|
sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret")
|
|
assert sig1 == sig2
|
|
|
|
def test_different_body_different_signature(self):
|
|
sig1 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body1")
|
|
sig2 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body2")
|
|
assert sig1 != sig2
|
|
|
|
def test_empty_body_same_as_no_body(self):
|
|
sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret")
|
|
sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret", None)
|
|
assert sig1 == sig2
|
|
|
|
def test_manual_verification(self):
|
|
"""Verify the signature matches the expected HMAC-SHA256 output."""
|
|
method, path, ts, secret = "POST", "/v1/chat", "1700000000.0", "test-secret"
|
|
body = b'{"message":"hello"}'
|
|
body_hash = hashlib.sha256(body).hexdigest()
|
|
string_to_sign = f"{method}\n{path}\n{ts}\n{body_hash}"
|
|
expected = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha256).hexdigest()
|
|
actual = compute_dapi_signature(method, path, ts, secret, body)
|
|
assert actual == expected
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# verify_timestamp tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestVerifyTimestamp:
|
|
def test_valid_timestamp(self):
|
|
ts = str(time.time())
|
|
assert verify_timestamp(ts) is True
|
|
|
|
def test_expired_timestamp(self):
|
|
ts = str(time.time() - 600) # 10 minutes ago
|
|
assert verify_timestamp(ts) is False
|
|
|
|
def test_future_timestamp(self):
|
|
ts = str(time.time() + 600) # 10 minutes in future
|
|
assert verify_timestamp(ts) is False
|
|
|
|
def test_invalid_timestamp(self):
|
|
assert verify_timestamp("not-a-number") is False
|
|
assert verify_timestamp("") is False
|
|
|
|
def test_custom_tolerance(self):
|
|
ts = str(time.time() - 10)
|
|
assert verify_timestamp(ts, tolerance_sec=5) is False
|
|
assert verify_timestamp(ts, tolerance_sec=15) is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DapiKeyRecord tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestDapiKeyRecord:
|
|
def test_active_key(self):
|
|
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None)
|
|
assert rec.is_active is True
|
|
|
|
def test_inactive_key(self):
|
|
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="inactive", expire_date=None)
|
|
assert rec.is_active is False
|
|
|
|
def test_not_expired_when_no_expiry(self):
|
|
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None)
|
|
assert rec.is_expired is False
|
|
|
|
def test_expired_with_past_date(self):
|
|
past = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
|
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=past)
|
|
assert rec.is_expired is True
|
|
|
|
def test_not_expired_with_future_date(self):
|
|
future = datetime(2099, 1, 1, tzinfo=timezone.utc)
|
|
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=future)
|
|
assert rec.is_expired is False
|
|
|
|
def test_expired_with_iso_string_past(self):
|
|
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2020-01-01T00:00:00Z")
|
|
assert rec.is_expired is True
|
|
|
|
def test_not_expired_with_iso_string_future(self):
|
|
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2099-01-01T00:00:00Z")
|
|
assert rec.is_expired is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# lookup_api_key tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestLookupApiKey:
|
|
@pytest.fixture(autouse=True)
|
|
def mock_sage_modules(self):
|
|
"""Create mock ahserver and sqlor modules for testing."""
|
|
import sys
|
|
import types
|
|
|
|
# Create fake modules
|
|
ahserver = types.ModuleType("ahserver")
|
|
ahserver_serverenv = types.ModuleType("ahserver.serverenv")
|
|
ahserver_serverenv.ServerEnv = MagicMock
|
|
ahserver.serverenv = ahserver_serverenv
|
|
|
|
sqlor = types.ModuleType("sqlor")
|
|
sqlor_dbpools = types.ModuleType("sqlor.dbpools")
|
|
|
|
mock_ctx = AsyncMock()
|
|
mock_sor = AsyncMock()
|
|
mock_sor.R = AsyncMock(return_value=[])
|
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor)
|
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx)
|
|
sqlor.dbpools = sqlor_dbpools
|
|
|
|
sys.modules["ahserver"] = ahserver
|
|
sys.modules["ahserver.serverenv"] = ahserver_serverenv
|
|
sys.modules["sqlor"] = sqlor
|
|
sys.modules["sqlor.dbpools"] = sqlor_dbpools
|
|
|
|
yield mock_ctx, mock_sor
|
|
|
|
# Cleanup
|
|
for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]:
|
|
if mod in sys.modules:
|
|
del sys.modules[mod]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_lookup_found(self, mock_sage_modules):
|
|
mock_ctx, mock_sor = mock_sage_modules
|
|
mock_rec = {
|
|
"id": 1,
|
|
"apikey": "test-key",
|
|
"secret": "test-secret",
|
|
"status": "active",
|
|
"expire_date": "2099-01-01T00:00:00Z",
|
|
"description": "Test key",
|
|
}
|
|
mock_sor.R.return_value = [mock_rec]
|
|
|
|
result = await lookup_api_key("test-key")
|
|
|
|
assert result is not None
|
|
assert result.apikey == "test-key"
|
|
assert result.secret == "test-secret"
|
|
assert result.is_active is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_lookup_not_found(self, mock_sage_modules):
|
|
mock_ctx, mock_sor = mock_sage_modules
|
|
mock_sor.R.return_value = []
|
|
|
|
result = await lookup_api_key("nonexistent")
|
|
|
|
assert result is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# authenticate_request tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestAuthenticateRequest:
|
|
def _make_request(self, headers: dict, body: bytes = b"", method: str = "GET", path: str = "/test") -> Request:
|
|
scope = {
|
|
"type": "http",
|
|
"method": method,
|
|
"path": path,
|
|
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
|
|
"query_string": b"",
|
|
}
|
|
|
|
async def receive():
|
|
return {"type": "http.request", "body": body, "more_body": False}
|
|
|
|
return Request(scope, receive)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_api_key_header(self):
|
|
req = self._make_request({
|
|
"X-DAPI-Timestamp": str(time.time()),
|
|
"X-DAPI-Signature": "abc",
|
|
})
|
|
with pytest.raises(DapiAuthError) as exc:
|
|
await authenticate_request(req)
|
|
assert exc.value.status_code == 401
|
|
assert "X-DAPI-Key" in exc.value.detail
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_timestamp_header(self):
|
|
req = self._make_request({
|
|
"X-DAPI-Key": "test-key",
|
|
"X-DAPI-Signature": "abc",
|
|
})
|
|
with pytest.raises(DapiAuthError) as exc:
|
|
await authenticate_request(req)
|
|
assert exc.value.status_code == 401
|
|
assert "X-DAPI-Timestamp" in exc.value.detail
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_signature_header(self):
|
|
req = self._make_request({
|
|
"X-DAPI-Key": "test-key",
|
|
"X-DAPI-Timestamp": str(time.time()),
|
|
})
|
|
with pytest.raises(DapiAuthError) as exc:
|
|
await authenticate_request(req)
|
|
assert exc.value.status_code == 401
|
|
assert "X-DAPI-Signature" in exc.value.detail
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_expired_timestamp(self):
|
|
ts = str(time.time() - 600)
|
|
req = self._make_request({
|
|
"X-DAPI-Key": "test-key",
|
|
"X-DAPI-Timestamp": ts,
|
|
"X-DAPI-Signature": "abc",
|
|
})
|
|
with pytest.raises(DapiAuthError) as exc:
|
|
await authenticate_request(req)
|
|
assert exc.value.status_code == 401
|
|
assert "timestamp" in exc.value.detail.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_api_key(self):
|
|
ts = str(time.time())
|
|
req = self._make_request({
|
|
"X-DAPI-Key": "bad-key",
|
|
"X-DAPI-Timestamp": ts,
|
|
"X-DAPI-Signature": "abc",
|
|
})
|
|
|
|
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=None):
|
|
with pytest.raises(DapiAuthError) as exc:
|
|
await authenticate_request(req)
|
|
assert exc.value.status_code == 401
|
|
assert "Invalid API key" in exc.value.detail
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_inactive_key(self):
|
|
ts = str(time.time())
|
|
req = self._make_request({
|
|
"X-DAPI-Key": "test-key",
|
|
"X-DAPI-Timestamp": ts,
|
|
"X-DAPI-Signature": "abc",
|
|
})
|
|
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="secret", status="inactive", expire_date=None)
|
|
|
|
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
|
|
with pytest.raises(DapiAuthError) as exc:
|
|
await authenticate_request(req)
|
|
assert exc.value.status_code == 403
|
|
assert "inactive" in exc.value.detail.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_expired_key(self):
|
|
ts = str(time.time())
|
|
req = self._make_request({
|
|
"X-DAPI-Key": "test-key",
|
|
"X-DAPI-Timestamp": ts,
|
|
"X-DAPI-Signature": "abc",
|
|
})
|
|
key_rec = DapiKeyRecord(
|
|
id=1, apikey="test-key", secret="secret", status="active",
|
|
expire_date=datetime(2020, 1, 1, tzinfo=timezone.utc),
|
|
)
|
|
|
|
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
|
|
with pytest.raises(DapiAuthError) as exc:
|
|
await authenticate_request(req)
|
|
assert exc.value.status_code == 403
|
|
assert "expired" in exc.value.detail.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_signature(self):
|
|
ts = str(time.time())
|
|
body = b'{"test": "data"}'
|
|
req = self._make_request({
|
|
"X-DAPI-Key": "test-key",
|
|
"X-DAPI-Timestamp": ts,
|
|
"X-DAPI-Signature": "invalid-signature",
|
|
}, body=body, method="POST")
|
|
|
|
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="real-secret", status="active", expire_date=None)
|
|
|
|
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
|
|
with pytest.raises(DapiAuthError) as exc:
|
|
await authenticate_request(req)
|
|
assert exc.value.status_code == 401
|
|
assert "Invalid signature" in exc.value.detail
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_valid_authentication(self):
|
|
ts = str(time.time())
|
|
secret = "my-secret"
|
|
body = b'{"test": "data"}'
|
|
sig = compute_dapi_signature("POST", "/test", ts, secret, body)
|
|
|
|
req = self._make_request({
|
|
"X-DAPI-Key": "test-key",
|
|
"X-DAPI-Timestamp": ts,
|
|
"X-DAPI-Signature": sig,
|
|
}, body=body, method="POST")
|
|
|
|
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret=secret, status="active", expire_date=None)
|
|
|
|
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
|
|
result = await authenticate_request(req)
|
|
|
|
assert result.apikey == "test-key"
|
|
assert result.secret == secret
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DapiAuthMiddleware tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestDapiAuthMiddleware:
|
|
@pytest.fixture(autouse=True)
|
|
def mock_sage_modules(self):
|
|
"""Create mock ahserver and sqlor modules."""
|
|
import sys
|
|
import types
|
|
|
|
ahserver = types.ModuleType("ahserver")
|
|
ahserver_serverenv = types.ModuleType("ahserver.serverenv")
|
|
ahserver_serverenv.ServerEnv = MagicMock
|
|
ahserver.serverenv = ahserver_serverenv
|
|
|
|
sqlor = types.ModuleType("sqlor")
|
|
sqlor_dbpools = types.ModuleType("sqlor.dbpools")
|
|
mock_sor = AsyncMock()
|
|
mock_sor.R = AsyncMock(return_value=[])
|
|
mock_ctx = AsyncMock()
|
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor)
|
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx)
|
|
sqlor.dbpools = sqlor_dbpools
|
|
|
|
sys.modules["ahserver"] = ahserver
|
|
sys.modules["ahserver.serverenv"] = ahserver_serverenv
|
|
sys.modules["sqlor"] = sqlor
|
|
sys.modules["sqlor.dbpools"] = sqlor_dbpools
|
|
|
|
yield
|
|
|
|
for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]:
|
|
if mod in sys.modules:
|
|
del sys.modules[mod]
|
|
|
|
def _build_test_app(self, exclude_paths=None, tolerance_sec=DEFAULT_TIMESTAMP_TOLERANCE_SEC):
|
|
from starlette.applications import Starlette
|
|
from starlette.responses import PlainTextResponse
|
|
from starlette.routing import Route
|
|
|
|
async def protected(request: Request):
|
|
return PlainTextResponse("OK - authenticated")
|
|
|
|
async def health(request: Request):
|
|
return PlainTextResponse("healthy")
|
|
|
|
app = Starlette(
|
|
routes=[
|
|
Route("/protected", endpoint=protected, methods=["GET"]),
|
|
Route("/health", endpoint=health, methods=["GET"]),
|
|
]
|
|
)
|
|
|
|
app.add_middleware(
|
|
DapiAuthMiddleware,
|
|
exclude_paths=exclude_paths or ["/health"],
|
|
tolerance_sec=tolerance_sec,
|
|
)
|
|
return app
|
|
|
|
def test_unauthenticated_request_rejected(self):
|
|
app = self._build_test_app()
|
|
client = TestClient(app)
|
|
resp = client.get("/protected")
|
|
assert resp.status_code == 401
|
|
assert "error" in resp.json()
|
|
|
|
def test_excluded_path_bypasses_auth(self):
|
|
app = self._build_test_app()
|
|
client = TestClient(app)
|
|
resp = client.get("/health")
|
|
assert resp.status_code == 200
|
|
assert resp.text == "healthy"
|
|
|
|
def test_valid_request_accepted(self):
|
|
ts = str(time.time())
|
|
secret = "test-secret"
|
|
sig = compute_dapi_signature("GET", "/protected", ts, secret)
|
|
|
|
key_rec = DapiKeyRecord(id=1, apikey="valid-key", secret=secret, status="active", expire_date=None)
|
|
|
|
app = self._build_test_app()
|
|
client = TestClient(app)
|
|
|
|
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
|
|
resp = client.get(
|
|
"/protected",
|
|
headers={
|
|
"X-DAPI-Key": "valid-key",
|
|
"X-DAPI-Timestamp": ts,
|
|
"X-DAPI-Signature": sig,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
assert resp.text == "OK - authenticated"
|