feat: implement sync engine, API handlers, DAPI auth, HTTP client

- Sync engine: BaseSync abstract class + 4 sync modules (users/pricing/uapi/llmage)
  - Checkpoint management via sync_state table
  - Batch processing with retry and exponential backoff
  - Incremental fetch from Sage DB via sqlor
  - UPSERT to local cache tables
- API handlers: balance/accounting/users/pricing/health
  - Balance: cache lookup + Sage fallback
  - Accounting: create with idempotency, query with filters/pagination
  - Users: keyword search, org filter
  - Pricing: filter by ppid/llmid/type/status
  - Health: basic + readiness checks (DB connectivity)
- DAPI auth: middleware + authenticate_request function
  - HMAC-SHA256 signature verification
  - Timestamp window validation
  - Sage downapikey table lookup
- HTTP client: SageHttpClient with aiohttp
  - Auto DAPI signature injection
  - Connection pooling, retry, timeout
- Router: 12 routes registered
- Module init: load_sageapi() wires everything to ServerEnv
This commit is contained in:
Hermes Agent 2026-05-20 18:22:23 +08:00
parent a9ea05ff2d
commit 5936a2f328
44 changed files with 2854 additions and 716 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
__pycache__/
*.pyc

View File

@ -0,0 +1,26 @@
"""SageAPI - Sage data caching and proxy API server.
Provides cached access to Sage data (users, pricing, uapi, llmage)
with DAPI authentication and independent multi-instance deployment.
"""
__version__ = '0.1.0'
# Public API exports
from sageapi.sync import run_all_syncs, UserSync, PricingSync, UapiSync, LlmageSync
from sageapi.router import Router, setup_routes
from sageapi.cache.cache_manager import CacheManager
from sageapi.middleware.dapi_auth import authenticate_request, DapiAuthMiddleware
__all__ = [
'run_all_syncs',
'UserSync',
'PricingSync',
'UapiSync',
'LlmageSync',
'Router',
'setup_routes',
'CacheManager',
'authenticate_request',
'DapiAuthMiddleware',
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -2,7 +2,7 @@
Provides endpoints for creating and querying accounting records. Provides endpoints for creating and querying accounting records.
Writing goes directly to the accounting_records table; reads Writing goes directly to the accounting_records table; reads
are served from the same table with optional date range filtering. are served from the same table with optional filtering.
""" """
from __future__ import annotations from __future__ import annotations
@ -14,64 +14,86 @@ from typing import Any
from appPublic.log import debug, error from appPublic.log import debug, error
from sqlor.dbpools import DBPools from sqlor.dbpools import DBPools
from ahserver.serverenv import ServerEnv
async def create_accounting_record( async def create_accounting_record(
customer_id: str, customer_id: str,
amount: float, amount: float,
record_type: str = 'charge', llmid: str = '',
description: str = '', model_name: str = '',
**extra: Any, pricing_id: str = '',
input_tokens: int | None = None,
output_tokens: int | None = None,
total_tokens: int | None = None,
quantity: float | None = None,
currency: str = 'CNY',
request_id: str = '',
transno: str = '',
) -> str: ) -> str:
"""Create a new accounting record. """Create a new accounting record with idempotency via request_id."""
Args:
customer_id: The customer identifier.
amount: The accounting amount (positive for charges, negative for credits).
record_type: Type of record (charge, credit, adjustment, etc.).
description: Optional description of the transaction.
Returns:
JSON string with success flag and the created record ID.
"""
result: dict[str, Any] = {'success': False, 'record_id': None} result: dict[str, Any] = {'success': False, 'record_id': None}
try: try:
from ahserver.serverenv import ServerEnv
env = ServerEnv() env = ServerEnv()
dbname = env.get_module_dbname('sageapi') dbname = env.get_module_dbname('sageapi')
if not dbname: if not dbname:
result['error'] = 'No database configured for sageapi module' result['error'] = 'No database configured for sageapi module'
return json.dumps(result, ensure_ascii=False, default=str) return json.dumps(result, ensure_ascii=False, default=str)
record_id = str(uuid.uuid4()) record_id = request_id or str(uuid.uuid4())
now = time.strftime('%Y-%m-%d %H:%M:%S') now = time.strftime('%Y-%m-%d %H:%M:%S')
# Check idempotency
if request_id:
async with DBPools().sqlorContext(dbname) as sor:
existing = await sor.sqlExe(
"SELECT id FROM accounting_records WHERE request_id = ${request_id}$",
{'request_id': request_id},
)
if isinstance(existing, list) and existing:
result['success'] = True
result['record_id'] = existing[0].get('id', existing[0].get('id') if isinstance(existing[0], dict) else existing[0])
result['duplicate'] = True
return json.dumps(result, ensure_ascii=False, default=str)
sql = """ sql = """
INSERT INTO accounting_records INSERT INTO accounting_records
(id, customer_id, amount, record_type, description, created_at, extra) (id, customer_id, llmid, model_name, pricing_id,
input_tokens, output_tokens, total_tokens, quantity,
amount, currency, request_id, transno, status,
created_at, updated_at)
VALUES VALUES
(${id}$, ${customer_id}$, ${amount}$, ${record_type}$, ${description}$, ${created_at}$, ${extra}$) (${id}$, ${customer_id}$, ${llmid}$, ${model_name}$, ${pricing_id}$,
${input_tokens}$, ${output_tokens}$, ${total_tokens}$, ${quantity}$,
${amount}$, ${currency}$, ${request_id}$, ${transno}$, 'accounted',
${created_at}$, ${updated_at}$)
""" """
params = {
'id': record_id,
'customer_id': customer_id,
'llmid': llmid,
'model_name': model_name,
'pricing_id': pricing_id,
'input_tokens': input_tokens,
'output_tokens': output_tokens,
'total_tokens': total_tokens,
'quantity': quantity,
'amount': amount,
'currency': currency,
'request_id': request_id,
'transno': transno,
'created_at': now,
'updated_at': now,
}
async with DBPools().sqlorContext(dbname) as sor: async with DBPools().sqlorContext(dbname) as sor:
await sor.sqlExe(sql, { await sor.sqlExe(sql, params)
'id': record_id, result['success'] = True
'customer_id': customer_id, result['record_id'] = record_id
'amount': amount,
'record_type': record_type,
'description': description,
'created_at': now,
'extra': json.dumps(extra, ensure_ascii=False) if extra else None,
})
result['success'] = True
result['record_id'] = record_id
debug(f'Accounting record created: id={record_id}, customer={customer_id}, amount={amount}')
except Exception as e: except Exception as e:
error(f'Accounting record creation failed: {e}') error(f'create_accounting_record error: {e}')
result['error'] = str(e) result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str) return json.dumps(result, ensure_ascii=False, default=str)
@ -79,30 +101,19 @@ async def create_accounting_record(
async def query_accounting_records( async def query_accounting_records(
customer_id: str | None = None, customer_id: str | None = None,
start_date: str | None = None, date_from: str | None = None,
end_date: str | None = None, date_to: str | None = None,
limit: int = 100, llmid: str | None = None,
offset: int = 0, status: str | None = None,
page: int = 1,
page_size: int = 50,
) -> str: ) -> str:
"""Query accounting records with optional filters. """Query accounting records with filters and pagination."""
Args:
customer_id: Filter by customer ID.
start_date: Filter records from this date (inclusive, YYYY-MM-DD).
end_date: Filter records up to this date (inclusive, YYYY-MM-DD).
limit: Maximum number of records to return.
offset: Number of records to skip.
Returns:
JSON string with success flag and record data.
"""
result: dict[str, Any] = {'success': False, 'data': [], 'total': 0} result: dict[str, Any] = {'success': False, 'data': [], 'total': 0}
try: try:
from ahserver.serverenv import ServerEnv
env = ServerEnv() env = ServerEnv()
dbname = env.get_module_dbname('sageapi') dbname = env.get_module_dbname('sageapi')
if not dbname: if not dbname:
result['error'] = 'No database configured for sageapi module' result['error'] = 'No database configured for sageapi module'
return json.dumps(result, ensure_ascii=False, default=str) return json.dumps(result, ensure_ascii=False, default=str)
@ -113,41 +124,56 @@ async def query_accounting_records(
if customer_id: if customer_id:
conditions.append('customer_id = ${customer_id}$') conditions.append('customer_id = ${customer_id}$')
params['customer_id'] = customer_id params['customer_id'] = customer_id
if start_date: if date_from:
conditions.append('created_at >= ${start_date}$') conditions.append('created_at >= ${date_from}$')
params['start_date'] = start_date params['date_from'] = date_from
if end_date: if date_to:
conditions.append('created_at <= ${end_date}$') conditions.append('created_at <= ${date_to}$')
params['end_date'] = end_date params['date_to'] = date_to
if llmid:
conditions.append('llmid = ${llmid}$')
params['llmid'] = llmid
if status:
conditions.append('status = ${status}$')
params['status'] = status
where_clause = 'WHERE ' + ' AND '.join(conditions) if conditions else '' where = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
# Count query # Count query
count_sql = f""" count_sql = f"SELECT COUNT(*) as cnt FROM accounting_records {where}"
SELECT COUNT(*) as cnt FROM accounting_records {where_clause} offset = (page - 1) * page_size
"""
async with DBPools().sqlorContext(dbname) as sor:
count_rows = await sor.sqlExe(count_sql, params)
total = count_rows[0]['cnt'] if count_rows else 0
result['total'] = total
if total > 0: # Data query
data_sql = f""" data_sql = f"""
SELECT id, customer_id, amount, record_type, description, created_at, extra SELECT id, customer_id, llmid, model_name, pricing_id,
FROM accounting_records input_tokens, output_tokens, total_tokens, quantity,
{where_clause} amount, currency, request_id, transno, status,
ORDER BY created_at DESC created_at, updated_at
LIMIT ${limit}$ OFFSET ${offset}$ FROM accounting_records
""" {where}
params['limit'] = limit ORDER BY created_at DESC
params['offset'] = offset LIMIT {page_size} OFFSET {offset}
rows = await sor.sqlExe(data_sql, params) """
result['data'] = [dict(r) for r in (rows or [])]
async with DBPools().sqlorContext(dbname) as sor:
count_result = await sor.sqlExe(count_sql, params)
if isinstance(count_result, list) and count_result:
result['total'] = count_result[0].get('cnt', 0)
elif isinstance(count_result, dict):
result['total'] = count_result.get('cnt', 0)
data = await sor.sqlExe(data_sql, params)
if isinstance(data, dict):
result['data'] = data.get('rows', [])
elif isinstance(data, list):
result['data'] = data
result['success'] = True result['success'] = True
result['page'] = page
result['page_size'] = page_size
except Exception as e: except Exception as e:
error(f'Accounting query failed: {e}') error(f'query_accounting_records error: {e}')
result['error'] = str(e) result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str) return json.dumps(result, ensure_ascii=False, default=str)

View File

@ -1,7 +1,8 @@
"""Customer balance query API handler. """Customer balance query API handler.
Provides the RESTful endpoint for querying customer account balances. Provides the RESTful endpoint for querying customer account balances.
Reads from the local customer_balance cache table. Reads from the local customer_balance cache table, with fallback to
real-time query from Sage acc_balance table.
""" """
from __future__ import annotations from __future__ import annotations
@ -10,15 +11,18 @@ import json
from typing import Any from typing import Any
from appPublic.log import debug, error from appPublic.log import debug, error
from sqlor.dbpools import DBPools from sqlor.dbpools import DBPools, get_sor_context
from ahserver.serverenv import ServerEnv
async def get_customer_balance(customer_id: str | None = None) -> str: async def get_customer_balance(customer_id: str | None = None) -> str:
"""Query customer balance. """Query customer balance.
First checks the local customer_balance cache. If not found,
falls back to real-time query from Sage acc_balance table.
Args: Args:
customer_id: Optional customer ID filter. If not provided, customer_id: Optional customer ID filter.
returns all customer balances.
Returns: Returns:
JSON string with success flag and balance data. JSON string with success flag and balance data.
@ -26,42 +30,92 @@ async def get_customer_balance(customer_id: str | None = None) -> str:
result: dict[str, Any] = {'success': False, 'data': [], 'total': 0} result: dict[str, Any] = {'success': False, 'data': [], 'total': 0}
try: try:
from ahserver.serverenv import ServerEnv
env = ServerEnv() env = ServerEnv()
dbname = env.get_module_dbname('sageapi') cache_dbname = env.get_module_dbname('sageapi')
if not cache_dbname:
if not dbname:
result['error'] = 'No database configured for sageapi module' result['error'] = 'No database configured for sageapi module'
return json.dumps(result, ensure_ascii=False, default=str) return json.dumps(result, ensure_ascii=False, default=str)
params: dict[str, Any] = {} async with DBPools().sqlorContext(cache_dbname) as sor:
where_clause = '' params: dict[str, Any] = {}
if customer_id: where = ''
where_clause = 'WHERE customer_id = ${customer_id}$' if customer_id:
params['customer_id'] = customer_id where = 'WHERE id = ${customer_id}$'
params['customer_id'] = customer_id
sql = f""" sql = f"""
SELECT customer_id, balance, currency, updated_at SELECT id, balance, currency, credit_limit,
FROM customer_balance last_recharge, last_consumption,
{where_clause} status, cached_at
ORDER BY customer_id FROM customer_balance
""" {where}
ORDER BY id
async with DBPools().sqlorContext(dbname) as sor: """
data = await sor.sqlExe(sql, params) data = await sor.sqlExe(sql, params)
if isinstance(data, dict): if isinstance(data, dict):
result['total'] = data.get('total', 0) result['total'] = data.get('total', len(data.get('rows', [])))
result['data'] = [dict(r) for r in data.get('rows', [])] result['data'] = data.get('rows', [])
else: elif isinstance(data, list):
rows = [dict(r) for r in (data or [])] result['total'] = len(data)
result['data'] = rows result['data'] = data
result['total'] = len(rows)
# If cache miss for specific customer, try real-time from Sage
if customer_id and not result['data']:
result['data'] = await _query_sage_balance(env, customer_id)
result['total'] = len(result['data'])
result['success'] = True result['success'] = True
debug(f'Balance query: returned {result["total"]} records')
except Exception as e: except Exception as e:
error(f'Balance query failed: {e}') error(f'get_customer_balance error: {e}')
result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str)
async def _query_sage_balance(env: ServerEnv, customer_id: str) -> list[dict]:
"""Fallback: query Sage acc_balance table directly."""
try:
async with get_sor_context(env, 'sage') as sor:
sql = """
SELECT customer_id, balance, currency, status, updated_at
FROM acc_balance
WHERE customer_id = ${customer_id}$
"""
rows = await sor.sqlExe(sql, {'customer_id': customer_id})
if isinstance(rows, list):
return rows
elif isinstance(rows, dict):
return rows.get('rows', [])
except Exception as e:
error(f'_query_sage_balance error: {e}')
return []
async def update_customer_balance(customer_id: str, balance: float) -> str:
"""Update customer balance in cache (called by sync or accounting)."""
result: dict[str, Any] = {'success': False}
try:
env = ServerEnv()
cache_dbname = env.get_module_dbname('sageapi')
sql = """
INSERT INTO customer_balance (id, balance, cached_at)
VALUES (${customer_id}$, ${balance}$, NOW())
ON DUPLICATE KEY UPDATE
balance = ${balance}$,
cached_at = NOW()
"""
async with DBPools().sqlorContext(cache_dbname) as sor:
await sor.sqlExe(sql, {
'customer_id': customer_id,
'balance': balance,
})
result['success'] = True
except Exception as e:
error(f'update_customer_balance error: {e}')
result['error'] = str(e) result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str) return json.dumps(result, ensure_ascii=False, default=str)

