sageapi/tests/test_dapi_auth.py
Hermes Agent 5936a2f328 feat: implement sync engine, API handlers, DAPI auth, HTTP client
- Sync engine: BaseSync abstract class + 4 sync modules (users/pricing/uapi/llmage)
  - Checkpoint management via sync_state table
  - Batch processing with retry and exponential backoff
  - Incremental fetch from Sage DB via sqlor
  - UPSERT to local cache tables
- API handlers: balance/accounting/users/pricing/health
  - Balance: cache lookup + Sage fallback
  - Accounting: create with idempotency, query with filters/pagination
  - Users: keyword search, org filter
  - Pricing: filter by ppid/llmid/type/status
  - Health: basic + readiness checks (DB connectivity)
- DAPI auth: middleware + authenticate_request function
  - HMAC-SHA256 signature verification
  - Timestamp window validation
  - Sage downapikey table lookup
- HTTP client: SageHttpClient with aiohttp
  - Auto DAPI signature injection
  - Connection pooling, retry, timeout
- Router: 12 routes registered
- Module init: load_sageapi() wires everything to ServerEnv
2026-05-20 18:22:23 +08:00

459 lines
16 KiB
Python

"""Tests for sageapi.middleware.dapi_auth"""
import hashlib
import hmac
import time
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.testclient import TestClient
from sageapi.middleware.dapi_auth import (
DEFAULT_TIMESTAMP_TOLERANCE_SEC,
DapiAuthError,
DapiAuthMiddleware,
DapiKeyRecord,
authenticate_request,
compute_dapi_signature,
lookup_api_key,
verify_timestamp,
)
def make_signature(method: str, path: str, timestamp: str, secret: str, body: bytes = b"") -> str:
"""Helper to create a valid DAPI signature."""
return compute_dapi_signature(method, path, timestamp, secret, body if body else None)
# ---------------------------------------------------------------------------
# compute_dapi_signature tests
# ---------------------------------------------------------------------------
class TestComputeDapiSignature:
def test_basic_signature(self):
sig = compute_dapi_signature("GET", "/api/test", "1700000000.0", "my-secret")
assert isinstance(sig, str)
assert len(sig) == 64 # SHA-256 hex length
def test_signature_with_body(self):
body = b'{"key": "value"}'
sig = compute_dapi_signature("POST", "/api/test", "1700000000.0", "my-secret", body)
assert isinstance(sig, str)
assert len(sig) == 64
def test_signature_deterministic(self):
sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret")
sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret")
assert sig1 == sig2
def test_different_body_different_signature(self):
sig1 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body1")
sig2 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body2")
assert sig1 != sig2
def test_empty_body_same_as_no_body(self):
sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret")
sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret", None)
assert sig1 == sig2
def test_manual_verification(self):
"""Verify the signature matches the expected HMAC-SHA256 output."""
method, path, ts, secret = "POST", "/v1/chat", "1700000000.0", "test-secret"
body = b'{"message":"hello"}'
body_hash = hashlib.sha256(body).hexdigest()
string_to_sign = f"{method}\n{path}\n{ts}\n{body_hash}"
expected = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha256).hexdigest()
actual = compute_dapi_signature(method, path, ts, secret, body)
assert actual == expected
# ---------------------------------------------------------------------------
# verify_timestamp tests
# ---------------------------------------------------------------------------
class TestVerifyTimestamp:
def test_valid_timestamp(self):
ts = str(time.time())
assert verify_timestamp(ts) is True
def test_expired_timestamp(self):
ts = str(time.time() - 600) # 10 minutes ago
assert verify_timestamp(ts) is False
def test_future_timestamp(self):
ts = str(time.time() + 600) # 10 minutes in future
assert verify_timestamp(ts) is False
def test_invalid_timestamp(self):
assert verify_timestamp("not-a-number") is False
assert verify_timestamp("") is False
def test_custom_tolerance(self):
ts = str(time.time() - 10)
assert verify_timestamp(ts, tolerance_sec=5) is False
assert verify_timestamp(ts, tolerance_sec=15) is True
# ---------------------------------------------------------------------------
# DapiKeyRecord tests
# ---------------------------------------------------------------------------
class TestDapiKeyRecord:
def test_active_key(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None)
assert rec.is_active is True
def test_inactive_key(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="inactive", expire_date=None)
assert rec.is_active is False
def test_not_expired_when_no_expiry(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None)
assert rec.is_expired is False
def test_expired_with_past_date(self):
past = datetime(2020, 1, 1, tzinfo=timezone.utc)
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=past)
assert rec.is_expired is True
def test_not_expired_with_future_date(self):
future = datetime(2099, 1, 1, tzinfo=timezone.utc)
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=future)
assert rec.is_expired is False
def test_expired_with_iso_string_past(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2020-01-01T00:00:00Z")
assert rec.is_expired is True
def test_not_expired_with_iso_string_future(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2099-01-01T00:00:00Z")
assert rec.is_expired is False
# ---------------------------------------------------------------------------
# lookup_api_key tests
# ---------------------------------------------------------------------------
class TestLookupApiKey:
@pytest.fixture(autouse=True)
def mock_sage_modules(self):
"""Create mock ahserver and sqlor modules for testing."""
import sys
import types
# Create fake modules
ahserver = types.ModuleType("ahserver")
ahserver_serverenv = types.ModuleType("ahserver.serverenv")
ahserver_serverenv.ServerEnv = MagicMock
ahserver.serverenv = ahserver_serverenv
sqlor = types.ModuleType("sqlor")
sqlor_dbpools = types.ModuleType("sqlor.dbpools")
mock_ctx = AsyncMock()
mock_sor = AsyncMock()
mock_sor.R = AsyncMock(return_value=[])
mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx)
sqlor.dbpools = sqlor_dbpools
sys.modules["ahserver"] = ahserver
sys.modules["ahserver.serverenv"] = ahserver_serverenv
sys.modules["sqlor"] = sqlor
sys.modules["sqlor.dbpools"] = sqlor_dbpools
yield mock_ctx, mock_sor
# Cleanup
for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]:
if mod in sys.modules:
del sys.modules[mod]
@pytest.mark.asyncio
async def test_lookup_found(self, mock_sage_modules):
mock_ctx, mock_sor = mock_sage_modules
mock_rec = {
"id": 1,
"apikey": "test-key",
"secret": "test-secret",
"status": "active",
"expire_date": "2099-01-01T00:00:00Z",
"description": "Test key",
}
mock_sor.R.return_value = [mock_rec]
result = await lookup_api_key("test-key")
assert result is not None
assert result.apikey == "test-key"
assert result.secret == "test-secret"
assert result.is_active is True
@pytest.mark.asyncio
async def test_lookup_not_found(self, mock_sage_modules):
mock_ctx, mock_sor = mock_sage_modules
mock_sor.R.return_value = []
result = await lookup_api_key("nonexistent")
assert result is None
# ---------------------------------------------------------------------------
# authenticate_request tests
# ---------------------------------------------------------------------------
class TestAuthenticateRequest:
def _make_request(self, headers: dict, body: bytes = b"", method: str = "GET", path: str = "/test") -> Request:
scope = {
"type": "http",
"method": method,
"path": path,
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
"query_string": b"",
}
async def receive():
return {"type": "http.request", "body": body, "more_body": False}
return Request(scope, receive)
@pytest.mark.asyncio
async def test_missing_api_key_header(self):
req = self._make_request({
"X-DAPI-Timestamp": str(time.time()),
"X-DAPI-Signature": "abc",
})
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "X-DAPI-Key" in exc.value.detail
@pytest.mark.asyncio
async def test_missing_timestamp_header(self):
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Signature": "abc",
})
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "X-DAPI-Timestamp" in exc.value.detail
@pytest.mark.asyncio
async def test_missing_signature_header(self):
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": str(time.time()),
})
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "X-DAPI-Signature" in exc.value.detail
@pytest.mark.asyncio
async def test_expired_timestamp(self):
ts = str(time.time() - 600)
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "abc",
})
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "timestamp" in exc.value.detail.lower()
@pytest.mark.asyncio
async def test_invalid_api_key(self):
ts = str(time.time())
req = self._make_request({
"X-DAPI-Key": "bad-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "abc",
})
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=None):
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "Invalid API key" in exc.value.detail
@pytest.mark.asyncio
async def test_inactive_key(self):
ts = str(time.time())
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "abc",
})
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="secret", status="inactive", expire_date=None)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 403
assert "inactive" in exc.value.detail.lower()
@pytest.mark.asyncio
async def test_expired_key(self):
ts = str(time.time())
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "abc",
})
key_rec = DapiKeyRecord(
id=1, apikey="test-key", secret="secret", status="active",
expire_date=datetime(2020, 1, 1, tzinfo=timezone.utc),
)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 403
assert "expired" in exc.value.detail.lower()
@pytest.mark.asyncio
async def test_invalid_signature(self):
ts = str(time.time())
body = b'{"test": "data"}'
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "invalid-signature",
}, body=body, method="POST")
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="real-secret", status="active", expire_date=None)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "Invalid signature" in exc.value.detail
@pytest.mark.asyncio
async def test_valid_authentication(self):
ts = str(time.time())
secret = "my-secret"
body = b'{"test": "data"}'
sig = compute_dapi_signature("POST", "/test", ts, secret, body)
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": sig,
}, body=body, method="POST")
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret=secret, status="active", expire_date=None)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
result = await authenticate_request(req)
assert result.apikey == "test-key"
assert result.secret == secret
# ---------------------------------------------------------------------------
# DapiAuthMiddleware tests
# ---------------------------------------------------------------------------
class TestDapiAuthMiddleware:
@pytest.fixture(autouse=True)
def mock_sage_modules(self):
"""Create mock ahserver and sqlor modules."""
import sys
import types
ahserver = types.ModuleType("ahserver")
ahserver_serverenv = types.ModuleType("ahserver.serverenv")
ahserver_serverenv.ServerEnv = MagicMock
ahserver.serverenv = ahserver_serverenv
sqlor = types.ModuleType("sqlor")
sqlor_dbpools = types.ModuleType("sqlor.dbpools")
mock_sor = AsyncMock()
mock_sor.R = AsyncMock(return_value=[])
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx)
sqlor.dbpools = sqlor_dbpools
sys.modules["ahserver"] = ahserver
sys.modules["ahserver.serverenv"] = ahserver_serverenv
sys.modules["sqlor"] = sqlor
sys.modules["sqlor.dbpools"] = sqlor_dbpools
yield
for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]:
if mod in sys.modules:
del sys.modules[mod]
def _build_test_app(self, exclude_paths=None, tolerance_sec=DEFAULT_TIMESTAMP_TOLERANCE_SEC):
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse
from starlette.routing import Route
async def protected(request: Request):
return PlainTextResponse("OK - authenticated")
async def health(request: Request):
return PlainTextResponse("healthy")
app = Starlette(
routes=[
Route("/protected", endpoint=protected, methods=["GET"]),
Route("/health", endpoint=health, methods=["GET"]),
]
)
app.add_middleware(
DapiAuthMiddleware,
exclude_paths=exclude_paths or ["/health"],
tolerance_sec=tolerance_sec,
)
return app
def test_unauthenticated_request_rejected(self):
app = self._build_test_app()
client = TestClient(app)
resp = client.get("/protected")
assert resp.status_code == 401
assert "error" in resp.json()
def test_excluded_path_bypasses_auth(self):
app = self._build_test_app()
client = TestClient(app)
resp = client.get("/health")
assert resp.status_code == 200
assert resp.text == "healthy"
def test_valid_request_accepted(self):
ts = str(time.time())
secret = "test-secret"
sig = compute_dapi_signature("GET", "/protected", ts, secret)
key_rec = DapiKeyRecord(id=1, apikey="valid-key", secret=secret, status="active", expire_date=None)
app = self._build_test_app()
client = TestClient(app)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
resp = client.get(
"/protected",
headers={
"X-DAPI-Key": "valid-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": sig,
},
)
assert resp.status_code == 200
assert resp.text == "OK - authenticated"