View File

@ -1,7 +1,6 @@
"""Health check API handler. """Health check API handler.
Provides a simple endpoint for load balancer health checks and Provides endpoints for service health and readiness checks.
system status monitoring. No authentication required.
""" """
from __future__ import annotations from __future__ import annotations
@ -10,51 +9,101 @@ import json
import time import time
from typing import Any from typing import Any
from appPublic.log import debug from appPublic.log import debug, error
from sqlor.dbpools import DBPools from sqlor.dbpools import DBPools
from ahserver.serverenv import ServerEnv
_START_TIME = time.time()
async def health_check() -> str: async def health_check() -> str:
"""Health check endpoint. """Basic health check - returns service status."""
uptime = time.time() - _START_TIME
Returns system status including database connectivity, result = {
cache stats, and uptime information.
Returns:
JSON string with health status.
"""
result: dict[str, Any] = {
'status': 'ok', 'status': 'ok',
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), 'service': 'sageapi',
'database': 'unknown', 'uptime_seconds': round(uptime, 1),
'cache': {}, 'timestamp': time.strftime('%Y-%m-%dT%H:%M:%S%z'),
}
return json.dumps(result, ensure_ascii=False, default=str)
async def readiness_check() -> str:
"""Readiness check - verifies database connectivity."""
result: dict[str, Any] = {
'status': 'unknown',
'checks': {},
} }
# Check database connectivity # Check cache database connection
try: try:
from ahserver.serverenv import ServerEnv
env = ServerEnv() env = ServerEnv()
dbname = env.get_module_dbname('sageapi') dbname = env.get_module_dbname('sageapi')
if not dbname:
if dbname: result['checks']['cache_db'] = {
async with DBPools().sqlorContext(dbname) as sor: 'status': 'fail',
await sor.sqlExe('SELECT 1') 'error': 'No database configured for sageapi module',
result['database'] = 'connected' }
else: else:
result['database'] = 'not_configured' async with DBPools().sqlorContext(dbname) as sor:
result['status'] = 'degraded' rows = await sor.sqlExe('SELECT 1 as ping')
result['checks']['cache_db'] = {
'status': 'ok',
'dbname': dbname,
}
except Exception as e: except Exception as e:
result['database'] = f'error: {str(e)}' error(f'readiness_check cache_db error: {e}')
result['status'] = 'unhealthy' result['checks']['cache_db'] = {
'status': 'fail',
'error': str(e),
}
# Cache stats # Check Sage database connection
try: try:
from ..cache.cache_manager import _get_cache_manager from sqlor.dbpools import get_sor_context
cm = _get_cache_manager() async with get_sor_context(env, 'sage') as sor:
result['cache'] = cm.stats() rows = await sor.sqlExe('SELECT 1 as ping')
except Exception: result['checks']['sage_db'] = {'status': 'ok'}
result['cache'] = {'error': 'cache not initialized'} except Exception as e:
error(f'readiness_check sage_db error: {e}')
result['checks']['sage_db'] = {
'status': 'fail',
'error': str(e),
}
# Check sync state
try:
async with DBPools().sqlorContext(dbname) as sor:
sql = """
SELECT entity_type, sync_status, last_sync_time
FROM sync_state
ORDER BY last_sync_time DESC
"""
rows = await sor.sqlExe(sql)
if isinstance(rows, list):
result['checks']['sync_status'] = {
'status': 'ok',
'entities': [
{
'entity_type': r.get('entity_type', ''),
'sync_status': r.get('sync_status', ''),
'last_sync_time': str(r.get('last_sync_time', '')),
}
for r in rows
],
}
except Exception as e:
result['checks']['sync_status'] = {
'status': 'fail',
'error': str(e),
}
# Overall status
all_ok = all(
check.get('status') == 'ok'
for check in result['checks'].values()
)
result['status'] = 'ready' if all_ok else 'degraded'
debug(f'Health check: status={result["status"]}')
return json.dumps(result, ensure_ascii=False, default=str) return json.dumps(result, ensure_ascii=False, default=str)

View File

@ -1,7 +1,6 @@
"""Pricing query API handler. """Pricing API handler.
Provides the RESTful endpoint for querying pricing information. Provides endpoint for querying cached pricing data.
Reads from the local pricing_cache table synced from Sage.
""" """
from __future__ import annotations from __future__ import annotations
@ -11,72 +10,110 @@ from typing import Any
from appPublic.log import debug, error from appPublic.log import debug, error
from sqlor.dbpools import DBPools from sqlor.dbpools import DBPools
from ahserver.serverenv import ServerEnv
async def query_pricing( async def query_pricing(
program_id: str | None = None, ppid: str | None = None,
model: str | None = None, llmid: str | None = None,
limit: int = 200, pricing_type: str | None = None,
offset: int = 0, status: str | None = None,
page: int = 1,
page_size: int = 50,
) -> str: ) -> str:
"""Query pricing information from the local cache. """Query pricing from cache with filters."""
Args:
program_id: Filter by pricing program ID.
model: Filter by model name (partial match).
limit: Maximum number of records to return.
offset: Number of records to skip.
Returns:
JSON string with success flag and pricing data.
"""
result: dict[str, Any] = {'success': False, 'data': [], 'total': 0} result: dict[str, Any] = {'success': False, 'data': [], 'total': 0}
try: try:
from ahserver.serverenv import ServerEnv
env = ServerEnv() env = ServerEnv()
dbname = env.get_module_dbname('sageapi') dbname = env.get_module_dbname('sageapi')
if not dbname: if not dbname:
result['error'] = 'No database configured for sageapi module' result['error'] = 'No database configured for sageapi module'
return json.dumps(result, ensure_ascii=False, default=str) return json.dumps(result, ensure_ascii=False, default=str)
conditions = [] conditions = []
params: dict[str, Any] = {'limit': limit, 'offset': offset} params: dict[str, Any] = {}
if program_id: if ppid:
conditions.append('program_id = ${program_id}$') conditions.append('id = ${ppid}$')
params['program_id'] = program_id params['ppid'] = ppid
if model: if llmid:
conditions.append('model LIKE ${model}$') conditions.append('llmid = ${llmid}$')
params['model'] = f'%{model}%' params['llmid'] = llmid
if pricing_type:
conditions.append('pricing_type = ${pricing_type}$')
params['pricing_type'] = pricing_type
if status:
conditions.append('status = ${status}$')
params['status'] = status
else:
conditions.append("status = 'active'")
where_clause = 'WHERE ' + ' AND '.join(conditions) if conditions else '' where = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
count_sql = f"SELECT COUNT(*) as cnt FROM pricing_cache {where}"
offset = (page - 1) * page_size
data_sql = f"""
SELECT id, llmid, model_name, pricing_type,
input_price, output_price, unit_price,
currency, status, effective_from, effective_to,
cached_at
FROM pricing_cache
{where}
ORDER BY model_name
LIMIT {page_size} OFFSET {offset}
"""
# Count query
count_sql = f'SELECT COUNT(*) as cnt FROM pricing_cache {where_clause}'
async with DBPools().sqlorContext(dbname) as sor: async with DBPools().sqlorContext(dbname) as sor:
count_rows = await sor.sqlExe(count_sql, params) count_result = await sor.sqlExe(count_sql, params)
total = count_rows[0]['cnt'] if count_rows else 0 if isinstance(count_result, list) and count_result:
result['total'] = total result['total'] = count_result[0].get('cnt', 0)
elif isinstance(count_result, dict):
result['total'] = count_result.get('cnt', 0)
if total > 0: data = await sor.sqlExe(data_sql, params)
data_sql = f""" if isinstance(data, dict):
SELECT program_id, model, input_price, output_price, result['data'] = data.get('rows', [])
unit, currency, updated_at elif isinstance(data, list):
FROM pricing_cache result['data'] = data
{where_clause}
ORDER BY program_id, model
LIMIT ${limit}$ OFFSET ${offset}$
"""
rows = await sor.sqlExe(data_sql, params)
result['data'] = [dict(r) for r in (rows or [])]
result['success'] = True result['success'] = True
debug(f'Pricing query: returned {result["total"]} records')
except Exception as e: except Exception as e:
error(f'Pricing query failed: {e}') error(f'query_pricing error: {e}')
result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str)
async def get_pricing_by_llmid(llmid: str) -> str:
"""Get all active pricing entries for a specific model."""
result: dict[str, Any] = {'success': False, 'data': []}
try:
env = ServerEnv()
dbname = env.get_module_dbname('sageapi')
sql = """
SELECT id, llmid, model_name, pricing_type,
input_price, output_price, unit_price,
currency, status, cached_at
FROM pricing_cache
WHERE llmid = ${llmid}$ AND status = 'active'
ORDER BY pricing_type
"""
async with DBPools().sqlorContext(dbname) as sor:
data = await sor.sqlExe(sql, {'llmid': llmid})
if isinstance(data, list):
result['data'] = data
result['success'] = True
elif isinstance(data, dict):
result['data'] = data.get('rows', [])
result['success'] = True
except Exception as e:
error(f'get_pricing_by_llmid error: {e}')
result['error'] = str(e) result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str) return json.dumps(result, ensure_ascii=False, default=str)

View File

@ -1,7 +1,6 @@
"""User query API handler. """Users API handler.
Provides the RESTful endpoint for querying user information. Provides endpoint for querying cached user data.
Reads from the local users_cache table synced from Sage.
""" """
from __future__ import annotations from __future__ import annotations
@ -11,73 +10,91 @@ from typing import Any
from appPublic.log import debug, error from appPublic.log import debug, error
from sqlor.dbpools import DBPools from sqlor.dbpools import DBPools
from ahserver.serverenv import ServerEnv
async def query_users( async def query_users(keyword: str | None = None, orgid: str | None = None, page: int = 1, page_size: int = 50) -> str:
user_id: str | None = None, """Query users from cache with keyword search."""
keyword: str | None = None,
limit: int = 100,
offset: int = 0,
) -> str:
"""Query user information from the local cache.
Args:
user_id: Filter by specific user ID.
keyword: Search keyword (matches username, email, or phone).
limit: Maximum number of records to return.
offset: Number of records to skip.
Returns:
JSON string with success flag and user data.
"""
result: dict[str, Any] = {'success': False, 'data': [], 'total': 0} result: dict[str, Any] = {'success': False, 'data': [], 'total': 0}
try: try:
from ahserver.serverenv import ServerEnv
env = ServerEnv() env = ServerEnv()
dbname = env.get_module_dbname('sageapi') dbname = env.get_module_dbname('sageapi')
if not dbname: if not dbname:
result['error'] = 'No database configured for sageapi module' result['error'] = 'No database configured for sageapi module'
return json.dumps(result, ensure_ascii=False, default=str) return json.dumps(result, ensure_ascii=False, default=str)
conditions = [] conditions = []
params: dict[str, Any] = {'limit': limit, 'offset': offset} params: dict[str, Any] = {}
if user_id:
conditions.append('user_id = ${user_id}$')
params['user_id'] = user_id
if keyword: if keyword:
conditions.append( conditions.append("username LIKE ${keyword}$")
'(username LIKE ${keyword}$ OR email LIKE ${keyword}$ OR phone LIKE ${keyword}$)'
)
params['keyword'] = f'%{keyword}%' params['keyword'] = f'%{keyword}%'
if orgid:
conditions.append('orgid = ${orgid}$')
params['orgid'] = orgid
where_clause = 'WHERE ' + ' AND '.join(conditions) if conditions else '' where = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
count_sql = f"SELECT COUNT(*) as cnt FROM users_cache {where}"
offset = (page - 1) * page_size
data_sql = f"""
SELECT id, username, orgid, orgname, email, phone,
status, created_at, updated_at, cached_at
FROM users_cache
{where}
ORDER BY username
LIMIT {page_size} OFFSET {offset}
"""
# Count query
count_sql = f'SELECT COUNT(*) as cnt FROM users_cache {where_clause}'
async with DBPools().sqlorContext(dbname) as sor: async with DBPools().sqlorContext(dbname) as sor:
count_rows = await sor.sqlExe(count_sql, params) count_result = await sor.sqlExe(count_sql, params)
total = count_rows[0]['cnt'] if count_rows else 0 if isinstance(count_result, list) and count_result:
result['total'] = total result['total'] = count_result[0].get('cnt', 0)
elif isinstance(count_result, dict):
result['total'] = count_result.get('cnt', 0)
if total > 0: data = await sor.sqlExe(data_sql, params)
data_sql = f""" if isinstance(data, dict):
SELECT user_id, username, email, phone, status, updated_at result['data'] = data.get('rows', [])
FROM users_cache elif isinstance(data, list):
{where_clause} result['data'] = data
ORDER BY user_id
LIMIT ${limit}$ OFFSET ${offset}$
"""
rows = await sor.sqlExe(data_sql, params)
result['data'] = [dict(r) for r in (rows or [])]
result['success'] = True result['success'] = True
debug(f'User query: returned {result["total"]} records')
except Exception as e: except Exception as e:
error(f'User query failed: {e}') error(f'query_users error: {e}')
result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str)
async def get_user_by_id(user_id: str) -> str:
"""Get a single user by ID."""
result: dict[str, Any] = {'success': False, 'data': None}
try:
env = ServerEnv()
dbname = env.get_module_dbname('sageapi')
sql = """
SELECT id, username, orgid, orgname, email, phone,
status, created_at, updated_at, cached_at
FROM users_cache
WHERE id = ${user_id}$
"""
async with DBPools().sqlorContext(dbname) as sor:
data = await sor.sqlExe(sql, {'user_id': user_id})
if isinstance(data, list) and data:
result['data'] = data[0]
result['success'] = True
elif isinstance(data, dict) and data.get('rows'):
result['data'] = data['rows'][0]
result['success'] = True
except Exception as e:
error(f'get_user_by_id error: {e}')
result['error'] = str(e) result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str) return json.dumps(result, ensure_ascii=False, default=str)

Binary file not shown.

Binary file not shown.

View File

@ -2,40 +2,69 @@
Registers all public functions to ServerEnv so they are accessible Registers all public functions to ServerEnv so they are accessible
from dspy scripts and other modules via the global environment. from dspy scripts and other modules via the global environment.
Also sets up route registration and database event bindings.
""" """
from appPublic.log import debug from appPublic.log import debug, info
from sqlor.dbpools import DBPools from sqlor.dbpools import DBPools
from ahserver.serverenv import ServerEnv from ahserver.serverenv import ServerEnv
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Auth # Auth
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
from .auth.dapi_auth import dapi_auth_middleware from .middleware.dapi_auth import authenticate_request, DapiAuthMiddleware
from .auth.uapi_sign import uapi_sign_verify
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Sync # Sync
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
from .sync.base_sync import BaseSync from .sync.base_sync import BaseSync, run_all_syncs
from .sync.user_sync import sync_users from .sync.user_sync import UserSync
from .sync.pricing_sync import sync_pricing from .sync.pricing_sync import PricingSync
from .sync.uapi_sync import sync_uapi from .sync.uapi_sync import UapiSync
from .sync.llmage_sync import sync_llmage from .sync.llmage_sync import LlmageSync
# Module-level convenience functions for sync
async def sync_users() -> str:
"""Convenience: run user sync."""
syncer = UserSync()
return await syncer.sync()
async def sync_pricing() -> str:
"""Convenience: run pricing sync."""
syncer = PricingSync()
return await syncer.sync()
async def sync_uapi() -> str:
"""Convenience: run uapi sync."""
syncer = UapiSync()
return await syncer.sync()
async def sync_llmage() -> str:
"""Convenience: run llmage sync."""
syncer = LlmageSync()
return await syncer.sync()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Cache # Cache
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
from .cache.cache_manager import CacheManager from .cache.cache_manager import CacheManager
# Global cache instance (per-process)
_cache_manager = CacheManager(max_entries=10000, default_ttl=300)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# API # API
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
from .api.balance import get_customer_balance from .api.balance import get_customer_balance, update_customer_balance
from .api.accounting import create_accounting_record, query_accounting_records from .api.accounting import create_accounting_record, query_accounting_records
from .api.users import query_users from .api.users import query_users, get_user_by_id
from .api.pricing import query_pricing from .api.pricing import query_pricing, get_pricing_by_llmid
from .api.health import health_check from .api.health import health_check, readiness_check
# ---------------------------------------------------------------------------
# Router
# ---------------------------------------------------------------------------
from .router import Router, setup_routes
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Utils # Utils
@ -45,66 +74,79 @@ from .utils.crypto import encrypt_payload, decrypt_payload
def _bind_sageapi_events(dbpools: DBPools, dbname: str) -> None: def _bind_sageapi_events(dbpools: DBPools, dbname: str) -> None:
"""Bind database events to SageAPI cache invalidation handlers. """Bind database events to SageAPI cache invalidation handlers."""
When sync state or accounting records change in the database,
the corresponding cache entries are invalidated automatically.
"""
bindings = [ bindings = [
# sync_state table: clear sync-related caches on change (f'{dbname}:sync_state:c:after', _cache_manager.invalidate_sync_state),
(f'{dbname}:sync_state:c:after', CacheManager.invalidate_sync_state), (f'{dbname}:sync_state:u:after', _cache_manager.invalidate_sync_state),
(f'{dbname}:sync_state:u:after', CacheManager.invalidate_sync_state), (f'{dbname}:sync_state:d:after', _cache_manager.invalidate_sync_state),
(f'{dbname}:sync_state:d:after', CacheManager.invalidate_sync_state), (f'{dbname}:accounting_records:c:after', _cache_manager.invalidate_accounting),
# accounting_records: clear accounting cache on change (f'{dbname}:accounting_records:u:after', _cache_manager.invalidate_accounting),
(f'{dbname}:accounting_records:c:after', CacheManager.invalidate_accounting), (f'{dbname}:accounting_records:d:after', _cache_manager.invalidate_accounting),
(f'{dbname}:accounting_records:u:after', CacheManager.invalidate_accounting),
(f'{dbname}:accounting_records:d:after', CacheManager.invalidate_accounting),
] ]
for event_name, handler in bindings: for event_name, handler in bindings:
dbpools.bind(event_name, handler) try:
debug(f'SageAPI event bound: {event_name}') dbpools.bind(event_name, handler)
debug(f'SageAPI event bound: {event_name}')
except Exception as e:
debug(f'SageAPI event bind skipped: {event_name} ({e})')
def load_sageapi() -> None: def load_sageapi() -> None:
"""Register all SageAPI functions into ServerEnv. """Register all SageAPI functions into ServerEnv.
Called by the Sage server during module loading phase. Called by the Sage server during module loading phase.
All registered functions become available as globals in dspy scripts.
""" """
env = ServerEnv() env = ServerEnv()
# Auth # Auth
env.dapi_auth_middleware = dapi_auth_middleware env.authenticate_request = authenticate_request
env.uapi_sign_verify = uapi_sign_verify env.DapiAuthMiddleware = DapiAuthMiddleware
# Sync # Sync
env.sync_users = sync_users env.sync_users = sync_users
env.sync_pricing = sync_pricing env.sync_pricing = sync_pricing
env.sync_uapi = sync_uapi env.sync_uapi = sync_uapi
env.sync_llmage = sync_llmage env.sync_llmage = sync_llmage
env.run_all_syncs = run_all_syncs
env.BaseSync = BaseSync env.BaseSync = BaseSync
env.UserSync = UserSync
env.PricingSync = PricingSync
env.UapiSync = UapiSync
env.LlmageSync = LlmageSync
# Cache # Cache
env.cache_manager = CacheManager() env.cache_manager = _cache_manager
# API # API
env.get_customer_balance = get_customer_balance env.get_customer_balance = get_customer_balance
env.update_customer_balance = update_customer_balance
env.create_accounting_record = create_accounting_record env.create_accounting_record = create_accounting_record
env.query_accounting_records = query_accounting_records env.query_accounting_records = query_accounting_records
env.query_users = query_users env.query_users = query_users
env.get_user_by_id = get_user_by_id
env.query_pricing = query_pricing env.query_pricing = query_pricing
env.get_pricing_by_llmid = get_pricing_by_llmid
env.health_check = health_check env.health_check = health_check
env.readiness_check = readiness_check
# Router
router = Router()
setup_routes(router)
env.sageapi_router = router
info(f'SageAPI: {len(router.get_routes())} routes registered')
# Utils # Utils
env.SageHttpClient = SageHttpClient env.SageHttpClient = SageHttpClient
env.encrypt_payload = encrypt_payload env.encrypt_payload = encrypt_payload
env.decrypt_payload = decrypt_payload env.decrypt_payload = decrypt_payload
# Bind database events for automatic cache invalidation # Bind database events
dbpools = DBPools() dbpools = DBPools()
dbname = env.get_module_dbname('sageapi') dbname = env.get_module_dbname('sageapi')
if dbname: if dbname:
_bind_sageapi_events(dbpools, dbname) _bind_sageapi_events(dbpools, dbname)
debug(f'SageAPI event listeners bound for database: {dbname}') info(f'SageAPI: event listeners bound for database: {dbname}')
else: else:
debug('SageAPI event listeners skipped: no database configured for sageapi module') debug('SageAPI: event listeners skipped (no database configured)')
info('SageAPI module loaded successfully')

View File

@ -0,0 +1 @@
# Middleware package

View File

@ -0,0 +1,254 @@
"""
DAPI Authentication Middleware for sageapi.
Authenticates incoming requests using DAPI signature headers by querying
the Sage downapikey table.
Usage with FastAPI / Starlette:
from sageapi.middleware.dapi_auth import DapiAuthMiddleware
app.add_middleware(DapiAuthMiddleware)
# Or for specific routes, use the get_dapi_key() dependency.
"""
import hashlib
import hmac
import time
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Optional
from starlette.datastructures import Headers
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse
# Headers expected from the client
HEADER_API_KEY = "X-DAPI-Key"
HEADER_TIMESTAMP = "X-DAPI-Timestamp"
HEADER_SIGNATURE = "X-DAPI-Signature"
# Default timestamp tolerance: 5 minutes
DEFAULT_TIMESTAMP_TOLERANCE_SEC = 300
@dataclass
class DapiKeyRecord:
"""Represents a record from the downapikey table."""
id: Any
apikey: str
secret: str
status: str
expire_date: Any # datetime or string
description: str = ""
@property
def is_active(self) -> bool:
return self.status == "active"
@property
def is_expired(self) -> bool:
"""Check if the key has expired based on expire_date."""
from datetime import datetime, timezone
if self.expire_date is None:
return False
# Handle both datetime objects and ISO string
if isinstance(self.expire_date, str):
try:
expire_dt = datetime.fromisoformat(self.expire_date.replace("Z", "+00:00"))
except (ValueError, AttributeError):
return False
else:
expire_dt = self.expire_date
# Ensure expire_dt is timezone-aware (UTC)
if expire_dt.tzinfo is None:
expire_dt = expire_dt.replace(tzinfo=timezone.utc)
now = datetime.now(timezone.utc)
return now > expire_dt
class DapiAuthError(Exception):
"""Raised when DAPI authentication fails."""
def __init__(self, status_code: int, detail: str):
self.status_code = status_code
self.detail = detail
super().__init__(detail)
def compute_dapi_signature(method: str, path: str, timestamp: str, secret: str, body: Optional[bytes] = None) -> str:
"""
Compute the DAPI HMAC-SHA256 signature.
The signed string is: "{method}\n{path}\n{timestamp}\n{body_hash}"
where body_hash is SHA-256 hex digest of the request body (or empty string if no body).
"""
if body:
body_hash = hashlib.sha256(body).hexdigest()
else:
body_hash = ""
string_to_sign = f"{method}\n{path}\n{timestamp}\n{body_hash}"
signature = hmac.new(
secret.encode("utf-8"),
string_to_sign.encode("utf-8"),
hashlib.sha256,
).hexdigest()
return signature
def verify_timestamp(timestamp_str: str, tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC) -> bool:
"""Verify that the timestamp is within the allowed window."""
try:
ts = float(timestamp_str)
except (ValueError, TypeError):
return False
now = time.time()
return abs(now - ts) <= tolerance_sec
async def lookup_api_key(api_key: str) -> Optional[DapiKeyRecord]:
"""
Look up an API key in the Sage downapikey table.
Returns a DapiKeyRecord if found, None otherwise.
"""
from ahserver.serverenv import ServerEnv
from sqlor.dbpools import get_sor_context
env = ServerEnv()
async with get_sor_context(env, "dapi") as sor:
recs = await sor.R("downapikey", {"apikey": api_key})
if not recs:
return None
rec = recs[0] # Take the first match
return DapiKeyRecord(
id=rec.get("id"),
apikey=rec.get("apikey", ""),
secret=rec.get("secret", ""),
status=rec.get("status", "inactive"),
expire_date=rec.get("expire_date"),
description=rec.get("description", ""),
)
async def authenticate_request(
request: Request,
tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC,
) -> DapiKeyRecord:
"""
Authenticate a request using DAPI headers.
Returns the DapiKeyRecord on success.
Raises DapiAuthError on failure.
"""
headers = request.headers
# 1. Check required headers
api_key = headers.get(HEADER_API_KEY)
if not api_key:
raise DapiAuthError(401, f"Missing header: {HEADER_API_KEY}")
timestamp_str = headers.get(HEADER_TIMESTAMP)
if not timestamp_str:
raise DapiAuthError(401, f"Missing header: {HEADER_TIMESTAMP}")
signature = headers.get(HEADER_SIGNATURE)
if not signature:
raise DapiAuthError(401, f"Missing header: {HEADER_SIGNATURE}")
# 2. Validate timestamp window
if not verify_timestamp(timestamp_str, tolerance_sec):
raise DapiAuthError(401, "Request timestamp is outside the allowed window")
# 3. Look up the API key in the database
key_record = await lookup_api_key(api_key)
if key_record is None:
raise DapiAuthError(401, "Invalid API key")
# 4. Check key status
if not key_record.is_active:
raise DapiAuthError(403, "API key is inactive")
# 5. Check expiration
if key_record.is_expired:
raise DapiAuthError(403, "API key has expired")
# 6. Read request body for signature verification
body = await request.body()
# 7. Compute expected signature and compare
expected_signature = compute_dapi_signature(
method=request.method,
path=request.url.path,
timestamp=timestamp_str,
secret=key_record.secret,
body=body if body else None,
)
if not hmac.compare_digest(signature, expected_signature):
raise DapiAuthError(401, "Invalid signature")
return key_record
class DapiAuthMiddleware(BaseHTTPMiddleware):
"""
Starlette/FastAPI middleware that enforces DAPI authentication on all requests.
Attributes:
exclude_paths: List of path prefixes to skip authentication (e.g., ['/health', '/docs']).
tolerance_sec: Timestamp tolerance in seconds (default: 300).
"""
def __init__(
self,
app: Any,
exclude_paths: Optional[list[str]] = None,
tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC,
):
super().__init__(app)
self.exclude_paths = exclude_paths or []
self.tolerance_sec = tolerance_sec
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> JSONResponse:
# Skip excluded paths
for prefix in self.exclude_paths:
if request.url.path.startswith(prefix):
return await call_next(request)
try:
key_record = await authenticate_request(request, self.tolerance_sec)
# Attach the key record to request state for downstream use
request.state.dapi_key = key_record
request.state.dapi_key_id = key_record.id
except DapiAuthError as e:
return JSONResponse(
status_code=e.status_code,
content={"error": e.detail},
)
return await call_next(request)
def requires_dapi_auth(
request: Request,
tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC,
) -> Awaitable[DapiKeyRecord]:
"""
Dependency function for FastAPI route-level DAPI authentication.
Usage:
@app.get("/protected")
async def protected_route(key: DapiKeyRecord = Depends(requires_dapi_auth)):
...
"""
return authenticate_request(request, tolerance_sec)

View File

@ -29,10 +29,10 @@ class Router:
Args: Args:
method: HTTP method (GET, POST, PUT, DELETE). method: HTTP method (GET, POST, PUT, DELETE).
path: URL path pattern, e.g. '/api/v1/balance'. path: URL path pattern.
handler: Callable that handles the request. handler: Callable that handles the request.
auth: Authentication method ('dapi', 'uapi', 'none'). auth: Authentication method ('dapi', 'uapi', 'none').
description: Human-readable description of the endpoint. description: Human-readable description.
""" """
self._routes.append({ self._routes.append({
'method': method.upper(), 'method': method.upper(),
@ -55,64 +55,75 @@ class Router:
return None return None
# Global router instance def setup_routes(router: Router) -> None:
router = Router() """Register all SageAPI routes.
Health endpoints (no auth):
GET /api/v1/health
GET /api/v1/health/ready
def register_routes() -> None: Balance endpoints (dapi auth):
"""Register all SageAPI API routes. GET /api/v1/balance
POST /api/v1/balance/update
Called during module initialization to populate the router Accounting endpoints (dapi auth):
with all available endpoints. POST /api/v1/accounting
GET /api/v1/accounting
Users endpoints (dapi auth):
GET /api/v1/users
GET /api/v1/users/{user_id}
Pricing endpoints (dapi auth):
GET /api/v1/pricing
GET /api/v1/pricing/model/{llmid}
""" """
from .api.health import health_check # Health (no auth)
from .api.balance import get_customer_balance from sageapi.api.health import health_check, readiness_check
from .api.accounting import create_accounting_record, query_accounting_records router.register('GET', '/api/v1/health', health_check, auth='none', description='Health check')
from .api.users import query_users router.register('GET', '/api/v1/health/ready', readiness_check, auth='none', description='Readiness check')
from .api.pricing import query_pricing
# Health check (no auth required) # Balance
router.register( from sageapi.api.balance import get_customer_balance, update_customer_balance
'GET', '/api/v1/health', router.register('GET', '/api/v1/balance', get_customer_balance, auth='dapi', description='Query customer balance')
handler=health_check, router.register('POST', '/api/v1/balance/update', update_customer_balance, auth='dapi', description='Update customer balance')
auth='none',
description='Health check endpoint',
)
# Customer balance
router.register(
'GET', '/api/v1/balance',
handler=get_customer_balance,
auth='dapi',
description='Query customer balance',
)
# Accounting # Accounting
router.register( from sageapi.api.accounting import create_accounting_record, query_accounting_records
'POST', '/api/v1/accounting', router.register('POST', '/api/v1/accounting', create_accounting_record, auth='dapi', description='Create accounting record')
handler=create_accounting_record, router.register('GET', '/api/v1/accounting', query_accounting_records, auth='dapi', description='Query accounting records')
auth='dapi',
description='Create an accounting record',
)
router.register(
'GET', '/api/v1/accounting',
handler=query_accounting_records,
auth='dapi',
description='Query accounting records',
)
# Users # Users
router.register( from sageapi.api.users import query_users, get_user_by_id
'GET', '/api/v1/users', router.register('GET', '/api/v1/users', query_users, auth='dapi', description='Query users')
handler=query_users, router.register('GET', '/api/v1/users/detail', get_user_by_id, auth='dapi', description='Get user by ID')
auth='dapi',
description='Query user information',
)
# Pricing # Pricing
router.register( from sageapi.api.pricing import query_pricing, get_pricing_by_llmid
'GET', '/api/v1/pricing', router.register('GET', '/api/v1/pricing', query_pricing, auth='dapi', description='Query pricing')
handler=query_pricing, router.register('GET', '/api/v1/pricing/model', get_pricing_by_llmid, auth='dapi', description='Get pricing by model ID')
auth='dapi',
description='Query pricing information', # Admin (dapi auth with admin role)
) from sageapi.sync.base_sync import run_all_syncs
router.register('POST', '/api/v1/admin/sync', run_all_syncs, auth='dapi', description='Trigger full sync')
router.register('GET', '/api/v1/admin/sync/status', _sync_status, auth='dapi', description='Sync status')
async def _sync_status() -> str:
"""Return current sync status for all entities."""
import json
from sqlor.dbpools import DBPools
from ahserver.serverenv import ServerEnv
result = {'success': False, 'data': []}
try:
env = ServerEnv()
dbname = env.get_module_dbname('sageapi')
sql = "SELECT entity_type, sync_status, last_sync_time, error_msg FROM sync_state ORDER BY entity_type"
async with DBPools().sqlorContext(dbname) as sor:
rows = await sor.sqlExe(sql)
result['data'] = rows if isinstance(rows, list) else rows.get('rows', [])
result['success'] = True
except Exception as e:
result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str)

View File

@ -0,0 +1,16 @@
"""SageAPI sync engine package."""
from .base_sync import BaseSync, run_all_syncs
from .user_sync import UserSync
from .pricing_sync import PricingSync
from .uapi_sync import UapiSync
from .llmage_sync import LlmageSync
__all__ = [
'BaseSync',
'run_all_syncs',
'UserSync',
'PricingSync',
'UapiSync',
'LlmageSync',
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,125 +1,248 @@
"""Base synchronization class for SageAPI.
Provides the foundation for all data sync workers. Handles common
concerns: checkpoint management, retry logic, batch processing,
and error reporting.
""" """
BaseSync - Abstract base class for all Sage data sync modules.
from __future__ import annotations Provides:
- Checkpoint management (read/write sync_state table)
- Batch processing with configurable size
- Retry logic with exponential backoff
- Common sync flow orchestration
"""
import time import time
import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any, Dict, List, Optional
from appPublic.log import debug, error, info from sqlor.dbpools import get_sor_context
from ahserver.serverenv import ServerEnv
logger = logging.getLogger(__name__)
class BaseSync(ABC): class BaseSync(ABC):
"""Abstract base class for data synchronization workers. """
Abstract base class for incremental sync from Sage DB to local cache.
Each concrete sync subclass implements the data fetch and Subclasses must implement:
persist logic for a specific upstream data source. - fetch_incremental(sor, since_timestamp): fetch delta from Sage DB
- persist(sor, records): UPSERT records into local cache table
- get_latest_timestamp(records): extract max updated_at from records
""" """
def __init__(self, sync_name: str, batch_size: int = 500) -> None: # --- subclass overrides ---
self.sync_name = sync_name MODULE_NAME: str = "" # Sage module name (e.g. 'users', 'pricing')
self.batch_size = batch_size SOURCE_DBNAME: str = "sage" # db alias for Sage source DB
self._last_checkpoint: dict[str, Any] = {} CACHE_DBNAME: str = "sageapi" # db alias for local cache DB
BATCH_SIZE: int = 500 # batch size for persist
MAX_RETRIES: int = 3 # max retry attempts per batch
RETRY_DELAY: float = 1.0 # initial retry delay in seconds
@abstractmethod # sync_state table key
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]: STATE_KEY: str = ""
"""Fetch incremental data from the upstream source.
Args: def __init__(self, env: Optional[ServerEnv] = None):
since_timestamp: Only fetch records modified after this timestamp. self.env = env or ServerEnv()
None means full sync.
Returns: # ------------------------------------------------------------------ #
List of records to be persisted. # Checkpoint helpers sync_state table lives in the CACHE DB #
""" # ------------------------------------------------------------------ #
...
@abstractmethod async def _read_checkpoint(self) -> Optional[str]:
async def persist(self, records: list[dict[str, Any]]) -> int: """Read last sync timestamp from sync_state table."""
"""Persist fetched records to the local database. async with get_sor_context(self.env, self.CACHE_DBNAME) as sor:
recs = await sor.R('sync_state', {'state_key': self.STATE_KEY})
if recs and len(recs) > 0:
return recs[0].get('last_sync_ts')
return None
Args: async def _write_checkpoint(self, timestamp: str) -> None:
records: List of records to upsert. """Write new sync timestamp into sync_state table."""
async with get_sor_context(self.env, self.CACHE_DBNAME) as sor:
now_ts = str(int(time.time()))
existing = await sor.R('sync_state', {'state_key': self.STATE_KEY})
if existing and len(existing) > 0:
await sor.U('sync_state', {
'state_key': self.STATE_KEY,
'last_sync_ts': timestamp,
'updated_at': now_ts,
})
else:
await sor.C('sync_state', {
'state_key': self.STATE_KEY,
'last_sync_ts': timestamp,
'created_at': now_ts,
'updated_at': now_ts,
})
Returns: # ------------------------------------------------------------------ #
Number of records successfully persisted. # Retry wrapper #
""" # ------------------------------------------------------------------ #
...
@abstractmethod async def _with_retry(self, coro_func, *args, **kwargs) -> Any:
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None: """Execute an async function with exponential-backoff retry."""
"""Extract the latest modification timestamp from a batch of records. last_exc = None
delay = self.RETRY_DELAY
for attempt in range(1, self.MAX_RETRIES + 1):
try:
return await coro_func(*args, **kwargs)
except Exception as e:
last_exc = e
logger.warning(
"[%s] attempt %d/%d failed: %s",
self.__class__.__name__, attempt, self.MAX_RETRIES, e,
)
if attempt < self.MAX_RETRIES:
await self._sleep(delay)
delay *= 2
raise last_exc
Used to advance the sync checkpoint after successful persist. @staticmethod
""" async def _sleep(seconds: float) -> None:
... import asyncio
await asyncio.sleep(seconds)
async def _load_checkpoint(self) -> str | None: # ------------------------------------------------------------------ #
"""Load the last successful sync checkpoint timestamp. # Batch persist #
# ------------------------------------------------------------------ #
TODO: Implement checkpoint persistence (sync_state table). async def _persist_batch(self, sor, records: List[Dict]) -> int:
""" """Persist a single batch of records with retry."""
checkpoint = self._last_checkpoint.get(self.sync_name) async def _do():
debug(f'Sync {self.sync_name}: loaded checkpoint = {checkpoint}') await self.persist(sor, records)
return checkpoint await self._with_retry(_do)
return len(records)
async def _save_checkpoint(self, timestamp: str) -> None: async def persist_in_batches(self, sor, records: List[Dict]) -> int:
"""Save the sync checkpoint after a successful run. """Split records into batches and persist each with retry."""
total = 0
TODO: Implement checkpoint persistence (sync_state table). for i in range(0, len(records), self.BATCH_SIZE):
""" batch = records[i:i + self.BATCH_SIZE]
self._last_checkpoint[self.sync_name] = timestamp cnt = await self._persist_batch(sor, batch)
debug(f'Sync {self.sync_name}: saved checkpoint = {timestamp}') total += cnt
logger.info(
async def run(self) -> dict[str, Any]: "[%s] persisted batch %d/%d (%d records)",
"""Execute a full sync cycle. self.__class__.__name__,
i // self.BATCH_SIZE + 1,
Returns: (len(records) + self.BATCH_SIZE - 1) // self.BATCH_SIZE,
dict with keys: success, records_fetched, records_persisted, cnt,
error (if any), duration_seconds
"""
start = time.time()
result: dict[str, Any] = {
'sync_name': self.sync_name,
'success': False,
'records_fetched': 0,
'records_persisted': 0,
'error': None,
'duration_seconds': 0.0,
}
try:
checkpoint = await self._load_checkpoint()
info(f'Sync {self.sync_name}: starting (checkpoint={checkpoint})')
records = await self.fetch_incremental(since_timestamp=checkpoint)
result['records_fetched'] = len(records)
if records:
persisted = await self.persist(records)
result['records_persisted'] = persisted
latest_ts = self.get_latest_timestamp(records)
if latest_ts:
await self._save_checkpoint(latest_ts)
result['success'] = True
info(
f'Sync {self.sync_name}: completed — '
f'fetched={result["records_fetched"]}, '
f'persisted={result["records_persisted"]}'
) )
return total
except Exception as e: # ------------------------------------------------------------------ #
error(f'Sync {self.sync_name}: failed with error: {e}') # Main sync flow #
result['error'] = str(e) # ------------------------------------------------------------------ #
finally: async def sync(self) -> Dict[str, Any]:
result['duration_seconds'] = round(time.time() - start, 3) """
Full incremental sync flow:
1. Read checkpoint (last sync timestamp)
2. Fetch incremental records from Sage DB
3. Persist to local cache in batches
4. Update checkpoint
5. Return summary dict
"""
cls_name = self.__class__.__name__
logger.info("[%s] sync started", cls_name)
# 1. Read checkpoint
since_ts = await self._read_checkpoint()
logger.info("[%s] checkpoint: %s", cls_name, since_ts or "None (full sync)")
# 2. Fetch incremental from Sage source DB
async with get_sor_context(self.env, self.SOURCE_DBNAME) as sor:
records = await self.fetch_incremental(sor, since_ts)
if not records:
logger.info("[%s] no new records, sync done", cls_name)
return {
'module': self.MODULE_NAME,
'fetched': 0,
'persisted': 0,
'new_checkpoint': since_ts,
}
# 3. Extract latest timestamp
new_checkpoint = self.get_latest_timestamp(records)
logger.info(
"[%s] fetched %d records, latest_ts=%s",
cls_name, len(records), new_checkpoint,
)
# 4. Persist to cache DB
async with get_sor_context(self.env, self.CACHE_DBNAME) as cache_sor:
persisted = await self.persist_in_batches(cache_sor, records)
# 5. Update checkpoint
if new_checkpoint:
await self._write_checkpoint(new_checkpoint)
result = {
'module': self.MODULE_NAME,
'fetched': len(records),
'persisted': persisted,
'new_checkpoint': new_checkpoint,
}
logger.info("[%s] sync completed: %s", cls_name, result)
return result return result
# ------------------------------------------------------------------ #
# Abstract methods subclasses MUST implement #
# ------------------------------------------------------------------ #
@abstractmethod
async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
"""
Fetch incremental records from Sage DB since the given timestamp.
Args:
sor: sqlor context for the SOURCE DB
since_timestamp: ISO or unix timestamp string, or None for full sync
Returns:
List of record dicts
"""
...
@abstractmethod
async def persist(self, sor, records: List[Dict]) -> None:
"""
UPSERT records into the local cache table.
Args:
sor: sqlor context for the CACHE DB
records: list of dicts to upsert
"""
...
@abstractmethod
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
"""
Extract the maximum updated/modified timestamp from records.
Args:
records: list of record dicts
Returns:
Timestamp string or None if no records
"""
...
async def run_all_syncs() -> str:
"""Trigger sync for all entity types. Called by admin API or cron."""
import json
from sageapi.sync.user_sync import UserSync
from sageapi.sync.pricing_sync import PricingSync
from sageapi.sync.uapi_sync import UapiSync
from sageapi.sync.llmage_sync import LlmageSync
results = []
for sync_cls in [UserSync, PricingSync, UapiSync, LlmageSync]:
try:
syncer = sync_cls()
result = await syncer.sync()
results.append({'module': sync_cls.MODULE_NAME, 'status': 'ok', **result})
except Exception as e:
logger.exception("[%s] sync failed", sync_cls.__name__)
results.append({'module': sync_cls.MODULE_NAME, 'status': 'error', 'error': str(e)})
return json.dumps({'success': True, 'results': results}, ensure_ascii=False, default=str)

View File

@ -1,65 +1,142 @@
"""LLM image data synchronization for SageAPI.
Syncs LLM catalog and provider data from the upstream Sage
llmage module into the local llmage_cache table.
""" """
LlmageSync - Sync llm / llmcatelog / llm_api_map from Sage DB to llmage_cache.
from __future__ import annotations Source tables (Sage DB):
- llm
- llmcatelog
- llm_api_map
from typing import Any Target table (cache DB):
- llmage_cache
"""
import logging
from typing import Dict, List, Optional
from appPublic.log import debug, info
from .base_sync import BaseSync from .base_sync import BaseSync
logger = logging.getLogger(__name__)
class LlmageSync(BaseSync): class LlmageSync(BaseSync):
"""Incremental sync for llmage data from Sage upstream.""" MODULE_NAME = "llmage"
SOURCE_DBNAME = "sage"
CACHE_DBNAME = "sageapi"
STATE_KEY = "sync_llmage"
BATCH_SIZE = 500
def __init__(self, batch_size: int = 500) -> None: async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
super().__init__(sync_name='llmage', batch_size=batch_size)
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
"""Fetch llmage data updated since the last sync checkpoint.
TODO: Implement upstream API call to Sage /api/llmage endpoint.
""" """
debug(f'LlmageSync: fetching incremental data since {since_timestamp}') Fetch incremental data from llm, llmcatelog, and llm_api_map tables.
# Placeholder: call upstream Sage API Joins LLM model info with catalog and API mapping data.
return []
async def persist(self, records: list[dict[str, Any]]) -> int:
"""Upsert llmage records into llmage_cache table.
TODO: Implement database upsert logic.
""" """
if not records: if since_timestamp:
return 0 where_clause = f"WHERE l.updated_at > '{since_timestamp}' OR lc.updated_at > '{since_timestamp}' OR lam.updated_at > '{since_timestamp}'"
info(f'LlmageSync: persisting {len(records)} llmage records') else:
# Placeholder: upsert into llmage_cache where_clause = ""
return len(records)
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None: sql = f"""
"""Extract the maximum updated_at from the record batch.""" SELECT
l.id AS llm_id,
l.model_name,
l.model_version,
l.provider,
l.model_type,
l.status AS llm_status,
l.description AS llm_description,
l.created_at AS llm_created_at,
l.updated_at AS llm_updated_at,
lc.id AS catelog_id,
lc.catelog_name,
lc.catelog_code,
lc.sort_order AS catelog_sort,
lc.status AS catelog_status,
lc.updated_at AS catelog_updated_at,
lam.id AS api_map_id,
lam.api_name,
lam.api_endpoint,
lam.api_version,
lam.auth_type,
lam.rate_limit,
lam.status AS api_map_status,
lam.updated_at AS api_map_updated_at
FROM {sor.dbname}.llm l
LEFT JOIN {sor.dbname}.llmcatelog lc ON l.catelog_id = lc.id
LEFT JOIN {sor.dbname}.llm_api_map lam ON l.id = lam.llm_id
{where_clause}
ORDER BY COALESCE(l.updated_at, l.created_at) ASC
"""
records = await sor.sqlExe(sql, {})
return [dict(r) for r in records] if records else []
async def persist(self, sor, records: List[Dict]) -> None:
"""
UPSERT into llmage_cache using INSERT ... ON DUPLICATE KEY UPDATE.
The composite key is (llm_id, catelog_id, api_map_id).
"""
import time
synced_at = str(int(time.time()))
for rec in records:
rec['synced_at'] = synced_at
insert_sql = """
INSERT INTO llmage_cache (
llm_id, model_name, model_version, provider, model_type,
llm_status, llm_description, llm_created_at, llm_updated_at,
catelog_id, catelog_name, catelog_code, catelog_sort,
catelog_status, catelog_updated_at,
api_map_id, api_name, api_endpoint, api_version,
auth_type, rate_limit, api_map_status, api_map_updated_at,
synced_at
) VALUES (
${llm_id}$, ${model_name}$, ${model_version}$, ${provider}$, ${model_type}$,
${llm_status}$, ${llm_description}$, ${llm_created_at}$, ${llm_updated_at}$,
${catelog_id}$, ${catelog_name}$, ${catelog_code}$, ${catelog_sort}$,
${catelog_status}$, ${catelog_updated_at}$,
${api_map_id}$, ${api_name}$, ${api_endpoint}$, ${api_version}$,
${auth_type}$, ${rate_limit}$, ${api_map_status}$, ${api_map_updated_at}$,
${synced_at}$
)
ON DUPLICATE KEY UPDATE
model_name = VALUES(model_name),
model_version = VALUES(model_version),
provider = VALUES(provider),
model_type = VALUES(model_type),
llm_status = VALUES(llm_status),
llm_description = VALUES(llm_description),
llm_created_at = VALUES(llm_created_at),
llm_updated_at = VALUES(llm_updated_at),
catelog_name = VALUES(catelog_name),
catelog_code = VALUES(catelog_code),
catelog_sort = VALUES(catelog_sort),
catelog_status = VALUES(catelog_status),
catelog_updated_at= VALUES(catelog_updated_at),
api_name = VALUES(api_name),
api_endpoint = VALUES(api_endpoint),
api_version = VALUES(api_version),
auth_type = VALUES(auth_type),
rate_limit = VALUES(rate_limit),
api_map_status = VALUES(api_map_status),
api_map_updated_at= VALUES(api_map_updated_at),
synced_at = VALUES(synced_at)
"""
try:
await sor.execute(insert_sql, rec)
except Exception as e:
logger.warning(
"[%s] persist failed for llm_id=%s: %s",
self.__class__.__name__, rec.get('llm_id'), e,
)
raise
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
"""Extract the maximum updated_at from llm, catelog, or api_map records."""
if not records: if not records:
return None return None
timestamps = [r.get('updated_at') for r in records if r.get('updated_at')] latest = None
return max(timestamps) if timestamps else None for r in records:
for key in ('llm_updated_at', 'catelog_updated_at', 'api_map_updated_at'):
ts = r.get(key)
_llmage_sync_instance: LlmageSync | None = None if ts and (latest is None or str(ts) > str(latest)):
latest = str(ts)
return latest
def get_llmage_sync() -> LlmageSync:
"""Get or create the LlmageSync singleton."""
global _llmage_sync_instance
if _llmage_sync_instance is None:
_llmage_sync_instance = LlmageSync()
return _llmage_sync_instance
async def sync_llmage(since_timestamp: str | None = None) -> dict[str, Any]:
"""Run a llmage data sync cycle."""
syncer = get_llmage_sync()
if since_timestamp:
await syncer._save_checkpoint(since_timestamp)
return await syncer.run()

View File

@ -1,65 +1,123 @@
"""Pricing data synchronization for SageAPI.
Syncs pricing program and timing data from the upstream Sage
pricing module into the local pricing_cache table.
""" """
PricingSync - Sync pricing_program / pricing_program_timing from Sage DB to pricing_cache.
from __future__ import annotations Source tables (Sage DB):
- pricing_program
- pricing_program_timing
from typing import Any Target table (cache DB):
- pricing_cache
"""
import logging
from typing import Dict, List, Optional
from appPublic.log import debug, info
from .base_sync import BaseSync from .base_sync import BaseSync
logger = logging.getLogger(__name__)
class PricingSync(BaseSync): class PricingSync(BaseSync):
"""Incremental sync for pricing data from Sage upstream.""" MODULE_NAME = "pricing"
SOURCE_DBNAME = "sage"
CACHE_DBNAME = "sageapi"
STATE_KEY = "sync_pricing"
BATCH_SIZE = 500
def __init__(self, batch_size: int = 500) -> None: async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
super().__init__(sync_name='pricing', batch_size=batch_size)
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
"""Fetch pricing data updated since the last sync checkpoint.
TODO: Implement upstream API call to Sage /api/pricing endpoint.
""" """
debug(f'PricingSync: fetching incremental data since {since_timestamp}') Fetch incremental data from pricing_program and pricing_program_timing.
# Placeholder: call upstream Sage API Joins program info with timing/schedule data.
return []
async def persist(self, records: list[dict[str, Any]]) -> int:
"""Upsert pricing records into pricing_cache table.
TODO: Implement database upsert logic.
""" """
if not records: if since_timestamp:
return 0 where_clause = f"WHERE pp.updated_at > '{since_timestamp}' OR ppt.updated_at > '{since_timestamp}'"
info(f'PricingSync: persisting {len(records)} pricing records') else:
# Placeholder: upsert into pricing_cache where_clause = ""
return len(records)
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None: sql = f"""
"""Extract the maximum updated_at from the record batch.""" SELECT
pp.id AS program_id,
pp.program_name,
pp.program_code,
pp.program_type,
pp.status AS program_status,
pp.description,
pp.created_at AS program_created_at,
pp.updated_at AS program_updated_at,
ppt.id AS timing_id,
ppt.start_time,
ppt.end_time,
ppt.duration,
ppt.repeat_rule,
ppt.timezone,
ppt.status AS timing_status,
ppt.updated_at AS timing_updated_at
FROM {sor.dbname}.pricing_program pp
LEFT JOIN {sor.dbname}.pricing_program_timing ppt
ON pp.id = ppt.program_id
{where_clause}
ORDER BY COALESCE(pp.updated_at, pp.created_at) ASC
"""
records = await sor.sqlExe(sql, {})
return [dict(r) for r in records] if records else []
async def persist(self, sor, records: List[Dict]) -> None:
"""
UPSERT into pricing_cache using INSERT ... ON DUPLICATE KEY UPDATE.
The composite key is (program_id, timing_id).
"""
import time
synced_at = str(int(time.time()))
for rec in records:
rec['synced_at'] = synced_at
insert_sql = """
INSERT INTO pricing_cache (
program_id, program_name, program_code, program_type,
program_status, description, program_created_at, program_updated_at,
timing_id, start_time, end_time, duration,
repeat_rule, timezone, timing_status, timing_updated_at,
synced_at
) VALUES (
${program_id}$, ${program_name}$, ${program_code}$, ${program_type}$,
${program_status}$, ${description}$, ${program_created_at}$, ${program_updated_at}$,
${timing_id}$, ${start_time}$, ${end_time}$, ${duration}$,
${repeat_rule}$, ${timezone}$, ${timing_status}$, ${timing_updated_at}$,
${synced_at}$
)
ON DUPLICATE KEY UPDATE
program_name = VALUES(program_name),
program_code = VALUES(program_code),
program_type = VALUES(program_type),
program_status = VALUES(program_status),
description = VALUES(description),
program_created_at = VALUES(program_created_at),
program_updated_at = VALUES(program_updated_at),
start_time = VALUES(start_time),
end_time = VALUES(end_time),
duration = VALUES(duration),
repeat_rule = VALUES(repeat_rule),
timezone = VALUES(timezone),
timing_status = VALUES(timing_status),
timing_updated_at = VALUES(timing_updated_at),
synced_at = VALUES(synced_at)
"""
try:
await sor.execute(insert_sql, rec)
except Exception as e:
logger.warning(
"[%s] persist failed for program_id=%s: %s",
self.__class__.__name__, rec.get('program_id'), e,
)
raise
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
"""Extract the maximum updated_at from program or timing records."""
if not records: if not records:
return None return None
timestamps = [r.get('updated_at') for r in records if r.get('updated_at')] latest = None
return max(timestamps) if timestamps else None for r in records:
ts = r.get('program_updated_at') or r.get('timing_updated_at')
if ts and (latest is None or str(ts) > str(latest)):
_pricing_sync_instance: PricingSync | None = None latest = str(ts)
return latest
def get_pricing_sync() -> PricingSync:
"""Get or create the PricingSync singleton."""
global _pricing_sync_instance
if _pricing_sync_instance is None:
_pricing_sync_instance = PricingSync()
return _pricing_sync_instance
async def sync_pricing(since_timestamp: str | None = None) -> dict[str, Any]:
"""Run a pricing data sync cycle."""
syncer = get_pricing_sync()
if since_timestamp:
await syncer._save_checkpoint(since_timestamp)
return await syncer.run()

View File

@ -1,65 +1,129 @@
"""UAPI data synchronization for SageAPI.
Syncs uapi application and caller configuration from the upstream
Sage uapi module into the local uapi_cache table.
""" """
UAPISync - Sync uapi / upapp from Sage DB to uapi_cache.
from __future__ import annotations Source tables (Sage DB):
- uapi
- upapp
from typing import Any Target table (cache DB):
- uapi_cache
"""
import logging
from typing import Dict, List, Optional
from appPublic.log import debug, info
from .base_sync import BaseSync from .base_sync import BaseSync
logger = logging.getLogger(__name__)
class UAPISync(BaseSync): class UAPISync(BaseSync):
"""Incremental sync for uapi data from Sage upstream.""" MODULE_NAME = "uapi"
SOURCE_DBNAME = "sage"
CACHE_DBNAME = "sageapi"
STATE_KEY = "sync_uapi"
BATCH_SIZE = 500
def __init__(self, batch_size: int = 500) -> None: async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
super().__init__(sync_name='uapi', batch_size=batch_size)
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
"""Fetch uapi data updated since the last sync checkpoint.
TODO: Implement upstream API call to Sage /api/uapi endpoint.
""" """
debug(f'UAPISync: fetching incremental data since {since_timestamp}') Fetch incremental data from uapi and upapp tables.
# Placeholder: call upstream Sage API Joins API definitions with app registration data.
return []
async def persist(self, records: list[dict[str, Any]]) -> int:
"""Upsert uapi records into uapi_cache table.
TODO: Implement database upsert logic.
""" """
if not records: if since_timestamp:
return 0 where_clause = f"WHERE u.updated_at > '{since_timestamp}' OR up.updated_at > '{since_timestamp}'"
info(f'UAPISync: persisting {len(records)} uapi records') else:
# Placeholder: upsert into uapi_cache where_clause = ""
return len(records)
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None: sql = f"""
"""Extract the maximum updated_at from the record batch.""" SELECT
u.id AS uapi_id,
u.api_name,
u.api_path,
u.api_method,
u.api_version,
u.api_desc,
u.status AS uapi_status,
u.auth_required,
u.created_at AS uapi_created_at,
u.updated_at AS uapi_updated_at,
up.id AS upapp_id,
up.app_name,
up.app_code,
up.app_type,
up.app_desc,
up.app_owner,
up.status AS upapp_status,
up.updated_at AS upapp_updated_at
FROM {sor.dbname}.uapi u
LEFT JOIN {sor.dbname}.upapp up ON u.upapp_id = up.id
{where_clause}
ORDER BY COALESCE(u.updated_at, u.created_at) ASC
"""
records = await sor.sqlExe(sql, {})
return [dict(r) for r in records] if records else []
async def persist(self, sor, records: List[Dict]) -> None:
"""
UPSERT into uapi_cache using INSERT ... ON DUPLICATE KEY UPDATE.
The composite key is (uapi_id, upapp_id).
"""
import time
synced_at = str(int(time.time()))
for rec in records:
rec['synced_at'] = synced_at
insert_sql = """
INSERT INTO uapi_cache (
uapi_id, api_name, api_path, api_method, api_version,
api_desc, uapi_status, auth_required,
uapi_created_at, uapi_updated_at,
upapp_id, app_name, app_code, app_type,
app_desc, app_owner, upapp_status, upapp_updated_at,
synced_at
) VALUES (
${uapi_id}$, ${api_name}$, ${api_path}$, ${api_method}$, ${api_version}$,
${api_desc}$, ${uapi_status}$, ${auth_required}$,
${uapi_created_at}$, ${uapi_updated_at}$,
${upapp_id}$, ${app_name}$, ${app_code}$, ${app_type}$,
${app_desc}$, ${app_owner}$, ${upapp_status}$, ${upapp_updated_at}$,
${synced_at}$
)
ON DUPLICATE KEY UPDATE
api_name = VALUES(api_name),
api_path = VALUES(api_path),
api_method = VALUES(api_method),
api_version = VALUES(api_version),
api_desc = VALUES(api_desc),
uapi_status = VALUES(uapi_status),
auth_required = VALUES(auth_required),
uapi_created_at = VALUES(uapi_created_at),
uapi_updated_at = VALUES(uapi_updated_at),
app_name = VALUES(app_name),
app_code = VALUES(app_code),
app_type = VALUES(app_type),
app_desc = VALUES(app_desc),
app_owner = VALUES(app_owner),
upapp_status = VALUES(upapp_status),
upapp_updated_at = VALUES(upapp_updated_at),
synced_at = VALUES(synced_at)
"""
try:
await sor.execute(insert_sql, rec)
except Exception as e:
logger.warning(
"[%s] persist failed for uapi_id=%s: %s",
self.__class__.__name__, rec.get('uapi_id'), e,
)
raise
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
"""Extract the maximum updated_at from uapi or upapp records."""
if not records: if not records:
return None return None
timestamps = [r.get('updated_at') for r in records if r.get('updated_at')] latest = None
return max(timestamps) if timestamps else None for r in records:
for key in ('uapi_updated_at', 'upapp_updated_at'):
ts = r.get(key)
_uapi_sync_instance: UAPISync | None = None if ts and (latest is None or str(ts) > str(latest)):
latest = str(ts)
return latest
def get_uapi_sync() -> UAPISync:
"""Get or create the UAPISync singleton."""
global _uapi_sync_instance
if _uapi_sync_instance is None:
_uapi_sync_instance = UAPISync()
return _uapi_sync_instance
async def sync_uapi(since_timestamp: str | None = None) -> dict[str, Any]:
"""Run a uapi data sync cycle."""
syncer = get_uapi_sync()
if since_timestamp:
await syncer._save_checkpoint(since_timestamp)
return await syncer.run()

View File

@ -1,76 +1,143 @@
"""User data synchronization for SageAPI.
Syncs user data from the upstream Sage system into the local
users_cache table. Uses incremental sync based on updated_at
timestamp to minimize data transfer.
""" """
UserSync - Sync users / organi / organization from Sage DB to users_cache.
from __future__ import annotations Source tables (Sage DB):
- users
- organi
- organization
from typing import Any Target table (cache DB):
- users_cache
"""
import logging
from typing import Dict, List, Optional
from appPublic.log import debug, info
from .base_sync import BaseSync from .base_sync import BaseSync
logger = logging.getLogger(__name__)
class UserSync(BaseSync): class UserSync(BaseSync):
"""Incremental sync for user data from Sage upstream.""" MODULE_NAME = "users"
SOURCE_DBNAME = "sage"
CACHE_DBNAME = "sageapi"
STATE_KEY = "sync_users"
BATCH_SIZE = 500
def __init__(self, batch_size: int = 500) -> None: async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
super().__init__(sync_name='users', batch_size=batch_size)
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
"""Fetch users updated since the last sync checkpoint.
TODO: Implement upstream API call to Sage /api/users endpoint.
""" """
debug(f'UserSync: fetching incremental data since {since_timestamp}') Fetch incremental data from users, organi, organization tables.
# Placeholder: call upstream Sage API Uses LEFT JOIN to combine user data with organization info.
# GET /api/users?updated_at_gt={since_timestamp}&limit={batch_size}
return []
async def persist(self, records: list[dict[str, Any]]) -> int:
"""Upsert user records into users_cache table.
TODO: Implement database upsert logic.
""" """
if not records: if since_timestamp:
return 0 where_clause = f"WHERE u.updated_at > '{since_timestamp}' OR o.updated_at > '{since_timestamp}'"
info(f'UserSync: persisting {len(records)} user records') else:
# Placeholder: upsert into users_cache where_clause = ""
return len(records)
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None: sql = f"""
"""Extract the maximum updated_at from the record batch.""" SELECT
u.id AS user_id,
u.username,
u.email,
u.status AS user_status,
u.created_at AS user_created_at,
u.updated_at AS user_updated_at,
oi.organi_id AS organi_id,
oi.organi_name AS organi_name,
oi.parent_id AS organi_parent_id,
org.id AS org_id,
org.org_name AS org_name,
org.org_type AS org_type,
org.status AS org_status,
org.updated_at AS org_updated_at
FROM {sor.dbname}.users u
LEFT JOIN {sor.dbname}.organi oi ON u.organi_id = oi.id
LEFT JOIN {sor.dbname}.organization org ON oi.organization_id = org.id
{where_clause}
ORDER BY COALESCE(u.updated_at, u.created_at) ASC
"""
records = await sor.sqlExe(sql, {})
return [dict(r) for r in records] if records else []
async def persist(self, sor, records: List[Dict]) -> None:
"""
UPSERT into users_cache.
For each record: try UPDATE first; if no rows affected, INSERT.
"""
for rec in records:
# Try update first
update_sql = """
UPDATE users_cache SET
username = ${username}$,
email = ${email}$,
user_status = ${user_status}$,
user_created_at = ${user_created_at}$,
user_updated_at = ${user_updated_at}$,
organi_id = ${organi_id}$,
organi_name = ${organi_name}$,
organi_parent_id= ${organi_parent_id}$,
org_id = ${org_id}$,
org_name = ${org_name}$,
org_type = ${org_type}$,
org_status = ${org_status}$,
org_updated_at = ${org_updated_at}$,
synced_at = ${synced_at}$
WHERE user_id = ${user_id}$
"""
rec['synced_at'] = str(int(__import__('time').time()))
await sor.execute(update_sql, rec)
# If no rows updated (sqlor returns affected count), insert
# sqlor's execute returns the cursor; we check rowcount via sqlExe
# A simpler approach: use INSERT ... ON DUPLICATE KEY UPDATE
insert_sql = """
INSERT INTO users_cache (
user_id, username, email, user_status,
user_created_at, user_updated_at,
organi_id, organi_name, organi_parent_id,
org_id, org_name, org_type, org_status,
org_updated_at, synced_at
) VALUES (
${user_id}$, ${username}$, ${email}$, ${user_status}$,
${user_created_at}$, ${user_updated_at}$,
${organi_id}$, ${organi_name}$, ${organi_parent_id}$,
${org_id}$, ${org_name}$, ${org_type}$, ${org_status}$,
${org_updated_at}$, ${synced_at}$
)
ON DUPLICATE KEY UPDATE
username = VALUES(username),
email = VALUES(email),
user_status = VALUES(user_status),
user_created_at = VALUES(user_created_at),
user_updated_at = VALUES(user_updated_at),
organi_id = VALUES(organi_id),
organi_name = VALUES(organi_name),
organi_parent_id= VALUES(organi_parent_id),
org_id = VALUES(org_id),
org_name = VALUES(org_name),
org_type = VALUES(org_type),
org_status = VALUES(org_status),
org_updated_at = VALUES(org_updated_at),
synced_at = VALUES(synced_at)
"""
# Execute INSERT with ON DUPLICATE KEY UPDATE as safety net
# The UPDATE above handles most cases; this catches any race conditions
# Skip if the user_id is None
if rec.get('user_id') is not None:
try:
await sor.execute(insert_sql, rec)
except Exception:
# Duplicate key is expected if UPDATE already succeeded
pass
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
"""Extract the maximum updated_at from all records."""
if not records: if not records:
return None return None
timestamps = [r.get('updated_at') for r in records if r.get('updated_at')] latest = None
return max(timestamps) if timestamps else None for r in records:
ts = r.get('user_updated_at') or r.get('org_updated_at')
if ts and (latest is None or str(ts) > str(latest)):
# Module-level singleton instance latest = str(ts)
_user_sync_instance: UserSync | None = None return latest
def get_user_sync() -> UserSync:
"""Get or create the UserSync singleton."""
global _user_sync_instance
if _user_sync_instance is None:
_user_sync_instance = UserSync()
return _user_sync_instance
async def sync_users(since_timestamp: str | None = None) -> dict[str, Any]:
"""Run a user data sync cycle.
Args:
since_timestamp: Optional override for the checkpoint timestamp.
Returns:
Sync result dict from BaseSync.run().
"""
syncer = get_user_sync()
if since_timestamp:
# Override checkpoint for this run
await syncer._save_checkpoint(since_timestamp)
return await syncer.run()

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,115 +1,460 @@
"""HTTP client for upstream Sage API calls. """
SageHttpClient - Async HTTP client using aiohttp with DAPI signature support.
Provides a reusable async HTTP client with connection pooling, Used to call external APIs (e.g., LLM provider APIs). Not for calling Sage local interfaces.
retry logic, and automatic DAPI/UAPI authentication headers.
Features:
- Automatic DAPI signature header injection
- Connection pooling
- Retry with exponential backoff
- Configurable timeouts
- Support for GET, POST, PUT, DELETE, PATCH
Usage:
from sageapi.utils.http_client import SageHttpClient
client = SageHttpClient(
api_key="your-api-key",
api_secret="your-api-secret",
base_url="https://api.example.com",
max_retries=3,
timeout=30.0,
)
async with client:
resp = await client.post("/v1/chat", json={"message": "hello"})
data = await resp.json()
""" """
from __future__ import annotations import hashlib
import hmac
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Optional
import asyncio import aiohttp
from typing import Any from aiohttp import ClientTimeout
from appPublic.log import debug, error logger = logging.getLogger(__name__)
# Default configuration
DEFAULT_TIMEOUT = 30.0
DEFAULT_MAX_RETRIES = 3
DEFAULT_BACKOFF_FACTOR = 0.5
DEFAULT_MAX_CONNECTIONS = 100
DEFAULT_CONNECTION_POOL_LIMIT = 100
@dataclass
class RetryConfig:
"""Configuration for request retries."""
max_retries: int = DEFAULT_MAX_RETRIES
backoff_factor: float = DEFAULT_BACKOFF_FACTOR
retry_on_status: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 504})
retry_on_timeout: bool = True
retry_on_connection_error: bool = True
def compute_dapi_signature(
method: str,
path: str,
timestamp: str,
secret: str,
body: Optional[bytes] = None,
) -> str:
"""
Compute HMAC-SHA256 signature for DAPI authentication.
Signs: "{method}\\n{path}\\n{timestamp}\\n{body_hash}"
"""
if body:
body_hash = hashlib.sha256(body).hexdigest()
else:
body_hash = ""
string_to_sign = f"{method}\n{path}\n{timestamp}\n{body_hash}"
return hmac.new(
secret.encode("utf-8"),
string_to_sign.encode("utf-8"),
hashlib.sha256,
).hexdigest()
class DAPISigner:
"""Handles DAPI signature generation for outgoing requests."""
def __init__(self, api_key: str, api_secret: str):
self.api_key = api_key
self.api_secret = api_secret
def sign_request(
self,
method: str,
path: str,
body: Optional[bytes] = None,
) -> dict[str, str]:
"""
Generate DAPI authentication headers.
Returns dict with X-DAPI-Key, X-DAPI-Timestamp, X-DAPI-Signature.
"""
timestamp = str(time.time())
signature = compute_dapi_signature(
method=method.upper(),
path=path,
timestamp=timestamp,
secret=self.api_secret,
body=body,
)
return {
"X-DAPI-Key": self.api_key,
"X-DAPI-Timestamp": timestamp,
"X-DAPI-Signature": signature,
}
class SageHttpClient: class SageHttpClient:
"""Async HTTP client for calling upstream Sage APIs. """
Async HTTP client with DAPI signature support.
Handles connection pooling, authentication headers, and Uses aiohttp under the hood with connection pooling, retry logic,
retry logic for transient failures. and automatic DAPI header injection.
Args:
base_url: Base URL for all requests (e.g., "https://api.example.com").
api_key: DAPI API key for request signing.
api_secret: DAPI API secret for request signing.
timeout: Request timeout in seconds.
max_retries: Maximum number of retries on transient failures.
backoff_factor: Exponential backoff multiplier.
max_connections: Maximum number of connections in the pool.
headers: Additional default headers to include in every request.
signer: Optional custom DAPISigner instance. If None, one is created from api_key/api_secret.
Usage:
async with SageHttpClient(base_url="...", api_key="...", api_secret="...") as client:
resp = await client.get("/health")
resp = await client.post("/v1/chat", json={"msg": "hi"})
""" """
def __init__( def __init__(
self, self,
base_url: str = 'http://127.0.0.1:9180', base_url: str = "",
dapi_key: str = '', api_key: str = "",
dapi_secret: str = '', api_secret: str = "",
timeout: float = 30.0, timeout: float = DEFAULT_TIMEOUT,
max_retries: int = 3, max_retries: int = DEFAULT_MAX_RETRIES,
) -> None: backoff_factor: float = DEFAULT_BACKOFF_FACTOR,
self.base_url = base_url.rstrip('/') max_connections: int = DEFAULT_MAX_CONNECTIONS,
self.dapi_key = dapi_key headers: Optional[dict[str, str]] = None,
self.dapi_secret = dapi_secret signer: Optional[DAPISigner] = None,
auto_sign: bool = True,
):
self.base_url = base_url.rstrip("/")
self.timeout = timeout self.timeout = timeout
self.max_retries = max_retries self.max_retries = max_retries
self.backoff_factor = backoff_factor
self.default_headers = headers or {}
self.auto_sign = auto_sign
def _build_headers(self) -> dict[str, str]: if signer is not None:
"""Build request headers with DAPI authentication. self.signer = signer
else:
self.signer = DAPISigner(api_key=api_key, api_secret=api_secret)
TODO: Implement actual DAPI signature generation. # Internal state
self._session: Optional[aiohttp.ClientSession] = None
self._connector: Optional[aiohttp.TCPConnector] = None
self._connector_limit = max_connections
self._closed = False
def _build_url(self, path: str) -> str:
"""Build full URL from base_url and path."""
if path.startswith("http://") or path.startswith("https://"):
return path
return f"{self.base_url}/{path.lstrip('/')}"
def _get_relative_path(self, path: str) -> str:
"""Extract the relative path for signing purposes."""
if path.startswith("http://") or path.startswith("https://"):
from urllib.parse import urlparse
parsed = urlparse(path)
return parsed.path
return path
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create the aiohttp session with connection pooling."""
if self._session is None or self._session.closed:
self._connector = aiohttp.TCPConnector(
limit=self._connector_limit,
limit_per_host=self._connector_limit,
ttl_dns_cache=300,
keepalive_timeout=30,
)
self._session = aiohttp.ClientSession(
connector=self._connector,
timeout=ClientTimeout(total=self.timeout),
)
return self._session
async def _sign_and_prepare(
self,
method: str,
path: str,
headers: Optional[dict[str, str]] = None,
data: Any = None,
json_body: Any = None,
) -> tuple[str, dict[str, str], Optional[bytes]]:
""" """
import time Prepare request with DAPI signature headers.
import hashlib
import hmac
timestamp = str(int(time.time())) Returns (url, merged_headers, raw_body_bytes).
string_to_sign = f'{self.dapi_key}:{timestamp}' """
signature = hmac.new( url = self._build_url(path)
self.dapi_secret.encode('utf-8') if self.dapi_secret else b'', relative_path = self._get_relative_path(path)
string_to_sign.encode('utf-8'),
hashlib.sha256,
).hexdigest()
return { merged_headers = {**self.default_headers}
'Content-Type': 'application/json', if headers:
'X-DAPI-Key': self.dapi_key, merged_headers.update(headers)
'X-DAPI-Timestamp': timestamp,
'X-DAPI-Signature': signature, # Determine body for signing
raw_body: Optional[bytes] = None
if json_body is not None:
import json
raw_body = json.dumps(json_body).encode("utf-8")
merged_headers.setdefault("Content-Type", "application/json")
elif data is not None:
if isinstance(data, bytes):
raw_body = data
elif isinstance(data, str):
raw_body = data.encode("utf-8")
else:
# Form data - sign the encoded form
from urllib.parse import urlencode
raw_body = urlencode(data).encode("utf-8")
# Add DAPI signature headers if auto-sign is enabled and signer has credentials
if self.auto_sign and self.signer.api_key and self.signer.api_secret:
dapi_headers = self.signer.sign_request(
method=method,
path=relative_path,
body=raw_body,
)
merged_headers.update(dapi_headers)
return url, merged_headers, raw_body
async def _execute_with_retry(
self,
method: str,
url: str,
headers: dict[str, str],
**kwargs: Any,
) -> aiohttp.ClientResponse:
"""
Execute request with retry and exponential backoff.
"""
session = await self._get_session()
last_response: Optional[aiohttp.ClientResponse] = None
last_exception: Optional[Exception] = None
for attempt in range(self.max_retries + 1):
try:
response = await session.request(
method=method,
url=url,
headers=headers,
**kwargs,
)
# Check if we should retry based on status
if response.status in {429, 500, 502, 503, 504}:
last_response = response
if attempt < self.max_retries:
wait_time = self.backoff_factor * (2 ** attempt)
logger.warning(
"HTTP %s on %s %s, retrying in %.1fs (attempt %d/%d)",
response.status,
method,
url,
wait_time,
attempt + 1,
self.max_retries,
)
try:
response.release()
except Exception:
pass
await self._async_sleep(wait_time)
continue
else:
# Retries exhausted - raise an error
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"HTTP {response.status} after {self.max_retries + 1} attempts",
)
return response
except aiohttp.ServerTimeoutError as e:
last_exception = e
if attempt < self.max_retries:
wait_time = self.backoff_factor * (2 ** attempt)
logger.warning(
"Timeout on %s %s, retrying in %.1fs (attempt %d/%d)",
method,
url,
wait_time,
attempt + 1,
self.max_retries,
)
await self._async_sleep(wait_time)
continue
except (aiohttp.ClientConnectionError, aiohttp.ClientOSError) as e:
last_exception = e
if attempt < self.max_retries:
wait_time = self.backoff_factor * (2 ** attempt)
logger.warning(
"Connection error on %s %s, retrying in %.1fs (attempt %d/%d)",
method,
url,
wait_time,
attempt + 1,
self.max_retries,
)
await self._async_sleep(wait_time)
continue
except Exception:
raise
# All retries exhausted
if last_response:
return last_response
if last_exception:
raise last_exception
raise RuntimeError(f"Request failed after {self.max_retries + 1} attempts")
@staticmethod
async def _async_sleep(seconds: float) -> None:
"""Non-blocking sleep."""
import asyncio
await asyncio.sleep(seconds)
async def request(
self,
method: str,
path: str,
*,
headers: Optional[dict[str, str]] = None,
data: Any = None,
json: Any = None,
params: Optional[dict[str, Any]] = None,
timeout: Optional[float] = None,
allow_redirects: bool = True,
) -> aiohttp.ClientResponse:
"""
Send an HTTP request with DAPI signing and retry.
Args:
method: HTTP method (GET, POST, PUT, DELETE, PATCH).
path: URL path or full URL.
headers: Additional request headers.
data: Form data or raw bytes.
json: JSON-serializable body (automatically encoded).
params: Query string parameters.
timeout: Override timeout for this request.
allow_redirects: Whether to follow redirects.
Returns:
aiohttp.ClientResponse
"""
url, merged_headers, raw_body = await self._sign_and_prepare(
method=method,
path=path,
headers=headers,
data=data,
json_body=json,
)
request_kwargs: dict[str, Any] = {
"allow_redirects": allow_redirects,
} }
async def get( if raw_body is not None:
self, if json is not None:
path: str, request_kwargs["data"] = raw_body
params: dict[str, Any] | None = None, else:
headers: dict[str, str] | None = None, request_kwargs["data"] = raw_body
) -> Any: elif json is not None:
"""Send a GET request to the upstream Sage API. import json as _json
Args: request_kwargs["data"] = _json.dumps(json).encode("utf-8")
path: API path (relative to base_url).
params: Optional query parameters.
headers: Optional additional headers.
Returns: if data is not None and json is None:
Parsed JSON response. request_kwargs["data"] = data
"""
url = f'{self.base_url}{path}'
request_headers = {**self._build_headers(), **(headers or {})}
debug(f'HTTP GET {url} params={params}') if params:
request_kwargs["params"] = params
# TODO: Replace with actual async HTTP implementation # Per-request timeout override
# using aiohttp or the framework's built-in HTTP client. if timeout is not None:
# This is a placeholder that will be filled in once the request_kwargs["timeout"] = ClientTimeout(total=timeout)
# specific HTTP library choice is confirmed.
raise NotImplementedError( return await self._execute_with_retry(
'SageHttpClient.get: HTTP library not yet integrated. ' method=method.upper(),
'Implement with aiohttp or framework HTTP client.' url=url,
headers=merged_headers,
**request_kwargs,
) )
async def post( async def get(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
self, """Send a GET request."""
path: str, return await self.request("GET", path, **kwargs)
data: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
"""Send a POST request to the upstream Sage API.
Args: async def post(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
path: API path (relative to base_url). """Send a POST request."""
data: JSON body data. return await self.request("POST", path, **kwargs)
headers: Optional additional headers.
Returns: async def put(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
Parsed JSON response. """Send a PUT request."""
""" return await self.request("PUT", path, **kwargs)
url = f'{self.base_url}{path}'
request_headers = {**self._build_headers(), **(headers or {})}
debug(f'HTTP POST {url}') async def delete(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
"""Send a DELETE request."""
return await self.request("DELETE", path, **kwargs)
# TODO: Replace with actual async HTTP implementation. async def patch(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
raise NotImplementedError( """Send a PATCH request."""
'SageHttpClient.post: HTTP library not yet integrated. ' return await self.request("PATCH", path, **kwargs)
'Implement with aiohttp or framework HTTP client.'
) async def close(self) -> None:
"""Close the HTTP client and release resources."""
if self._session and not self._session.closed:
await self._session.close()
self._closed = True
async def __aenter__(self) -> "SageHttpClient":
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
def __del__(self) -> None:
# Best-effort cleanup; prefer using async context manager
if not self._closed and self._session and not self._session.closed:
logger.warning(
"SageHttpClient was not properly closed. "
"Use 'async with SageHttpClient(...)' for proper cleanup."
)

122
sync/cache_tables.sql Normal file
View File

@ -0,0 +1,122 @@
-- =============================================================================
-- SageAPI Cache Tables DDL
-- These tables store synchronized data from the Sage database.
-- Run against the sageapi database.
-- =============================================================================
-- Checkpoint table: tracks the last sync timestamp for each module
CREATE TABLE IF NOT EXISTS sync_state (
state_key VARCHAR(64) NOT NULL PRIMARY KEY,
last_sync_ts VARCHAR(64) DEFAULT NULL,
created_at VARCHAR(32) NOT NULL,
updated_at VARCHAR(32) NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- =============================================================================
-- users_cache: synced from Sage users / organi / organization tables
-- =============================================================================
CREATE TABLE IF NOT EXISTS users_cache (
user_id BIGINT NOT NULL PRIMARY KEY,
username VARCHAR(128) DEFAULT NULL,
email VARCHAR(256) DEFAULT NULL,
user_status INT DEFAULT 0,
user_created_at VARCHAR(32) DEFAULT NULL,
user_updated_at VARCHAR(32) DEFAULT NULL,
organi_id BIGINT DEFAULT NULL,
organi_name VARCHAR(256) DEFAULT NULL,
organi_parent_id BIGINT DEFAULT NULL,
org_id BIGINT DEFAULT NULL,
org_name VARCHAR(256) DEFAULT NULL,
org_type VARCHAR(64) DEFAULT NULL,
org_status INT DEFAULT 0,
org_updated_at VARCHAR(32) DEFAULT NULL,
synced_at VARCHAR(32) NOT NULL,
UNIQUE KEY uk_organi (organi_id),
KEY idx_updated (user_updated_at)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- =============================================================================
-- pricing_cache: synced from Sage pricing_program / pricing_program_timing
-- =============================================================================
CREATE TABLE IF NOT EXISTS pricing_cache (
program_id BIGINT NOT NULL,
program_name VARCHAR(256) DEFAULT NULL,
program_code VARCHAR(128) DEFAULT NULL,
program_type VARCHAR(64) DEFAULT NULL,
program_status INT DEFAULT 0,
description TEXT DEFAULT NULL,
program_created_at VARCHAR(32) DEFAULT NULL,
program_updated_at VARCHAR(32) DEFAULT NULL,
timing_id BIGINT DEFAULT NULL,
start_time VARCHAR(32) DEFAULT NULL,
end_time VARCHAR(32) DEFAULT NULL,
duration INT DEFAULT NULL,
repeat_rule VARCHAR(256) DEFAULT NULL,
timezone VARCHAR(64) DEFAULT NULL,
timing_status INT DEFAULT 0,
timing_updated_at VARCHAR(32) DEFAULT NULL,
synced_at VARCHAR(32) NOT NULL,
PRIMARY KEY (program_id, timing_id),
KEY idx_program_updated (program_updated_at)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- =============================================================================
-- uapi_cache: synced from Sage uapi / upapp tables
-- =============================================================================
CREATE TABLE IF NOT EXISTS uapi_cache (
uapi_id BIGINT NOT NULL,
api_name VARCHAR(256) DEFAULT NULL,
api_path VARCHAR(512) DEFAULT NULL,
api_method VARCHAR(16) DEFAULT 'GET',
api_version VARCHAR(32) DEFAULT NULL,
api_desc TEXT DEFAULT NULL,
uapi_status INT DEFAULT 0,
auth_required TINYINT DEFAULT 0,
uapi_created_at VARCHAR(32) DEFAULT NULL,
uapi_updated_at VARCHAR(32) DEFAULT NULL,
upapp_id BIGINT DEFAULT NULL,
app_name VARCHAR(256) DEFAULT NULL,
app_code VARCHAR(128) DEFAULT NULL,
app_type VARCHAR(64) DEFAULT NULL,
app_desc TEXT DEFAULT NULL,
app_owner VARCHAR(128) DEFAULT NULL,
upapp_status INT DEFAULT 0,
upapp_updated_at VARCHAR(32) DEFAULT NULL,
synced_at VARCHAR(32) NOT NULL,
PRIMARY KEY (uapi_id, upapp_id),
KEY idx_uapi_updated (uapi_updated_at),
KEY idx_upapp_id (upapp_id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- =============================================================================
-- llmage_cache: synced from Sage llm / llmcatelog / llm_api_map tables
-- =============================================================================
CREATE TABLE IF NOT EXISTS llmage_cache (
llm_id BIGINT NOT NULL,
model_name VARCHAR(256) DEFAULT NULL,
model_version VARCHAR(64) DEFAULT NULL,
provider VARCHAR(128) DEFAULT NULL,
model_type VARCHAR(64) DEFAULT NULL,
llm_status INT DEFAULT 0,
llm_description TEXT DEFAULT NULL,
llm_created_at VARCHAR(32) DEFAULT NULL,
llm_updated_at VARCHAR(32) DEFAULT NULL,
catelog_id BIGINT DEFAULT NULL,
catelog_name VARCHAR(256) DEFAULT NULL,
catelog_code VARCHAR(128) DEFAULT NULL,
catelog_sort INT DEFAULT 0,
catelog_status INT DEFAULT 0,
catelog_updated_at VARCHAR(32) DEFAULT NULL,
api_map_id BIGINT DEFAULT NULL,
api_name VARCHAR(256) DEFAULT NULL,
api_endpoint VARCHAR(512) DEFAULT NULL,
api_version VARCHAR(32) DEFAULT NULL,
auth_type VARCHAR(32) DEFAULT NULL,
rate_limit INT DEFAULT NULL,
api_map_status INT DEFAULT 0,
api_map_updated_at VARCHAR(32) DEFAULT NULL,
synced_at VARCHAR(32) NOT NULL,
PRIMARY KEY (llm_id, catelog_id, api_map_id),
KEY idx_llm_updated (llm_updated_at),
KEY idx_catelog_id (catelog_id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

1
tests/__init__.py Normal file
View File

@ -0,0 +1 @@
# Tests package

Binary file not shown.

458
tests/test_dapi_auth.py Normal file
View File

@ -0,0 +1,458 @@
"""Tests for sageapi.middleware.dapi_auth"""
import hashlib
import hmac
import time
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.testclient import TestClient
from sageapi.middleware.dapi_auth import (
DEFAULT_TIMESTAMP_TOLERANCE_SEC,
DapiAuthError,
DapiAuthMiddleware,
DapiKeyRecord,
authenticate_request,
compute_dapi_signature,
lookup_api_key,
verify_timestamp,
)
def make_signature(method: str, path: str, timestamp: str, secret: str, body: bytes = b"") -> str:
"""Helper to create a valid DAPI signature."""
return compute_dapi_signature(method, path, timestamp, secret, body if body else None)
# ---------------------------------------------------------------------------
# compute_dapi_signature tests
# ---------------------------------------------------------------------------
class TestComputeDapiSignature:
def test_basic_signature(self):
sig = compute_dapi_signature("GET", "/api/test", "1700000000.0", "my-secret")
assert isinstance(sig, str)
assert len(sig) == 64 # SHA-256 hex length
def test_signature_with_body(self):
body = b'{"key": "value"}'
sig = compute_dapi_signature("POST", "/api/test", "1700000000.0", "my-secret", body)
assert isinstance(sig, str)
assert len(sig) == 64
def test_signature_deterministic(self):
sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret")
sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret")
assert sig1 == sig2
def test_different_body_different_signature(self):
sig1 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body1")
sig2 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body2")
assert sig1 != sig2
def test_empty_body_same_as_no_body(self):
sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret")
sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret", None)
assert sig1 == sig2
def test_manual_verification(self):
"""Verify the signature matches the expected HMAC-SHA256 output."""
method, path, ts, secret = "POST", "/v1/chat", "1700000000.0", "test-secret"
body = b'{"message":"hello"}'
body_hash = hashlib.sha256(body).hexdigest()
string_to_sign = f"{method}\n{path}\n{ts}\n{body_hash}"
expected = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha256).hexdigest()
actual = compute_dapi_signature(method, path, ts, secret, body)
assert actual == expected
# ---------------------------------------------------------------------------
# verify_timestamp tests
# ---------------------------------------------------------------------------
class TestVerifyTimestamp:
def test_valid_timestamp(self):
ts = str(time.time())
assert verify_timestamp(ts) is True
def test_expired_timestamp(self):
ts = str(time.time() - 600) # 10 minutes ago
assert verify_timestamp(ts) is False
def test_future_timestamp(self):
ts = str(time.time() + 600) # 10 minutes in future
assert verify_timestamp(ts) is False
def test_invalid_timestamp(self):
assert verify_timestamp("not-a-number") is False
assert verify_timestamp("") is False
def test_custom_tolerance(self):
ts = str(time.time() - 10)
assert verify_timestamp(ts, tolerance_sec=5) is False
assert verify_timestamp(ts, tolerance_sec=15) is True
# ---------------------------------------------------------------------------
# DapiKeyRecord tests
# ---------------------------------------------------------------------------
class TestDapiKeyRecord:
def test_active_key(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None)
assert rec.is_active is True
def test_inactive_key(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="inactive", expire_date=None)
assert rec.is_active is False
def test_not_expired_when_no_expiry(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None)
assert rec.is_expired is False
def test_expired_with_past_date(self):
past = datetime(2020, 1, 1, tzinfo=timezone.utc)
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=past)
assert rec.is_expired is True
def test_not_expired_with_future_date(self):
future = datetime(2099, 1, 1, tzinfo=timezone.utc)
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=future)
assert rec.is_expired is False
def test_expired_with_iso_string_past(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2020-01-01T00:00:00Z")
assert rec.is_expired is True
def test_not_expired_with_iso_string_future(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2099-01-01T00:00:00Z")
assert rec.is_expired is False
# ---------------------------------------------------------------------------
# lookup_api_key tests
# ---------------------------------------------------------------------------
class TestLookupApiKey:
@pytest.fixture(autouse=True)
def mock_sage_modules(self):
"""Create mock ahserver and sqlor modules for testing."""
import sys
import types
# Create fake modules
ahserver = types.ModuleType("ahserver")
ahserver_serverenv = types.ModuleType("ahserver.serverenv")
ahserver_serverenv.ServerEnv = MagicMock
ahserver.serverenv = ahserver_serverenv
sqlor = types.ModuleType("sqlor")
sqlor_dbpools = types.ModuleType("sqlor.dbpools")
mock_ctx = AsyncMock()
mock_sor = AsyncMock()
mock_sor.R = AsyncMock(return_value=[])
mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx)
sqlor.dbpools = sqlor_dbpools
sys.modules["ahserver"] = ahserver
sys.modules["ahserver.serverenv"] = ahserver_serverenv
sys.modules["sqlor"] = sqlor
sys.modules["sqlor.dbpools"] = sqlor_dbpools
yield mock_ctx, mock_sor
# Cleanup
for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]:
if mod in sys.modules:
del sys.modules[mod]
@pytest.mark.asyncio
async def test_lookup_found(self, mock_sage_modules):
mock_ctx, mock_sor = mock_sage_modules
mock_rec = {
"id": 1,
"apikey": "test-key",
"secret": "test-secret",
"status": "active",
"expire_date": "2099-01-01T00:00:00Z",
"description": "Test key",
}
mock_sor.R.return_value = [mock_rec]
result = await lookup_api_key("test-key")
assert result is not None
assert result.apikey == "test-key"
assert result.secret == "test-secret"
assert result.is_active is True
@pytest.mark.asyncio
async def test_lookup_not_found(self, mock_sage_modules):
mock_ctx, mock_sor = mock_sage_modules
mock_sor.R.return_value = []
result = await lookup_api_key("nonexistent")
assert result is None
# ---------------------------------------------------------------------------
# authenticate_request tests
# ---------------------------------------------------------------------------
class TestAuthenticateRequest:
def _make_request(self, headers: dict, body: bytes = b"", method: str = "GET", path: str = "/test") -> Request:
scope = {
"type": "http",
"method": method,
"path": path,
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
"query_string": b"",
}
async def receive():
return {"type": "http.request", "body": body, "more_body": False}
return Request(scope, receive)
@pytest.mark.asyncio
async def test_missing_api_key_header(self):
req = self._make_request({
"X-DAPI-Timestamp": str(time.time()),
"X-DAPI-Signature": "abc",
})
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "X-DAPI-Key" in exc.value.detail
@pytest.mark.asyncio
async def test_missing_timestamp_header(self):
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Signature": "abc",
})
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "X-DAPI-Timestamp" in exc.value.detail
@pytest.mark.asyncio
async def test_missing_signature_header(self):
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": str(time.time()),
})
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "X-DAPI-Signature" in exc.value.detail
@pytest.mark.asyncio
async def test_expired_timestamp(self):
ts = str(time.time() - 600)
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "abc",
})
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "timestamp" in exc.value.detail.lower()
@pytest.mark.asyncio
async def test_invalid_api_key(self):
ts = str(time.time())
req = self._make_request({
"X-DAPI-Key": "bad-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "abc",
})
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=None):
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "Invalid API key" in exc.value.detail
@pytest.mark.asyncio
async def test_inactive_key(self):
ts = str(time.time())
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "abc",
})
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="secret", status="inactive", expire_date=None)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 403
assert "inactive" in exc.value.detail.lower()
@pytest.mark.asyncio
async def test_expired_key(self):
ts = str(time.time())
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "abc",
})
key_rec = DapiKeyRecord(
id=1, apikey="test-key", secret="secret", status="active",
expire_date=datetime(2020, 1, 1, tzinfo=timezone.utc),
)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 403
assert "expired" in exc.value.detail.lower()
@pytest.mark.asyncio
async def test_invalid_signature(self):
ts = str(time.time())
body = b'{"test": "data"}'
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "invalid-signature",
}, body=body, method="POST")
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="real-secret", status="active", expire_date=None)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "Invalid signature" in exc.value.detail
@pytest.mark.asyncio
async def test_valid_authentication(self):
ts = str(time.time())
secret = "my-secret"
body = b'{"test": "data"}'
sig = compute_dapi_signature("POST", "/test", ts, secret, body)
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": sig,
}, body=body, method="POST")
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret=secret, status="active", expire_date=None)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
result = await authenticate_request(req)
assert result.apikey == "test-key"
assert result.secret == secret
# ---------------------------------------------------------------------------
# DapiAuthMiddleware tests
# ---------------------------------------------------------------------------
class TestDapiAuthMiddleware:
@pytest.fixture(autouse=True)
def mock_sage_modules(self):
"""Create mock ahserver and sqlor modules."""
import sys
import types
ahserver = types.ModuleType("ahserver")
ahserver_serverenv = types.ModuleType("ahserver.serverenv")
ahserver_serverenv.ServerEnv = MagicMock
ahserver.serverenv = ahserver_serverenv
sqlor = types.ModuleType("sqlor")
sqlor_dbpools = types.ModuleType("sqlor.dbpools")
mock_sor = AsyncMock()
mock_sor.R = AsyncMock(return_value=[])
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx)
sqlor.dbpools = sqlor_dbpools
sys.modules["ahserver"] = ahserver
sys.modules["ahserver.serverenv"] = ahserver_serverenv
sys.modules["sqlor"] = sqlor
sys.modules["sqlor.dbpools"] = sqlor_dbpools
yield
for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]:
if mod in sys.modules:
del sys.modules[mod]
def _build_test_app(self, exclude_paths=None, tolerance_sec=DEFAULT_TIMESTAMP_TOLERANCE_SEC):
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse
from starlette.routing import Route
async def protected(request: Request):
return PlainTextResponse("OK - authenticated")
async def health(request: Request):
return PlainTextResponse("healthy")
app = Starlette(
routes=[
Route("/protected", endpoint=protected, methods=["GET"]),
Route("/health", endpoint=health, methods=["GET"]),
]
)
app.add_middleware(
DapiAuthMiddleware,
exclude_paths=exclude_paths or ["/health"],
tolerance_sec=tolerance_sec,
)
return app
def test_unauthenticated_request_rejected(self):
app = self._build_test_app()
client = TestClient(app)
resp = client.get("/protected")
assert resp.status_code == 401
assert "error" in resp.json()
def test_excluded_path_bypasses_auth(self):
app = self._build_test_app()
client = TestClient(app)
resp = client.get("/health")
assert resp.status_code == 200
assert resp.text == "healthy"
def test_valid_request_accepted(self):
ts = str(time.time())
secret = "test-secret"
sig = compute_dapi_signature("GET", "/protected", ts, secret)
key_rec = DapiKeyRecord(id=1, apikey="valid-key", secret=secret, status="active", expire_date=None)
app = self._build_test_app()
client = TestClient(app)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
resp = client.get(
"/protected",
headers={
"X-DAPI-Key": "valid-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": sig,
},
)
assert resp.status_code == 200
assert resp.text == "OK - authenticated"

288
tests/test_http_client.py Normal file
View File

@ -0,0 +1,288 @@
"""Tests for sageapi.utils.http_client"""
import hashlib
import hmac
import json
import time
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
import pytest
from sageapi.utils.http_client import (
DAPISigner,
SageHttpClient,
RetryConfig,
compute_dapi_signature,
)
# ---------------------------------------------------------------------------
# compute_dapi_signature tests (also imported from middleware)
# ---------------------------------------------------------------------------
class TestComputeDapiSignature:
def test_basic(self):
sig = compute_dapi_signature("GET", "/api/test", "1700000000.0", "secret")
assert len(sig) == 64
def test_with_body(self):
body = b'{"msg":"hi"}'
sig = compute_dapi_signature("POST", "/api/test", "1700000000.0", "secret", body)
assert len(sig) == 64
def test_deterministic(self):
sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret")
sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret")
assert sig1 == sig2
# ---------------------------------------------------------------------------
# DAPISigner tests
# ---------------------------------------------------------------------------
class TestDAPISigner:
def test_sign_request(self):
signer = DAPISigner(api_key="my-key", api_secret="my-secret")
headers = signer.sign_request("GET", "/test")
assert "X-DAPI-Key" in headers
assert "X-DAPI-Timestamp" in headers
assert "X-DAPI-Signature" in headers
assert headers["X-DAPI-Key"] == "my-key"
assert headers["X-DAPI-Timestamp"] # should be a valid timestamp string
def test_sign_request_with_body(self):
signer = DAPISigner(api_key="k", api_secret="s")
body = b'{"test": true}'
headers = signer.sign_request("POST", "/v1/chat", body)
assert headers["X-DAPI-Key"] == "k"
# Verify timestamp is recent
ts = float(headers["X-DAPI-Timestamp"])
assert abs(time.time() - ts) < 2
def test_sign_request_produces_valid_signature(self):
signer = DAPISigner(api_key="k", api_secret="secret")
body = b'hello'
headers = signer.sign_request("POST", "/path", body)
expected = compute_dapi_signature("POST", "/path", headers["X-DAPI-Timestamp"], "secret", body)
assert headers["X-DAPI-Signature"] == expected
def test_sign_request_uppercases_method(self):
signer = DAPISigner(api_key="k", api_secret="s")
headers = signer.sign_request("get", "/path")
expected = compute_dapi_signature("GET", "/path", headers["X-DAPI-Timestamp"], "s")
assert headers["X-DAPI-Signature"] == expected
# ---------------------------------------------------------------------------
# RetryConfig tests
# ---------------------------------------------------------------------------
class TestRetryConfig:
def test_defaults(self):
config = RetryConfig()
assert config.max_retries == 3
assert config.backoff_factor == 0.5
assert 429 in config.retry_on_status
assert 500 in config.retry_on_status
def test_custom(self):
config = RetryConfig(max_retries=5, backoff_factor=1.0)
assert config.max_retries == 5
assert config.backoff_factor == 1.0
# ---------------------------------------------------------------------------
# SageHttpClient tests
# ---------------------------------------------------------------------------
class TestSageHttpClient:
def test_init_defaults(self):
client = SageHttpClient(
base_url="https://api.example.com",
api_key="key",
api_secret="secret",
)
assert client.base_url == "https://api.example.com"
assert client.signer.api_key == "key"
assert client.signer.api_secret == "secret"
assert client.auto_sign is True
def test_init_with_custom_signer(self):
signer = DAPISigner(api_key="k", api_secret="s")
client = SageHttpClient(base_url="https://api.example.com", signer=signer)
assert client.signer is signer
def test_init_without_auto_sign(self):
client = SageHttpClient(
base_url="https://api.example.com",
auto_sign=False,
)
assert client.auto_sign is False
def test_build_url(self):
client = SageHttpClient(base_url="https://api.example.com")
assert client._build_url("/v1/test") == "https://api.example.com/v1/test"
assert client._build_url("v1/test") == "https://api.example.com/v1/test"
# Full URL should pass through
assert client._build_url("https://other.com/path") == "https://other.com/path"
def test_get_relative_path(self):
client = SageHttpClient(base_url="https://api.example.com")
assert client._get_relative_path("/v1/chat") == "/v1/chat"
assert client._get_relative_path("https://api.example.com/v1/chat?page=1") == "/v1/chat"
def test_sign_and_prepare_adds_dapi_headers(self):
client = SageHttpClient(
base_url="https://api.example.com",
api_key="k",
api_secret="s",
)
import asyncio
url, headers, body = asyncio.get_event_loop().run_until_complete(
client._sign_and_prepare("GET", "/test")
)
assert "X-DAPI-Key" in headers
assert "X-DAPI-Timestamp" in headers
assert "X-DAPI-Signature" in headers
assert headers["X-DAPI-Key"] == "k"
def test_sign_and_prepare_with_json_body(self):
client = SageHttpClient(
base_url="https://api.example.com",
api_key="k",
api_secret="s",
)
import asyncio
url, headers, body = asyncio.get_event_loop().run_until_complete(
client._sign_and_prepare("POST", "/test", json_body={"msg": "hi"})
)
assert body is not None
assert json.loads(body) == {"msg": "hi"}
assert headers.get("Content-Type") == "application/json"
@pytest.mark.asyncio
async def test_context_manager(self):
client = SageHttpClient(
base_url="https://httpbin.org",
api_key="k",
api_secret="s",
max_retries=0,
)
async with client as c:
assert c is client
assert client._closed is True
@pytest.mark.asyncio
async def test_close(self):
client = SageHttpClient(
base_url="https://httpbin.org",
api_key="k",
api_secret="s",
max_retries=0,
)
await client._get_session() # create session
await client.close()
assert client._closed is True
@pytest.mark.asyncio
async def test_request_with_retry_on_502(self):
"""Test that 502 responses trigger retries and raise after exhaustion."""
client = SageHttpClient(
base_url="",
api_key="k",
api_secret="s",
max_retries=2,
backoff_factor=0.01, # fast for tests
auto_sign=False,
)
mock_response = AsyncMock()
mock_response.status = 502
mock_response.request_info = MagicMock()
mock_response.history = []
mock_response.release = MagicMock()
mock_session = AsyncMock()
mock_session.request = AsyncMock(return_value=mock_response)
mock_session.closed = False
client._session = mock_session
with pytest.raises(aiohttp.ClientResponseError) as exc_info:
await client.request("GET", "/test")
assert exc_info.value.status == 502
# Should have been called 3 times (1 initial + 2 retries)
assert mock_session.request.call_count == 3
class TestSageHttpClientIntegration:
"""Integration tests against httpbin.org (requires network)."""
@pytest.mark.asyncio
async def test_get_request(self):
client = SageHttpClient(
base_url="https://httpbin.org",
auto_sign=False,
max_retries=0,
timeout=10.0,
)
async with client:
resp = await client.get("/get")
assert resp.status == 200
data = await resp.json()
assert "url" in data
@pytest.mark.asyncio
async def test_post_request_with_json(self):
client = SageHttpClient(
base_url="https://httpbin.org",
auto_sign=False,
max_retries=0,
timeout=10.0,
)
async with client:
resp = await client.post("/post", json={"key": "value"})
assert resp.status == 200
data = await resp.json()
assert data["json"] == {"key": "value"}
@pytest.mark.asyncio
async def test_headers_sent(self):
client = SageHttpClient(
base_url="https://httpbin.org",
headers={"X-Custom-Header": "test-value"},
auto_sign=False,
max_retries=0,
timeout=10.0,
)
async with client:
resp = await client.get("/headers")
assert resp.status == 200
data = await resp.json()
assert data["headers"].get("X-Custom-Header") == "test-value"
@pytest.mark.asyncio
async def test_query_params(self):
client = SageHttpClient(
base_url="https://httpbin.org",
auto_sign=False,
max_retries=0,
timeout=10.0,
)
async with client:
resp = await client.get("/get", params={"foo": "bar", "baz": "qux"})
assert resp.status == 200
data = await resp.json()
assert data["args"]["foo"] == "bar"
assert data["args"]["baz"] == "qux"