feat: implement sync engine, API handlers, DAPI auth, HTTP client
- Sync engine: BaseSync abstract class + 4 sync modules (users/pricing/uapi/llmage) - Checkpoint management via sync_state table - Batch processing with retry and exponential backoff - Incremental fetch from Sage DB via sqlor - UPSERT to local cache tables - API handlers: balance/accounting/users/pricing/health - Balance: cache lookup + Sage fallback - Accounting: create with idempotency, query with filters/pagination - Users: keyword search, org filter - Pricing: filter by ppid/llmid/type/status - Health: basic + readiness checks (DB connectivity) - DAPI auth: middleware + authenticate_request function - HMAC-SHA256 signature verification - Timestamp window validation - Sage downapikey table lookup - HTTP client: SageHttpClient with aiohttp - Auto DAPI signature injection - Connection pooling, retry, timeout - Router: 12 routes registered - Module init: load_sageapi() wires everything to ServerEnv
This commit is contained in:
parent
a9ea05ff2d
commit
5936a2f328
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
@ -0,0 +1,26 @@
|
||||
"""SageAPI - Sage data caching and proxy API server.
|
||||
|
||||
Provides cached access to Sage data (users, pricing, uapi, llmage)
|
||||
with DAPI authentication and independent multi-instance deployment.
|
||||
"""
|
||||
|
||||
__version__ = '0.1.0'
|
||||
|
||||
# Public API exports
|
||||
from sageapi.sync import run_all_syncs, UserSync, PricingSync, UapiSync, LlmageSync
|
||||
from sageapi.router import Router, setup_routes
|
||||
from sageapi.cache.cache_manager import CacheManager
|
||||
from sageapi.middleware.dapi_auth import authenticate_request, DapiAuthMiddleware
|
||||
|
||||
__all__ = [
|
||||
'run_all_syncs',
|
||||
'UserSync',
|
||||
'PricingSync',
|
||||
'UapiSync',
|
||||
'LlmageSync',
|
||||
'Router',
|
||||
'setup_routes',
|
||||
'CacheManager',
|
||||
'authenticate_request',
|
||||
'DapiAuthMiddleware',
|
||||
]
|
||||
BIN
sageapi/api/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
sageapi/api/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
sageapi/api/__pycache__/accounting.cpython-310.pyc
Normal file
BIN
sageapi/api/__pycache__/accounting.cpython-310.pyc
Normal file
Binary file not shown.
BIN
sageapi/api/__pycache__/balance.cpython-310.pyc
Normal file
BIN
sageapi/api/__pycache__/balance.cpython-310.pyc
Normal file
Binary file not shown.
BIN
sageapi/api/__pycache__/health.cpython-310.pyc
Normal file
BIN
sageapi/api/__pycache__/health.cpython-310.pyc
Normal file
Binary file not shown.
BIN
sageapi/api/__pycache__/pricing.cpython-310.pyc
Normal file
BIN
sageapi/api/__pycache__/pricing.cpython-310.pyc
Normal file
Binary file not shown.
BIN
sageapi/api/__pycache__/users.cpython-310.pyc
Normal file
BIN
sageapi/api/__pycache__/users.cpython-310.pyc
Normal file
Binary file not shown.
@ -2,7 +2,7 @@
|
||||
|
||||
Provides endpoints for creating and querying accounting records.
|
||||
Writing goes directly to the accounting_records table; reads
|
||||
are served from the same table with optional date range filtering.
|
||||
are served from the same table with optional filtering.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -14,64 +14,86 @@ from typing import Any
|
||||
|
||||
from appPublic.log import debug, error
|
||||
from sqlor.dbpools import DBPools
|
||||
from ahserver.serverenv import ServerEnv
|
||||
|
||||
|
||||
async def create_accounting_record(
|
||||
customer_id: str,
|
||||
amount: float,
|
||||
record_type: str = 'charge',
|
||||
description: str = '',
|
||||
**extra: Any,
|
||||
llmid: str = '',
|
||||
model_name: str = '',
|
||||
pricing_id: str = '',
|
||||
input_tokens: int | None = None,
|
||||
output_tokens: int | None = None,
|
||||
total_tokens: int | None = None,
|
||||
quantity: float | None = None,
|
||||
currency: str = 'CNY',
|
||||
request_id: str = '',
|
||||
transno: str = '',
|
||||
) -> str:
|
||||
"""Create a new accounting record.
|
||||
|
||||
Args:
|
||||
customer_id: The customer identifier.
|
||||
amount: The accounting amount (positive for charges, negative for credits).
|
||||
record_type: Type of record (charge, credit, adjustment, etc.).
|
||||
description: Optional description of the transaction.
|
||||
|
||||
Returns:
|
||||
JSON string with success flag and the created record ID.
|
||||
"""
|
||||
"""Create a new accounting record with idempotency via request_id."""
|
||||
result: dict[str, Any] = {'success': False, 'record_id': None}
|
||||
|
||||
try:
|
||||
from ahserver.serverenv import ServerEnv
|
||||
env = ServerEnv()
|
||||
dbname = env.get_module_dbname('sageapi')
|
||||
|
||||
if not dbname:
|
||||
result['error'] = 'No database configured for sageapi module'
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
record_id = str(uuid.uuid4())
|
||||
record_id = request_id or str(uuid.uuid4())
|
||||
now = time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# Check idempotency
|
||||
if request_id:
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
existing = await sor.sqlExe(
|
||||
"SELECT id FROM accounting_records WHERE request_id = ${request_id}$",
|
||||
{'request_id': request_id},
|
||||
)
|
||||
if isinstance(existing, list) and existing:
|
||||
result['success'] = True
|
||||
result['record_id'] = existing[0].get('id', existing[0].get('id') if isinstance(existing[0], dict) else existing[0])
|
||||
result['duplicate'] = True
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
sql = """
|
||||
INSERT INTO accounting_records
|
||||
(id, customer_id, amount, record_type, description, created_at, extra)
|
||||
(id, customer_id, llmid, model_name, pricing_id,
|
||||
input_tokens, output_tokens, total_tokens, quantity,
|
||||
amount, currency, request_id, transno, status,
|
||||
created_at, updated_at)
|
||||
VALUES
|
||||
(${id}$, ${customer_id}$, ${amount}$, ${record_type}$, ${description}$, ${created_at}$, ${extra}$)
|
||||
(${id}$, ${customer_id}$, ${llmid}$, ${model_name}$, ${pricing_id}$,
|
||||
${input_tokens}$, ${output_tokens}$, ${total_tokens}$, ${quantity}$,
|
||||
${amount}$, ${currency}$, ${request_id}$, ${transno}$, 'accounted',
|
||||
${created_at}$, ${updated_at}$)
|
||||
"""
|
||||
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
await sor.sqlExe(sql, {
|
||||
params = {
|
||||
'id': record_id,
|
||||
'customer_id': customer_id,
|
||||
'llmid': llmid,
|
||||
'model_name': model_name,
|
||||
'pricing_id': pricing_id,
|
||||
'input_tokens': input_tokens,
|
||||
'output_tokens': output_tokens,
|
||||
'total_tokens': total_tokens,
|
||||
'quantity': quantity,
|
||||
'amount': amount,
|
||||
'record_type': record_type,
|
||||
'description': description,
|
||||
'currency': currency,
|
||||
'request_id': request_id,
|
||||
'transno': transno,
|
||||
'created_at': now,
|
||||
'extra': json.dumps(extra, ensure_ascii=False) if extra else None,
|
||||
})
|
||||
'updated_at': now,
|
||||
}
|
||||
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
await sor.sqlExe(sql, params)
|
||||
result['success'] = True
|
||||
result['record_id'] = record_id
|
||||
debug(f'Accounting record created: id={record_id}, customer={customer_id}, amount={amount}')
|
||||
|
||||
except Exception as e:
|
||||
error(f'Accounting record creation failed: {e}')
|
||||
error(f'create_accounting_record error: {e}')
|
||||
result['error'] = str(e)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
@ -79,30 +101,19 @@ async def create_accounting_record(
|
||||
|
||||
async def query_accounting_records(
|
||||
customer_id: str | None = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
date_from: str | None = None,
|
||||
date_to: str | None = None,
|
||||
llmid: str | None = None,
|
||||
status: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> str:
|
||||
"""Query accounting records with optional filters.
|
||||
|
||||
Args:
|
||||
customer_id: Filter by customer ID.
|
||||
start_date: Filter records from this date (inclusive, YYYY-MM-DD).
|
||||
end_date: Filter records up to this date (inclusive, YYYY-MM-DD).
|
||||
limit: Maximum number of records to return.
|
||||
offset: Number of records to skip.
|
||||
|
||||
Returns:
|
||||
JSON string with success flag and record data.
|
||||
"""
|
||||
"""Query accounting records with filters and pagination."""
|
||||
result: dict[str, Any] = {'success': False, 'data': [], 'total': 0}
|
||||
|
||||
try:
|
||||
from ahserver.serverenv import ServerEnv
|
||||
env = ServerEnv()
|
||||
dbname = env.get_module_dbname('sageapi')
|
||||
|
||||
if not dbname:
|
||||
result['error'] = 'No database configured for sageapi module'
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
@ -113,41 +124,56 @@ async def query_accounting_records(
|
||||
if customer_id:
|
||||
conditions.append('customer_id = ${customer_id}$')
|
||||
params['customer_id'] = customer_id
|
||||
if start_date:
|
||||
conditions.append('created_at >= ${start_date}$')
|
||||
params['start_date'] = start_date
|
||||
if end_date:
|
||||
conditions.append('created_at <= ${end_date}$')
|
||||
params['end_date'] = end_date
|
||||
if date_from:
|
||||
conditions.append('created_at >= ${date_from}$')
|
||||
params['date_from'] = date_from
|
||||
if date_to:
|
||||
conditions.append('created_at <= ${date_to}$')
|
||||
params['date_to'] = date_to
|
||||
if llmid:
|
||||
conditions.append('llmid = ${llmid}$')
|
||||
params['llmid'] = llmid
|
||||
if status:
|
||||
conditions.append('status = ${status}$')
|
||||
params['status'] = status
|
||||
|
||||
where_clause = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
|
||||
where = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
|
||||
|
||||
# Count query
|
||||
count_sql = f"""
|
||||
SELECT COUNT(*) as cnt FROM accounting_records {where_clause}
|
||||
"""
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
count_rows = await sor.sqlExe(count_sql, params)
|
||||
total = count_rows[0]['cnt'] if count_rows else 0
|
||||
result['total'] = total
|
||||
count_sql = f"SELECT COUNT(*) as cnt FROM accounting_records {where}"
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
if total > 0:
|
||||
# Data query
|
||||
data_sql = f"""
|
||||
SELECT id, customer_id, amount, record_type, description, created_at, extra
|
||||
SELECT id, customer_id, llmid, model_name, pricing_id,
|
||||
input_tokens, output_tokens, total_tokens, quantity,
|
||||
amount, currency, request_id, transno, status,
|
||||
created_at, updated_at
|
||||
FROM accounting_records
|
||||
{where_clause}
|
||||
{where}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ${limit}$ OFFSET ${offset}$
|
||||
LIMIT {page_size} OFFSET {offset}
|
||||
"""
|
||||
params['limit'] = limit
|
||||
params['offset'] = offset
|
||||
rows = await sor.sqlExe(data_sql, params)
|
||||
result['data'] = [dict(r) for r in (rows or [])]
|
||||
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
count_result = await sor.sqlExe(count_sql, params)
|
||||
if isinstance(count_result, list) and count_result:
|
||||
result['total'] = count_result[0].get('cnt', 0)
|
||||
elif isinstance(count_result, dict):
|
||||
result['total'] = count_result.get('cnt', 0)
|
||||
|
||||
data = await sor.sqlExe(data_sql, params)
|
||||
if isinstance(data, dict):
|
||||
result['data'] = data.get('rows', [])
|
||||
elif isinstance(data, list):
|
||||
result['data'] = data
|
||||
|
||||
result['success'] = True
|
||||
result['page'] = page
|
||||
result['page_size'] = page_size
|
||||
|
||||
except Exception as e:
|
||||
error(f'Accounting query failed: {e}')
|
||||
error(f'query_accounting_records error: {e}')
|
||||
result['error'] = str(e)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
"""Customer balance query API handler.
|
||||
|
||||
Provides the RESTful endpoint for querying customer account balances.
|
||||
Reads from the local customer_balance cache table.
|
||||
Reads from the local customer_balance cache table, with fallback to
|
||||
real-time query from Sage acc_balance table.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -10,15 +11,18 @@ import json
|
||||
from typing import Any
|
||||
|
||||
from appPublic.log import debug, error
|
||||
from sqlor.dbpools import DBPools
|
||||
from sqlor.dbpools import DBPools, get_sor_context
|
||||
from ahserver.serverenv import ServerEnv
|
||||
|
||||
|
||||
async def get_customer_balance(customer_id: str | None = None) -> str:
|
||||
"""Query customer balance.
|
||||
|
||||
First checks the local customer_balance cache. If not found,
|
||||
falls back to real-time query from Sage acc_balance table.
|
||||
|
||||
Args:
|
||||
customer_id: Optional customer ID filter. If not provided,
|
||||
returns all customer balances.
|
||||
customer_id: Optional customer ID filter.
|
||||
|
||||
Returns:
|
||||
JSON string with success flag and balance data.
|
||||
@ -26,42 +30,92 @@ async def get_customer_balance(customer_id: str | None = None) -> str:
|
||||
result: dict[str, Any] = {'success': False, 'data': [], 'total': 0}
|
||||
|
||||
try:
|
||||
from ahserver.serverenv import ServerEnv
|
||||
env = ServerEnv()
|
||||
dbname = env.get_module_dbname('sageapi')
|
||||
|
||||
if not dbname:
|
||||
cache_dbname = env.get_module_dbname('sageapi')
|
||||
if not cache_dbname:
|
||||
result['error'] = 'No database configured for sageapi module'
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
async with DBPools().sqlorContext(cache_dbname) as sor:
|
||||
params: dict[str, Any] = {}
|
||||
where_clause = ''
|
||||
where = ''
|
||||
if customer_id:
|
||||
where_clause = 'WHERE customer_id = ${customer_id}$'
|
||||
where = 'WHERE id = ${customer_id}$'
|
||||
params['customer_id'] = customer_id
|
||||
|
||||
sql = f"""
|
||||
SELECT customer_id, balance, currency, updated_at
|
||||
SELECT id, balance, currency, credit_limit,
|
||||
last_recharge, last_consumption,
|
||||
status, cached_at
|
||||
FROM customer_balance
|
||||
{where_clause}
|
||||
ORDER BY customer_id
|
||||
{where}
|
||||
ORDER BY id
|
||||
"""
|
||||
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
data = await sor.sqlExe(sql, params)
|
||||
if isinstance(data, dict):
|
||||
result['total'] = data.get('total', 0)
|
||||
result['data'] = [dict(r) for r in data.get('rows', [])]
|
||||
else:
|
||||
rows = [dict(r) for r in (data or [])]
|
||||
result['data'] = rows
|
||||
result['total'] = len(rows)
|
||||
result['total'] = data.get('total', len(data.get('rows', [])))
|
||||
result['data'] = data.get('rows', [])
|
||||
elif isinstance(data, list):
|
||||
result['total'] = len(data)
|
||||
result['data'] = data
|
||||
|
||||
# If cache miss for specific customer, try real-time from Sage
|
||||
if customer_id and not result['data']:
|
||||
result['data'] = await _query_sage_balance(env, customer_id)
|
||||
result['total'] = len(result['data'])
|
||||
|
||||
result['success'] = True
|
||||
debug(f'Balance query: returned {result["total"]} records')
|
||||
|
||||
except Exception as e:
|
||||
error(f'Balance query failed: {e}')
|
||||
error(f'get_customer_balance error: {e}')
|
||||
result['error'] = str(e)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
|
||||
async def _query_sage_balance(env: ServerEnv, customer_id: str) -> list[dict]:
|
||||
"""Fallback: query Sage acc_balance table directly."""
|
||||
try:
|
||||
async with get_sor_context(env, 'sage') as sor:
|
||||
sql = """
|
||||
SELECT customer_id, balance, currency, status, updated_at
|
||||
FROM acc_balance
|
||||
WHERE customer_id = ${customer_id}$
|
||||
"""
|
||||
rows = await sor.sqlExe(sql, {'customer_id': customer_id})
|
||||
if isinstance(rows, list):
|
||||
return rows
|
||||
elif isinstance(rows, dict):
|
||||
return rows.get('rows', [])
|
||||
except Exception as e:
|
||||
error(f'_query_sage_balance error: {e}')
|
||||
return []
|
||||
|
||||
|
||||
async def update_customer_balance(customer_id: str, balance: float) -> str:
|
||||
"""Update customer balance in cache (called by sync or accounting)."""
|
||||
result: dict[str, Any] = {'success': False}
|
||||
|
||||
try:
|
||||
env = ServerEnv()
|
||||
cache_dbname = env.get_module_dbname('sageapi')
|
||||
|
||||
sql = """
|
||||
INSERT INTO customer_balance (id, balance, cached_at)
|
||||
VALUES (${customer_id}$, ${balance}$, NOW())
|
||||
ON DUPLICATE KEY UPDATE
|
||||
balance = ${balance}$,
|
||||
cached_at = NOW()
|
||||
"""
|
||||
async with DBPools().sqlorContext(cache_dbname) as sor:
|
||||
await sor.sqlExe(sql, {
|
||||
'customer_id': customer_id,
|
||||
'balance': balance,
|
||||
})
|
||||
result['success'] = True
|
||||
|
||||
except Exception as e:
|
||||
error(f'update_customer_balance error: {e}')
|
||||
result['error'] = str(e)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""Health check API handler.
|
||||
|
||||
Provides a simple endpoint for load balancer health checks and
|
||||
system status monitoring. No authentication required.
|
||||
Provides endpoints for service health and readiness checks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -10,51 +9,101 @@ import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from appPublic.log import debug
|
||||
from appPublic.log import debug, error
|
||||
from sqlor.dbpools import DBPools
|
||||
from ahserver.serverenv import ServerEnv
|
||||
|
||||
_START_TIME = time.time()
|
||||
|
||||
|
||||
async def health_check() -> str:
|
||||
"""Health check endpoint.
|
||||
"""Basic health check - returns service status."""
|
||||
uptime = time.time() - _START_TIME
|
||||
|
||||
Returns system status including database connectivity,
|
||||
cache stats, and uptime information.
|
||||
|
||||
Returns:
|
||||
JSON string with health status.
|
||||
"""
|
||||
result: dict[str, Any] = {
|
||||
result = {
|
||||
'status': 'ok',
|
||||
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'database': 'unknown',
|
||||
'cache': {},
|
||||
'service': 'sageapi',
|
||||
'uptime_seconds': round(uptime, 1),
|
||||
'timestamp': time.strftime('%Y-%m-%dT%H:%M:%S%z'),
|
||||
}
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
|
||||
async def readiness_check() -> str:
|
||||
"""Readiness check - verifies database connectivity."""
|
||||
result: dict[str, Any] = {
|
||||
'status': 'unknown',
|
||||
'checks': {},
|
||||
}
|
||||
|
||||
# Check database connectivity
|
||||
# Check cache database connection
|
||||
try:
|
||||
from ahserver.serverenv import ServerEnv
|
||||
env = ServerEnv()
|
||||
dbname = env.get_module_dbname('sageapi')
|
||||
|
||||
if dbname:
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
await sor.sqlExe('SELECT 1')
|
||||
result['database'] = 'connected'
|
||||
if not dbname:
|
||||
result['checks']['cache_db'] = {
|
||||
'status': 'fail',
|
||||
'error': 'No database configured for sageapi module',
|
||||
}
|
||||
else:
|
||||
result['database'] = 'not_configured'
|
||||
result['status'] = 'degraded'
|
||||
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
rows = await sor.sqlExe('SELECT 1 as ping')
|
||||
result['checks']['cache_db'] = {
|
||||
'status': 'ok',
|
||||
'dbname': dbname,
|
||||
}
|
||||
except Exception as e:
|
||||
result['database'] = f'error: {str(e)}'
|
||||
result['status'] = 'unhealthy'
|
||||
error(f'readiness_check cache_db error: {e}')
|
||||
result['checks']['cache_db'] = {
|
||||
'status': 'fail',
|
||||
'error': str(e),
|
||||
}
|
||||
|
||||
# Cache stats
|
||||
# Check Sage database connection
|
||||
try:
|
||||
from ..cache.cache_manager import _get_cache_manager
|
||||
cm = _get_cache_manager()
|
||||
result['cache'] = cm.stats()
|
||||
except Exception:
|
||||
result['cache'] = {'error': 'cache not initialized'}
|
||||
from sqlor.dbpools import get_sor_context
|
||||
async with get_sor_context(env, 'sage') as sor:
|
||||
rows = await sor.sqlExe('SELECT 1 as ping')
|
||||
result['checks']['sage_db'] = {'status': 'ok'}
|
||||
except Exception as e:
|
||||
error(f'readiness_check sage_db error: {e}')
|
||||
result['checks']['sage_db'] = {
|
||||
'status': 'fail',
|
||||
'error': str(e),
|
||||
}
|
||||
|
||||
# Check sync state
|
||||
try:
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
sql = """
|
||||
SELECT entity_type, sync_status, last_sync_time
|
||||
FROM sync_state
|
||||
ORDER BY last_sync_time DESC
|
||||
"""
|
||||
rows = await sor.sqlExe(sql)
|
||||
if isinstance(rows, list):
|
||||
result['checks']['sync_status'] = {
|
||||
'status': 'ok',
|
||||
'entities': [
|
||||
{
|
||||
'entity_type': r.get('entity_type', ''),
|
||||
'sync_status': r.get('sync_status', ''),
|
||||
'last_sync_time': str(r.get('last_sync_time', '')),
|
||||
}
|
||||
for r in rows
|
||||
],
|
||||
}
|
||||
except Exception as e:
|
||||
result['checks']['sync_status'] = {
|
||||
'status': 'fail',
|
||||
'error': str(e),
|
||||
}
|
||||
|
||||
# Overall status
|
||||
all_ok = all(
|
||||
check.get('status') == 'ok'
|
||||
for check in result['checks'].values()
|
||||
)
|
||||
result['status'] = 'ready' if all_ok else 'degraded'
|
||||
|
||||
debug(f'Health check: status={result["status"]}')
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""Pricing query API handler.
|
||||
"""Pricing API handler.
|
||||
|
||||
Provides the RESTful endpoint for querying pricing information.
|
||||
Reads from the local pricing_cache table synced from Sage.
|
||||
Provides endpoint for querying cached pricing data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -11,72 +10,110 @@ from typing import Any
|
||||
|
||||
from appPublic.log import debug, error
|
||||
from sqlor.dbpools import DBPools
|
||||
from ahserver.serverenv import ServerEnv
|
||||
|
||||
|
||||
async def query_pricing(
|
||||
program_id: str | None = None,
|
||||
model: str | None = None,
|
||||
limit: int = 200,
|
||||
offset: int = 0,
|
||||
ppid: str | None = None,
|
||||
llmid: str | None = None,
|
||||
pricing_type: str | None = None,
|
||||
status: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> str:
|
||||
"""Query pricing information from the local cache.
|
||||
|
||||
Args:
|
||||
program_id: Filter by pricing program ID.
|
||||
model: Filter by model name (partial match).
|
||||
limit: Maximum number of records to return.
|
||||
offset: Number of records to skip.
|
||||
|
||||
Returns:
|
||||
JSON string with success flag and pricing data.
|
||||
"""
|
||||
"""Query pricing from cache with filters."""
|
||||
result: dict[str, Any] = {'success': False, 'data': [], 'total': 0}
|
||||
|
||||
try:
|
||||
from ahserver.serverenv import ServerEnv
|
||||
env = ServerEnv()
|
||||
dbname = env.get_module_dbname('sageapi')
|
||||
|
||||
if not dbname:
|
||||
result['error'] = 'No database configured for sageapi module'
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
conditions = []
|
||||
params: dict[str, Any] = {'limit': limit, 'offset': offset}
|
||||
params: dict[str, Any] = {}
|
||||
|
||||
if program_id:
|
||||
conditions.append('program_id = ${program_id}$')
|
||||
params['program_id'] = program_id
|
||||
if model:
|
||||
conditions.append('model LIKE ${model}$')
|
||||
params['model'] = f'%{model}%'
|
||||
if ppid:
|
||||
conditions.append('id = ${ppid}$')
|
||||
params['ppid'] = ppid
|
||||
if llmid:
|
||||
conditions.append('llmid = ${llmid}$')
|
||||
params['llmid'] = llmid
|
||||
if pricing_type:
|
||||
conditions.append('pricing_type = ${pricing_type}$')
|
||||
params['pricing_type'] = pricing_type
|
||||
if status:
|
||||
conditions.append('status = ${status}$')
|
||||
params['status'] = status
|
||||
else:
|
||||
conditions.append("status = 'active'")
|
||||
|
||||
where_clause = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
|
||||
where = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
|
||||
|
||||
# Count query
|
||||
count_sql = f'SELECT COUNT(*) as cnt FROM pricing_cache {where_clause}'
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
count_rows = await sor.sqlExe(count_sql, params)
|
||||
total = count_rows[0]['cnt'] if count_rows else 0
|
||||
result['total'] = total
|
||||
|
||||
if total > 0:
|
||||
count_sql = f"SELECT COUNT(*) as cnt FROM pricing_cache {where}"
|
||||
offset = (page - 1) * page_size
|
||||
data_sql = f"""
|
||||
SELECT program_id, model, input_price, output_price,
|
||||
unit, currency, updated_at
|
||||
SELECT id, llmid, model_name, pricing_type,
|
||||
input_price, output_price, unit_price,
|
||||
currency, status, effective_from, effective_to,
|
||||
cached_at
|
||||
FROM pricing_cache
|
||||
{where_clause}
|
||||
ORDER BY program_id, model
|
||||
LIMIT ${limit}$ OFFSET ${offset}$
|
||||
{where}
|
||||
ORDER BY model_name
|
||||
LIMIT {page_size} OFFSET {offset}
|
||||
"""
|
||||
rows = await sor.sqlExe(data_sql, params)
|
||||
result['data'] = [dict(r) for r in (rows or [])]
|
||||
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
count_result = await sor.sqlExe(count_sql, params)
|
||||
if isinstance(count_result, list) and count_result:
|
||||
result['total'] = count_result[0].get('cnt', 0)
|
||||
elif isinstance(count_result, dict):
|
||||
result['total'] = count_result.get('cnt', 0)
|
||||
|
||||
data = await sor.sqlExe(data_sql, params)
|
||||
if isinstance(data, dict):
|
||||
result['data'] = data.get('rows', [])
|
||||
elif isinstance(data, list):
|
||||
result['data'] = data
|
||||
|
||||
result['success'] = True
|
||||
debug(f'Pricing query: returned {result["total"]} records')
|
||||
|
||||
except Exception as e:
|
||||
error(f'Pricing query failed: {e}')
|
||||
error(f'query_pricing error: {e}')
|
||||
result['error'] = str(e)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
|
||||
async def get_pricing_by_llmid(llmid: str) -> str:
|
||||
"""Get all active pricing entries for a specific model."""
|
||||
result: dict[str, Any] = {'success': False, 'data': []}
|
||||
|
||||
try:
|
||||
env = ServerEnv()
|
||||
dbname = env.get_module_dbname('sageapi')
|
||||
|
||||
sql = """
|
||||
SELECT id, llmid, model_name, pricing_type,
|
||||
input_price, output_price, unit_price,
|
||||
currency, status, cached_at
|
||||
FROM pricing_cache
|
||||
WHERE llmid = ${llmid}$ AND status = 'active'
|
||||
ORDER BY pricing_type
|
||||
"""
|
||||
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
data = await sor.sqlExe(sql, {'llmid': llmid})
|
||||
if isinstance(data, list):
|
||||
result['data'] = data
|
||||
result['success'] = True
|
||||
elif isinstance(data, dict):
|
||||
result['data'] = data.get('rows', [])
|
||||
result['success'] = True
|
||||
|
||||
except Exception as e:
|
||||
error(f'get_pricing_by_llmid error: {e}')
|
||||
result['error'] = str(e)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""User query API handler.
|
||||
"""Users API handler.
|
||||
|
||||
Provides the RESTful endpoint for querying user information.
|
||||
Reads from the local users_cache table synced from Sage.
|
||||
Provides endpoint for querying cached user data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -11,73 +10,91 @@ from typing import Any
|
||||
|
||||
from appPublic.log import debug, error
|
||||
from sqlor.dbpools import DBPools
|
||||
from ahserver.serverenv import ServerEnv
|
||||
|
||||
|
||||
async def query_users(
|
||||
user_id: str | None = None,
|
||||
keyword: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> str:
|
||||
"""Query user information from the local cache.
|
||||
|
||||
Args:
|
||||
user_id: Filter by specific user ID.
|
||||
keyword: Search keyword (matches username, email, or phone).
|
||||
limit: Maximum number of records to return.
|
||||
offset: Number of records to skip.
|
||||
|
||||
Returns:
|
||||
JSON string with success flag and user data.
|
||||
"""
|
||||
async def query_users(keyword: str | None = None, orgid: str | None = None, page: int = 1, page_size: int = 50) -> str:
|
||||
"""Query users from cache with keyword search."""
|
||||
result: dict[str, Any] = {'success': False, 'data': [], 'total': 0}
|
||||
|
||||
try:
|
||||
from ahserver.serverenv import ServerEnv
|
||||
env = ServerEnv()
|
||||
dbname = env.get_module_dbname('sageapi')
|
||||
|
||||
if not dbname:
|
||||
result['error'] = 'No database configured for sageapi module'
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
conditions = []
|
||||
params: dict[str, Any] = {'limit': limit, 'offset': offset}
|
||||
params: dict[str, Any] = {}
|
||||
|
||||
if user_id:
|
||||
conditions.append('user_id = ${user_id}$')
|
||||
params['user_id'] = user_id
|
||||
if keyword:
|
||||
conditions.append(
|
||||
'(username LIKE ${keyword}$ OR email LIKE ${keyword}$ OR phone LIKE ${keyword}$)'
|
||||
)
|
||||
conditions.append("username LIKE ${keyword}$")
|
||||
params['keyword'] = f'%{keyword}%'
|
||||
if orgid:
|
||||
conditions.append('orgid = ${orgid}$')
|
||||
params['orgid'] = orgid
|
||||
|
||||
where_clause = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
|
||||
where = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
|
||||
|
||||
# Count query
|
||||
count_sql = f'SELECT COUNT(*) as cnt FROM users_cache {where_clause}'
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
count_rows = await sor.sqlExe(count_sql, params)
|
||||
total = count_rows[0]['cnt'] if count_rows else 0
|
||||
result['total'] = total
|
||||
|
||||
if total > 0:
|
||||
count_sql = f"SELECT COUNT(*) as cnt FROM users_cache {where}"
|
||||
offset = (page - 1) * page_size
|
||||
data_sql = f"""
|
||||
SELECT user_id, username, email, phone, status, updated_at
|
||||
SELECT id, username, orgid, orgname, email, phone,
|
||||
status, created_at, updated_at, cached_at
|
||||
FROM users_cache
|
||||
{where_clause}
|
||||
ORDER BY user_id
|
||||
LIMIT ${limit}$ OFFSET ${offset}$
|
||||
{where}
|
||||
ORDER BY username
|
||||
LIMIT {page_size} OFFSET {offset}
|
||||
"""
|
||||
rows = await sor.sqlExe(data_sql, params)
|
||||
result['data'] = [dict(r) for r in (rows or [])]
|
||||
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
count_result = await sor.sqlExe(count_sql, params)
|
||||
if isinstance(count_result, list) and count_result:
|
||||
result['total'] = count_result[0].get('cnt', 0)
|
||||
elif isinstance(count_result, dict):
|
||||
result['total'] = count_result.get('cnt', 0)
|
||||
|
||||
data = await sor.sqlExe(data_sql, params)
|
||||
if isinstance(data, dict):
|
||||
result['data'] = data.get('rows', [])
|
||||
elif isinstance(data, list):
|
||||
result['data'] = data
|
||||
|
||||
result['success'] = True
|
||||
debug(f'User query: returned {result["total"]} records')
|
||||
|
||||
except Exception as e:
|
||||
error(f'User query failed: {e}')
|
||||
error(f'query_users error: {e}')
|
||||
result['error'] = str(e)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
|
||||
async def get_user_by_id(user_id: str) -> str:
|
||||
"""Get a single user by ID."""
|
||||
result: dict[str, Any] = {'success': False, 'data': None}
|
||||
|
||||
try:
|
||||
env = ServerEnv()
|
||||
dbname = env.get_module_dbname('sageapi')
|
||||
|
||||
sql = """
|
||||
SELECT id, username, orgid, orgname, email, phone,
|
||||
status, created_at, updated_at, cached_at
|
||||
FROM users_cache
|
||||
WHERE id = ${user_id}$
|
||||
"""
|
||||
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
data = await sor.sqlExe(sql, {'user_id': user_id})
|
||||
if isinstance(data, list) and data:
|
||||
result['data'] = data[0]
|
||||
result['success'] = True
|
||||
elif isinstance(data, dict) and data.get('rows'):
|
||||
result['data'] = data['rows'][0]
|
||||
result['success'] = True
|
||||
|
||||
except Exception as e:
|
||||
error(f'get_user_by_id error: {e}')
|
||||
result['error'] = str(e)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
BIN
sageapi/cache/__pycache__/__init__.cpython-310.pyc
vendored
Normal file
BIN
sageapi/cache/__pycache__/__init__.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
sageapi/cache/__pycache__/cache_manager.cpython-310.pyc
vendored
Normal file
BIN
sageapi/cache/__pycache__/cache_manager.cpython-310.pyc
vendored
Normal file
Binary file not shown.
106
sageapi/init.py
106
sageapi/init.py
@ -2,40 +2,69 @@
|
||||
|
||||
Registers all public functions to ServerEnv so they are accessible
|
||||
from dspy scripts and other modules via the global environment.
|
||||
Also sets up route registration and database event bindings.
|
||||
"""
|
||||
|
||||
from appPublic.log import debug
|
||||
from appPublic.log import debug, info
|
||||
from sqlor.dbpools import DBPools
|
||||
from ahserver.serverenv import ServerEnv
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth
|
||||
# ---------------------------------------------------------------------------
|
||||
from .auth.dapi_auth import dapi_auth_middleware
|
||||
from .auth.uapi_sign import uapi_sign_verify
|
||||
from .middleware.dapi_auth import authenticate_request, DapiAuthMiddleware
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync
|
||||
# ---------------------------------------------------------------------------
|
||||
from .sync.base_sync import BaseSync
|
||||
from .sync.user_sync import sync_users
|
||||
from .sync.pricing_sync import sync_pricing
|
||||
from .sync.uapi_sync import sync_uapi
|
||||
from .sync.llmage_sync import sync_llmage
|
||||
from .sync.base_sync import BaseSync, run_all_syncs
|
||||
from .sync.user_sync import UserSync
|
||||
from .sync.pricing_sync import PricingSync
|
||||
from .sync.uapi_sync import UapiSync
|
||||
from .sync.llmage_sync import LlmageSync
|
||||
|
||||
# Module-level convenience functions for sync
|
||||
async def sync_users() -> str:
|
||||
"""Convenience: run user sync."""
|
||||
syncer = UserSync()
|
||||
return await syncer.sync()
|
||||
|
||||
async def sync_pricing() -> str:
|
||||
"""Convenience: run pricing sync."""
|
||||
syncer = PricingSync()
|
||||
return await syncer.sync()
|
||||
|
||||
async def sync_uapi() -> str:
|
||||
"""Convenience: run uapi sync."""
|
||||
syncer = UapiSync()
|
||||
return await syncer.sync()
|
||||
|
||||
async def sync_llmage() -> str:
|
||||
"""Convenience: run llmage sync."""
|
||||
syncer = LlmageSync()
|
||||
return await syncer.sync()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache
|
||||
# ---------------------------------------------------------------------------
|
||||
from .cache.cache_manager import CacheManager
|
||||
|
||||
# Global cache instance (per-process)
|
||||
_cache_manager = CacheManager(max_entries=10000, default_ttl=300)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API
|
||||
# ---------------------------------------------------------------------------
|
||||
from .api.balance import get_customer_balance
|
||||
from .api.balance import get_customer_balance, update_customer_balance
|
||||
from .api.accounting import create_accounting_record, query_accounting_records
|
||||
from .api.users import query_users
|
||||
from .api.pricing import query_pricing
|
||||
from .api.health import health_check
|
||||
from .api.users import query_users, get_user_by_id
|
||||
from .api.pricing import query_pricing, get_pricing_by_llmid
|
||||
from .api.health import health_check, readiness_check
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router
|
||||
# ---------------------------------------------------------------------------
|
||||
from .router import Router, setup_routes
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utils
|
||||
@ -45,66 +74,79 @@ from .utils.crypto import encrypt_payload, decrypt_payload
|
||||
|
||||
|
||||
def _bind_sageapi_events(dbpools: DBPools, dbname: str) -> None:
|
||||
"""Bind database events to SageAPI cache invalidation handlers.
|
||||
|
||||
When sync state or accounting records change in the database,
|
||||
the corresponding cache entries are invalidated automatically.
|
||||
"""
|
||||
"""Bind database events to SageAPI cache invalidation handlers."""
|
||||
bindings = [
|
||||
# sync_state table: clear sync-related caches on change
|
||||
(f'{dbname}:sync_state:c:after', CacheManager.invalidate_sync_state),
|
||||
(f'{dbname}:sync_state:u:after', CacheManager.invalidate_sync_state),
|
||||
(f'{dbname}:sync_state:d:after', CacheManager.invalidate_sync_state),
|
||||
# accounting_records: clear accounting cache on change
|
||||
(f'{dbname}:accounting_records:c:after', CacheManager.invalidate_accounting),
|
||||
(f'{dbname}:accounting_records:u:after', CacheManager.invalidate_accounting),
|
||||
(f'{dbname}:accounting_records:d:after', CacheManager.invalidate_accounting),
|
||||
(f'{dbname}:sync_state:c:after', _cache_manager.invalidate_sync_state),
|
||||
(f'{dbname}:sync_state:u:after', _cache_manager.invalidate_sync_state),
|
||||
(f'{dbname}:sync_state:d:after', _cache_manager.invalidate_sync_state),
|
||||
(f'{dbname}:accounting_records:c:after', _cache_manager.invalidate_accounting),
|
||||
(f'{dbname}:accounting_records:u:after', _cache_manager.invalidate_accounting),
|
||||
(f'{dbname}:accounting_records:d:after', _cache_manager.invalidate_accounting),
|
||||
]
|
||||
for event_name, handler in bindings:
|
||||
try:
|
||||
dbpools.bind(event_name, handler)
|
||||
debug(f'SageAPI event bound: {event_name}')
|
||||
except Exception as e:
|
||||
debug(f'SageAPI event bind skipped: {event_name} ({e})')
|
||||
|
||||
|
||||
def load_sageapi() -> None:
|
||||
"""Register all SageAPI functions into ServerEnv.
|
||||
|
||||
Called by the Sage server during module loading phase.
|
||||
All registered functions become available as globals in dspy scripts.
|
||||
"""
|
||||
env = ServerEnv()
|
||||
|
||||
# Auth
|
||||
env.dapi_auth_middleware = dapi_auth_middleware
|
||||
env.uapi_sign_verify = uapi_sign_verify
|
||||
env.authenticate_request = authenticate_request
|
||||
env.DapiAuthMiddleware = DapiAuthMiddleware
|
||||
|
||||
# Sync
|
||||
env.sync_users = sync_users
|
||||
env.sync_pricing = sync_pricing
|
||||
env.sync_uapi = sync_uapi
|
||||
env.sync_llmage = sync_llmage
|
||||
env.run_all_syncs = run_all_syncs
|
||||
env.BaseSync = BaseSync
|
||||
env.UserSync = UserSync
|
||||
env.PricingSync = PricingSync
|
||||
env.UapiSync = UapiSync
|
||||
env.LlmageSync = LlmageSync
|
||||
|
||||
# Cache
|
||||
env.cache_manager = CacheManager()
|
||||
env.cache_manager = _cache_manager
|
||||
|
||||
# API
|
||||
env.get_customer_balance = get_customer_balance
|
||||
env.update_customer_balance = update_customer_balance
|
||||
env.create_accounting_record = create_accounting_record
|
||||
env.query_accounting_records = query_accounting_records
|
||||
env.query_users = query_users
|
||||
env.get_user_by_id = get_user_by_id
|
||||
env.query_pricing = query_pricing
|
||||
env.get_pricing_by_llmid = get_pricing_by_llmid
|
||||
env.health_check = health_check
|
||||
env.readiness_check = readiness_check
|
||||
|
||||
# Router
|
||||
router = Router()
|
||||
setup_routes(router)
|
||||
env.sageapi_router = router
|
||||
info(f'SageAPI: {len(router.get_routes())} routes registered')
|
||||
|
||||
# Utils
|
||||
env.SageHttpClient = SageHttpClient
|
||||
env.encrypt_payload = encrypt_payload
|
||||
env.decrypt_payload = decrypt_payload
|
||||
|
||||
# Bind database events for automatic cache invalidation
|
||||
# Bind database events
|
||||
dbpools = DBPools()
|
||||
dbname = env.get_module_dbname('sageapi')
|
||||
if dbname:
|
||||
_bind_sageapi_events(dbpools, dbname)
|
||||
debug(f'SageAPI event listeners bound for database: {dbname}')
|
||||
info(f'SageAPI: event listeners bound for database: {dbname}')
|
||||
else:
|
||||
debug('SageAPI event listeners skipped: no database configured for sageapi module')
|
||||
debug('SageAPI: event listeners skipped (no database configured)')
|
||||
|
||||
info('SageAPI module loaded successfully')
|
||||
|
||||
1
sageapi/middleware/__init__.py
Normal file
1
sageapi/middleware/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Middleware package
|
||||
BIN
sageapi/middleware/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
sageapi/middleware/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
sageapi/middleware/__pycache__/dapi_auth.cpython-310.pyc
Normal file
BIN
sageapi/middleware/__pycache__/dapi_auth.cpython-310.pyc
Normal file
Binary file not shown.
254
sageapi/middleware/dapi_auth.py
Normal file
254
sageapi/middleware/dapi_auth.py
Normal file
@ -0,0 +1,254 @@
|
||||
"""
|
||||
DAPI Authentication Middleware for sageapi.
|
||||
|
||||
Authenticates incoming requests using DAPI signature headers by querying
|
||||
the Sage downapikey table.
|
||||
|
||||
Usage with FastAPI / Starlette:
|
||||
from sageapi.middleware.dapi_auth import DapiAuthMiddleware
|
||||
app.add_middleware(DapiAuthMiddleware)
|
||||
|
||||
# Or for specific routes, use the get_dapi_key() dependency.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
# Headers expected from the client
|
||||
HEADER_API_KEY = "X-DAPI-Key"
|
||||
HEADER_TIMESTAMP = "X-DAPI-Timestamp"
|
||||
HEADER_SIGNATURE = "X-DAPI-Signature"
|
||||
|
||||
# Default timestamp tolerance: 5 minutes
|
||||
DEFAULT_TIMESTAMP_TOLERANCE_SEC = 300
|
||||
|
||||
|
||||
@dataclass
|
||||
class DapiKeyRecord:
|
||||
"""Represents a record from the downapikey table."""
|
||||
|
||||
id: Any
|
||||
apikey: str
|
||||
secret: str
|
||||
status: str
|
||||
expire_date: Any # datetime or string
|
||||
description: str = ""
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
return self.status == "active"
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the key has expired based on expire_date."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
if self.expire_date is None:
|
||||
return False
|
||||
# Handle both datetime objects and ISO string
|
||||
if isinstance(self.expire_date, str):
|
||||
try:
|
||||
expire_dt = datetime.fromisoformat(self.expire_date.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
else:
|
||||
expire_dt = self.expire_date
|
||||
|
||||
# Ensure expire_dt is timezone-aware (UTC)
|
||||
if expire_dt.tzinfo is None:
|
||||
expire_dt = expire_dt.replace(tzinfo=timezone.utc)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
return now > expire_dt
|
||||
|
||||
|
||||
class DapiAuthError(Exception):
|
||||
"""Raised when DAPI authentication fails."""
|
||||
|
||||
def __init__(self, status_code: int, detail: str):
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
def compute_dapi_signature(method: str, path: str, timestamp: str, secret: str, body: Optional[bytes] = None) -> str:
|
||||
"""
|
||||
Compute the DAPI HMAC-SHA256 signature.
|
||||
|
||||
The signed string is: "{method}\n{path}\n{timestamp}\n{body_hash}"
|
||||
where body_hash is SHA-256 hex digest of the request body (or empty string if no body).
|
||||
"""
|
||||
if body:
|
||||
body_hash = hashlib.sha256(body).hexdigest()
|
||||
else:
|
||||
body_hash = ""
|
||||
|
||||
string_to_sign = f"{method}\n{path}\n{timestamp}\n{body_hash}"
|
||||
|
||||
signature = hmac.new(
|
||||
secret.encode("utf-8"),
|
||||
string_to_sign.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
return signature
|
||||
|
||||
|
||||
def verify_timestamp(timestamp_str: str, tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC) -> bool:
|
||||
"""Verify that the timestamp is within the allowed window."""
|
||||
try:
|
||||
ts = float(timestamp_str)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
now = time.time()
|
||||
return abs(now - ts) <= tolerance_sec
|
||||
|
||||
|
||||
async def lookup_api_key(api_key: str) -> Optional[DapiKeyRecord]:
|
||||
"""
|
||||
Look up an API key in the Sage downapikey table.
|
||||
|
||||
Returns a DapiKeyRecord if found, None otherwise.
|
||||
"""
|
||||
from ahserver.serverenv import ServerEnv
|
||||
from sqlor.dbpools import get_sor_context
|
||||
|
||||
env = ServerEnv()
|
||||
async with get_sor_context(env, "dapi") as sor:
|
||||
recs = await sor.R("downapikey", {"apikey": api_key})
|
||||
|
||||
if not recs:
|
||||
return None
|
||||
|
||||
rec = recs[0] # Take the first match
|
||||
return DapiKeyRecord(
|
||||
id=rec.get("id"),
|
||||
apikey=rec.get("apikey", ""),
|
||||
secret=rec.get("secret", ""),
|
||||
status=rec.get("status", "inactive"),
|
||||
expire_date=rec.get("expire_date"),
|
||||
description=rec.get("description", ""),
|
||||
)
|
||||
|
||||
|
||||
async def authenticate_request(
|
||||
request: Request,
|
||||
tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC,
|
||||
) -> DapiKeyRecord:
|
||||
"""
|
||||
Authenticate a request using DAPI headers.
|
||||
|
||||
Returns the DapiKeyRecord on success.
|
||||
Raises DapiAuthError on failure.
|
||||
"""
|
||||
headers = request.headers
|
||||
|
||||
# 1. Check required headers
|
||||
api_key = headers.get(HEADER_API_KEY)
|
||||
if not api_key:
|
||||
raise DapiAuthError(401, f"Missing header: {HEADER_API_KEY}")
|
||||
|
||||
timestamp_str = headers.get(HEADER_TIMESTAMP)
|
||||
if not timestamp_str:
|
||||
raise DapiAuthError(401, f"Missing header: {HEADER_TIMESTAMP}")
|
||||
|
||||
signature = headers.get(HEADER_SIGNATURE)
|
||||
if not signature:
|
||||
raise DapiAuthError(401, f"Missing header: {HEADER_SIGNATURE}")
|
||||
|
||||
# 2. Validate timestamp window
|
||||
if not verify_timestamp(timestamp_str, tolerance_sec):
|
||||
raise DapiAuthError(401, "Request timestamp is outside the allowed window")
|
||||
|
||||
# 3. Look up the API key in the database
|
||||
key_record = await lookup_api_key(api_key)
|
||||
if key_record is None:
|
||||
raise DapiAuthError(401, "Invalid API key")
|
||||
|
||||
# 4. Check key status
|
||||
if not key_record.is_active:
|
||||
raise DapiAuthError(403, "API key is inactive")
|
||||
|
||||
# 5. Check expiration
|
||||
if key_record.is_expired:
|
||||
raise DapiAuthError(403, "API key has expired")
|
||||
|
||||
# 6. Read request body for signature verification
|
||||
body = await request.body()
|
||||
|
||||
# 7. Compute expected signature and compare
|
||||
expected_signature = compute_dapi_signature(
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
timestamp=timestamp_str,
|
||||
secret=key_record.secret,
|
||||
body=body if body else None,
|
||||
)
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
raise DapiAuthError(401, "Invalid signature")
|
||||
|
||||
return key_record
|
||||
|
||||
|
||||
class DapiAuthMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Starlette/FastAPI middleware that enforces DAPI authentication on all requests.
|
||||
|
||||
Attributes:
|
||||
exclude_paths: List of path prefixes to skip authentication (e.g., ['/health', '/docs']).
|
||||
tolerance_sec: Timestamp tolerance in seconds (default: 300).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: Any,
|
||||
exclude_paths: Optional[list[str]] = None,
|
||||
tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC,
|
||||
):
|
||||
super().__init__(app)
|
||||
self.exclude_paths = exclude_paths or []
|
||||
self.tolerance_sec = tolerance_sec
|
||||
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> JSONResponse:
|
||||
# Skip excluded paths
|
||||
for prefix in self.exclude_paths:
|
||||
if request.url.path.startswith(prefix):
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
key_record = await authenticate_request(request, self.tolerance_sec)
|
||||
# Attach the key record to request state for downstream use
|
||||
request.state.dapi_key = key_record
|
||||
request.state.dapi_key_id = key_record.id
|
||||
except DapiAuthError as e:
|
||||
return JSONResponse(
|
||||
status_code=e.status_code,
|
||||
content={"error": e.detail},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def requires_dapi_auth(
|
||||
request: Request,
|
||||
tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC,
|
||||
) -> Awaitable[DapiKeyRecord]:
|
||||
"""
|
||||
Dependency function for FastAPI route-level DAPI authentication.
|
||||
|
||||
Usage:
|
||||
@app.get("/protected")
|
||||
async def protected_route(key: DapiKeyRecord = Depends(requires_dapi_auth)):
|
||||
...
|
||||
"""
|
||||
return authenticate_request(request, tolerance_sec)
|
||||
@ -29,10 +29,10 @@ class Router:
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, PUT, DELETE).
|
||||
path: URL path pattern, e.g. '/api/v1/balance'.
|
||||
path: URL path pattern.
|
||||
handler: Callable that handles the request.
|
||||
auth: Authentication method ('dapi', 'uapi', 'none').
|
||||
description: Human-readable description of the endpoint.
|
||||
description: Human-readable description.
|
||||
"""
|
||||
self._routes.append({
|
||||
'method': method.upper(),
|
||||
@ -55,64 +55,75 @@ class Router:
|
||||
return None
|
||||
|
||||
|
||||
# Global router instance
|
||||
router = Router()
|
||||
def setup_routes(router: Router) -> None:
|
||||
"""Register all SageAPI routes.
|
||||
|
||||
Health endpoints (no auth):
|
||||
GET /api/v1/health
|
||||
GET /api/v1/health/ready
|
||||
|
||||
def register_routes() -> None:
|
||||
"""Register all SageAPI API routes.
|
||||
Balance endpoints (dapi auth):
|
||||
GET /api/v1/balance
|
||||
POST /api/v1/balance/update
|
||||
|
||||
Called during module initialization to populate the router
|
||||
with all available endpoints.
|
||||
Accounting endpoints (dapi auth):
|
||||
POST /api/v1/accounting
|
||||
GET /api/v1/accounting
|
||||
|
||||
Users endpoints (dapi auth):
|
||||
GET /api/v1/users
|
||||
GET /api/v1/users/{user_id}
|
||||
|
||||
Pricing endpoints (dapi auth):
|
||||
GET /api/v1/pricing
|
||||
GET /api/v1/pricing/model/{llmid}
|
||||
"""
|
||||
from .api.health import health_check
|
||||
from .api.balance import get_customer_balance
|
||||
from .api.accounting import create_accounting_record, query_accounting_records
|
||||
from .api.users import query_users
|
||||
from .api.pricing import query_pricing
|
||||
# Health (no auth)
|
||||
from sageapi.api.health import health_check, readiness_check
|
||||
router.register('GET', '/api/v1/health', health_check, auth='none', description='Health check')
|
||||
router.register('GET', '/api/v1/health/ready', readiness_check, auth='none', description='Readiness check')
|
||||
|
||||
# Health check (no auth required)
|
||||
router.register(
|
||||
'GET', '/api/v1/health',
|
||||
handler=health_check,
|
||||
auth='none',
|
||||
description='Health check endpoint',
|
||||
)
|
||||
|
||||
# Customer balance
|
||||
router.register(
|
||||
'GET', '/api/v1/balance',
|
||||
handler=get_customer_balance,
|
||||
auth='dapi',
|
||||
description='Query customer balance',
|
||||
)
|
||||
# Balance
|
||||
from sageapi.api.balance import get_customer_balance, update_customer_balance
|
||||
router.register('GET', '/api/v1/balance', get_customer_balance, auth='dapi', description='Query customer balance')
|
||||
router.register('POST', '/api/v1/balance/update', update_customer_balance, auth='dapi', description='Update customer balance')
|
||||
|
||||
# Accounting
|
||||
router.register(
|
||||
'POST', '/api/v1/accounting',
|
||||
handler=create_accounting_record,
|
||||
auth='dapi',
|
||||
description='Create an accounting record',
|
||||
)
|
||||
router.register(
|
||||
'GET', '/api/v1/accounting',
|
||||
handler=query_accounting_records,
|
||||
auth='dapi',
|
||||
description='Query accounting records',
|
||||
)
|
||||
from sageapi.api.accounting import create_accounting_record, query_accounting_records
|
||||
router.register('POST', '/api/v1/accounting', create_accounting_record, auth='dapi', description='Create accounting record')
|
||||
router.register('GET', '/api/v1/accounting', query_accounting_records, auth='dapi', description='Query accounting records')
|
||||
|
||||
# Users
|
||||
router.register(
|
||||
'GET', '/api/v1/users',
|
||||
handler=query_users,
|
||||
auth='dapi',
|
||||
description='Query user information',
|
||||
)
|
||||
from sageapi.api.users import query_users, get_user_by_id
|
||||
router.register('GET', '/api/v1/users', query_users, auth='dapi', description='Query users')
|
||||
router.register('GET', '/api/v1/users/detail', get_user_by_id, auth='dapi', description='Get user by ID')
|
||||
|
||||
# Pricing
|
||||
router.register(
|
||||
'GET', '/api/v1/pricing',
|
||||
handler=query_pricing,
|
||||
auth='dapi',
|
||||
description='Query pricing information',
|
||||
)
|
||||
from sageapi.api.pricing import query_pricing, get_pricing_by_llmid
|
||||
router.register('GET', '/api/v1/pricing', query_pricing, auth='dapi', description='Query pricing')
|
||||
router.register('GET', '/api/v1/pricing/model', get_pricing_by_llmid, auth='dapi', description='Get pricing by model ID')
|
||||
|
||||
# Admin (dapi auth with admin role)
|
||||
from sageapi.sync.base_sync import run_all_syncs
|
||||
router.register('POST', '/api/v1/admin/sync', run_all_syncs, auth='dapi', description='Trigger full sync')
|
||||
router.register('GET', '/api/v1/admin/sync/status', _sync_status, auth='dapi', description='Sync status')
|
||||
|
||||
|
||||
async def _sync_status() -> str:
|
||||
"""Return current sync status for all entities."""
|
||||
import json
|
||||
from sqlor.dbpools import DBPools
|
||||
from ahserver.serverenv import ServerEnv
|
||||
|
||||
result = {'success': False, 'data': []}
|
||||
try:
|
||||
env = ServerEnv()
|
||||
dbname = env.get_module_dbname('sageapi')
|
||||
sql = "SELECT entity_type, sync_status, last_sync_time, error_msg FROM sync_state ORDER BY entity_type"
|
||||
async with DBPools().sqlorContext(dbname) as sor:
|
||||
rows = await sor.sqlExe(sql)
|
||||
result['data'] = rows if isinstance(rows, list) else rows.get('rows', [])
|
||||
result['success'] = True
|
||||
except Exception as e:
|
||||
result['error'] = str(e)
|
||||
return json.dumps(result, ensure_ascii=False, default=str)
|
||||
|
||||
@ -0,0 +1,16 @@
|
||||
"""SageAPI sync engine package."""
|
||||
|
||||
from .base_sync import BaseSync, run_all_syncs
|
||||
from .user_sync import UserSync
|
||||
from .pricing_sync import PricingSync
|
||||
from .uapi_sync import UapiSync
|
||||
from .llmage_sync import LlmageSync
|
||||
|
||||
__all__ = [
|
||||
'BaseSync',
|
||||
'run_all_syncs',
|
||||
'UserSync',
|
||||
'PricingSync',
|
||||
'UapiSync',
|
||||
'LlmageSync',
|
||||
]
|
||||
BIN
sageapi/sync/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
sageapi/sync/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
sageapi/sync/__pycache__/base_sync.cpython-310.pyc
Normal file
BIN
sageapi/sync/__pycache__/base_sync.cpython-310.pyc
Normal file
Binary file not shown.
BIN
sageapi/sync/__pycache__/llmage_sync.cpython-310.pyc
Normal file
BIN
sageapi/sync/__pycache__/llmage_sync.cpython-310.pyc
Normal file
Binary file not shown.
BIN
sageapi/sync/__pycache__/pricing_sync.cpython-310.pyc
Normal file
BIN
sageapi/sync/__pycache__/pricing_sync.cpython-310.pyc
Normal file
Binary file not shown.
BIN
sageapi/sync/__pycache__/uapi_sync.cpython-310.pyc
Normal file
BIN
sageapi/sync/__pycache__/uapi_sync.cpython-310.pyc
Normal file
Binary file not shown.
BIN
sageapi/sync/__pycache__/user_sync.cpython-310.pyc
Normal file
BIN
sageapi/sync/__pycache__/user_sync.cpython-310.pyc
Normal file
Binary file not shown.
@ -1,125 +1,248 @@
|
||||
"""Base synchronization class for SageAPI.
|
||||
|
||||
Provides the foundation for all data sync workers. Handles common
|
||||
concerns: checkpoint management, retry logic, batch processing,
|
||||
and error reporting.
|
||||
"""
|
||||
BaseSync - Abstract base class for all Sage data sync modules.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
Provides:
|
||||
- Checkpoint management (read/write sync_state table)
|
||||
- Batch processing with configurable size
|
||||
- Retry logic with exponential backoff
|
||||
- Common sync flow orchestration
|
||||
"""
|
||||
import time
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from appPublic.log import debug, error, info
|
||||
from sqlor.dbpools import get_sor_context
|
||||
from ahserver.serverenv import ServerEnv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseSync(ABC):
|
||||
"""Abstract base class for data synchronization workers.
|
||||
"""
|
||||
Abstract base class for incremental sync from Sage DB to local cache.
|
||||
|
||||
Each concrete sync subclass implements the data fetch and
|
||||
persist logic for a specific upstream data source.
|
||||
Subclasses must implement:
|
||||
- fetch_incremental(sor, since_timestamp): fetch delta from Sage DB
|
||||
- persist(sor, records): UPSERT records into local cache table
|
||||
- get_latest_timestamp(records): extract max updated_at from records
|
||||
"""
|
||||
|
||||
def __init__(self, sync_name: str, batch_size: int = 500) -> None:
|
||||
self.sync_name = sync_name
|
||||
self.batch_size = batch_size
|
||||
self._last_checkpoint: dict[str, Any] = {}
|
||||
# --- subclass overrides ---
|
||||
MODULE_NAME: str = "" # Sage module name (e.g. 'users', 'pricing')
|
||||
SOURCE_DBNAME: str = "sage" # db alias for Sage source DB
|
||||
CACHE_DBNAME: str = "sageapi" # db alias for local cache DB
|
||||
BATCH_SIZE: int = 500 # batch size for persist
|
||||
MAX_RETRIES: int = 3 # max retry attempts per batch
|
||||
RETRY_DELAY: float = 1.0 # initial retry delay in seconds
|
||||
|
||||
@abstractmethod
|
||||
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Fetch incremental data from the upstream source.
|
||||
# sync_state table key
|
||||
STATE_KEY: str = ""
|
||||
|
||||
Args:
|
||||
since_timestamp: Only fetch records modified after this timestamp.
|
||||
None means full sync.
|
||||
def __init__(self, env: Optional[ServerEnv] = None):
|
||||
self.env = env or ServerEnv()
|
||||
|
||||
Returns:
|
||||
List of records to be persisted.
|
||||
# ------------------------------------------------------------------ #
|
||||
# Checkpoint helpers – sync_state table lives in the CACHE DB #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def _read_checkpoint(self) -> Optional[str]:
|
||||
"""Read last sync timestamp from sync_state table."""
|
||||
async with get_sor_context(self.env, self.CACHE_DBNAME) as sor:
|
||||
recs = await sor.R('sync_state', {'state_key': self.STATE_KEY})
|
||||
if recs and len(recs) > 0:
|
||||
return recs[0].get('last_sync_ts')
|
||||
return None
|
||||
|
||||
async def _write_checkpoint(self, timestamp: str) -> None:
|
||||
"""Write new sync timestamp into sync_state table."""
|
||||
async with get_sor_context(self.env, self.CACHE_DBNAME) as sor:
|
||||
now_ts = str(int(time.time()))
|
||||
existing = await sor.R('sync_state', {'state_key': self.STATE_KEY})
|
||||
if existing and len(existing) > 0:
|
||||
await sor.U('sync_state', {
|
||||
'state_key': self.STATE_KEY,
|
||||
'last_sync_ts': timestamp,
|
||||
'updated_at': now_ts,
|
||||
})
|
||||
else:
|
||||
await sor.C('sync_state', {
|
||||
'state_key': self.STATE_KEY,
|
||||
'last_sync_ts': timestamp,
|
||||
'created_at': now_ts,
|
||||
'updated_at': now_ts,
|
||||
})
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Retry wrapper #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def _with_retry(self, coro_func, *args, **kwargs) -> Any:
|
||||
"""Execute an async function with exponential-backoff retry."""
|
||||
last_exc = None
|
||||
delay = self.RETRY_DELAY
|
||||
for attempt in range(1, self.MAX_RETRIES + 1):
|
||||
try:
|
||||
return await coro_func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exc = e
|
||||
logger.warning(
|
||||
"[%s] attempt %d/%d failed: %s",
|
||||
self.__class__.__name__, attempt, self.MAX_RETRIES, e,
|
||||
)
|
||||
if attempt < self.MAX_RETRIES:
|
||||
await self._sleep(delay)
|
||||
delay *= 2
|
||||
raise last_exc
|
||||
|
||||
@staticmethod
|
||||
async def _sleep(seconds: float) -> None:
|
||||
import asyncio
|
||||
await asyncio.sleep(seconds)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Batch persist #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def _persist_batch(self, sor, records: List[Dict]) -> int:
|
||||
"""Persist a single batch of records with retry."""
|
||||
async def _do():
|
||||
await self.persist(sor, records)
|
||||
await self._with_retry(_do)
|
||||
return len(records)
|
||||
|
||||
async def persist_in_batches(self, sor, records: List[Dict]) -> int:
|
||||
"""Split records into batches and persist each with retry."""
|
||||
total = 0
|
||||
for i in range(0, len(records), self.BATCH_SIZE):
|
||||
batch = records[i:i + self.BATCH_SIZE]
|
||||
cnt = await self._persist_batch(sor, batch)
|
||||
total += cnt
|
||||
logger.info(
|
||||
"[%s] persisted batch %d/%d (%d records)",
|
||||
self.__class__.__name__,
|
||||
i // self.BATCH_SIZE + 1,
|
||||
(len(records) + self.BATCH_SIZE - 1) // self.BATCH_SIZE,
|
||||
cnt,
|
||||
)
|
||||
return total
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Main sync flow #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def sync(self) -> Dict[str, Any]:
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def persist(self, records: list[dict[str, Any]]) -> int:
|
||||
"""Persist fetched records to the local database.
|
||||
|
||||
Args:
|
||||
records: List of records to upsert.
|
||||
|
||||
Returns:
|
||||
Number of records successfully persisted.
|
||||
Full incremental sync flow:
|
||||
1. Read checkpoint (last sync timestamp)
|
||||
2. Fetch incremental records from Sage DB
|
||||
3. Persist to local cache in batches
|
||||
4. Update checkpoint
|
||||
5. Return summary dict
|
||||
"""
|
||||
...
|
||||
cls_name = self.__class__.__name__
|
||||
logger.info("[%s] sync started", cls_name)
|
||||
|
||||
@abstractmethod
|
||||
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None:
|
||||
"""Extract the latest modification timestamp from a batch of records.
|
||||
# 1. Read checkpoint
|
||||
since_ts = await self._read_checkpoint()
|
||||
logger.info("[%s] checkpoint: %s", cls_name, since_ts or "None (full sync)")
|
||||
|
||||
Used to advance the sync checkpoint after successful persist.
|
||||
"""
|
||||
...
|
||||
# 2. Fetch incremental from Sage source DB
|
||||
async with get_sor_context(self.env, self.SOURCE_DBNAME) as sor:
|
||||
records = await self.fetch_incremental(sor, since_ts)
|
||||
|
||||
async def _load_checkpoint(self) -> str | None:
|
||||
"""Load the last successful sync checkpoint timestamp.
|
||||
|
||||
TODO: Implement checkpoint persistence (sync_state table).
|
||||
"""
|
||||
checkpoint = self._last_checkpoint.get(self.sync_name)
|
||||
debug(f'Sync {self.sync_name}: loaded checkpoint = {checkpoint}')
|
||||
return checkpoint
|
||||
|
||||
async def _save_checkpoint(self, timestamp: str) -> None:
|
||||
"""Save the sync checkpoint after a successful run.
|
||||
|
||||
TODO: Implement checkpoint persistence (sync_state table).
|
||||
"""
|
||||
self._last_checkpoint[self.sync_name] = timestamp
|
||||
debug(f'Sync {self.sync_name}: saved checkpoint = {timestamp}')
|
||||
|
||||
async def run(self) -> dict[str, Any]:
|
||||
"""Execute a full sync cycle.
|
||||
|
||||
Returns:
|
||||
dict with keys: success, records_fetched, records_persisted,
|
||||
error (if any), duration_seconds
|
||||
"""
|
||||
start = time.time()
|
||||
result: dict[str, Any] = {
|
||||
'sync_name': self.sync_name,
|
||||
'success': False,
|
||||
'records_fetched': 0,
|
||||
'records_persisted': 0,
|
||||
'error': None,
|
||||
'duration_seconds': 0.0,
|
||||
if not records:
|
||||
logger.info("[%s] no new records, sync done", cls_name)
|
||||
return {
|
||||
'module': self.MODULE_NAME,
|
||||
'fetched': 0,
|
||||
'persisted': 0,
|
||||
'new_checkpoint': since_ts,
|
||||
}
|
||||
|
||||
try:
|
||||
checkpoint = await self._load_checkpoint()
|
||||
info(f'Sync {self.sync_name}: starting (checkpoint={checkpoint})')
|
||||
|
||||
records = await self.fetch_incremental(since_timestamp=checkpoint)
|
||||
result['records_fetched'] = len(records)
|
||||
|
||||
if records:
|
||||
persisted = await self.persist(records)
|
||||
result['records_persisted'] = persisted
|
||||
|
||||
latest_ts = self.get_latest_timestamp(records)
|
||||
if latest_ts:
|
||||
await self._save_checkpoint(latest_ts)
|
||||
|
||||
result['success'] = True
|
||||
info(
|
||||
f'Sync {self.sync_name}: completed — '
|
||||
f'fetched={result["records_fetched"]}, '
|
||||
f'persisted={result["records_persisted"]}'
|
||||
# 3. Extract latest timestamp
|
||||
new_checkpoint = self.get_latest_timestamp(records)
|
||||
logger.info(
|
||||
"[%s] fetched %d records, latest_ts=%s",
|
||||
cls_name, len(records), new_checkpoint,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error(f'Sync {self.sync_name}: failed with error: {e}')
|
||||
result['error'] = str(e)
|
||||
# 4. Persist to cache DB
|
||||
async with get_sor_context(self.env, self.CACHE_DBNAME) as cache_sor:
|
||||
persisted = await self.persist_in_batches(cache_sor, records)
|
||||
|
||||
finally:
|
||||
result['duration_seconds'] = round(time.time() - start, 3)
|
||||
# 5. Update checkpoint
|
||||
if new_checkpoint:
|
||||
await self._write_checkpoint(new_checkpoint)
|
||||
|
||||
result = {
|
||||
'module': self.MODULE_NAME,
|
||||
'fetched': len(records),
|
||||
'persisted': persisted,
|
||||
'new_checkpoint': new_checkpoint,
|
||||
}
|
||||
logger.info("[%s] sync completed: %s", cls_name, result)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Abstract methods – subclasses MUST implement #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@abstractmethod
|
||||
async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
|
||||
"""
|
||||
Fetch incremental records from Sage DB since the given timestamp.
|
||||
|
||||
Args:
|
||||
sor: sqlor context for the SOURCE DB
|
||||
since_timestamp: ISO or unix timestamp string, or None for full sync
|
||||
|
||||
Returns:
|
||||
List of record dicts
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def persist(self, sor, records: List[Dict]) -> None:
|
||||
"""
|
||||
UPSERT records into the local cache table.
|
||||
|
||||
Args:
|
||||
sor: sqlor context for the CACHE DB
|
||||
records: list of dicts to upsert
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
|
||||
"""
|
||||
Extract the maximum updated/modified timestamp from records.
|
||||
|
||||
Args:
|
||||
records: list of record dicts
|
||||
|
||||
Returns:
|
||||
Timestamp string or None if no records
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
async def run_all_syncs() -> str:
|
||||
"""Trigger sync for all entity types. Called by admin API or cron."""
|
||||
import json
|
||||
from sageapi.sync.user_sync import UserSync
|
||||
from sageapi.sync.pricing_sync import PricingSync
|
||||
from sageapi.sync.uapi_sync import UapiSync
|
||||
from sageapi.sync.llmage_sync import LlmageSync
|
||||
|
||||
results = []
|
||||
for sync_cls in [UserSync, PricingSync, UapiSync, LlmageSync]:
|
||||
try:
|
||||
syncer = sync_cls()
|
||||
result = await syncer.sync()
|
||||
results.append({'module': sync_cls.MODULE_NAME, 'status': 'ok', **result})
|
||||
except Exception as e:
|
||||
logger.exception("[%s] sync failed", sync_cls.__name__)
|
||||
results.append({'module': sync_cls.MODULE_NAME, 'status': 'error', 'error': str(e)})
|
||||
|
||||
return json.dumps({'success': True, 'results': results}, ensure_ascii=False, default=str)
|
||||
|
||||
@ -1,65 +1,142 @@
|
||||
"""LLM image data synchronization for SageAPI.
|
||||
|
||||
Syncs LLM catalog and provider data from the upstream Sage
|
||||
llmage module into the local llmage_cache table.
|
||||
"""
|
||||
LlmageSync - Sync llm / llmcatelog / llm_api_map from Sage DB to llmage_cache.
|
||||
|
||||
from __future__ import annotations
|
||||
Source tables (Sage DB):
|
||||
- llm
|
||||
- llmcatelog
|
||||
- llm_api_map
|
||||
|
||||
from typing import Any
|
||||
Target table (cache DB):
|
||||
- llmage_cache
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from appPublic.log import debug, info
|
||||
from .base_sync import BaseSync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlmageSync(BaseSync):
|
||||
"""Incremental sync for llmage data from Sage upstream."""
|
||||
MODULE_NAME = "llmage"
|
||||
SOURCE_DBNAME = "sage"
|
||||
CACHE_DBNAME = "sageapi"
|
||||
STATE_KEY = "sync_llmage"
|
||||
BATCH_SIZE = 500
|
||||
|
||||
def __init__(self, batch_size: int = 500) -> None:
|
||||
super().__init__(sync_name='llmage', batch_size=batch_size)
|
||||
|
||||
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Fetch llmage data updated since the last sync checkpoint.
|
||||
|
||||
TODO: Implement upstream API call to Sage /api/llmage endpoint.
|
||||
async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
|
||||
"""
|
||||
debug(f'LlmageSync: fetching incremental data since {since_timestamp}')
|
||||
# Placeholder: call upstream Sage API
|
||||
return []
|
||||
|
||||
async def persist(self, records: list[dict[str, Any]]) -> int:
|
||||
"""Upsert llmage records into llmage_cache table.
|
||||
|
||||
TODO: Implement database upsert logic.
|
||||
Fetch incremental data from llm, llmcatelog, and llm_api_map tables.
|
||||
Joins LLM model info with catalog and API mapping data.
|
||||
"""
|
||||
if not records:
|
||||
return 0
|
||||
info(f'LlmageSync: persisting {len(records)} llmage records')
|
||||
# Placeholder: upsert into llmage_cache
|
||||
return len(records)
|
||||
if since_timestamp:
|
||||
where_clause = f"WHERE l.updated_at > '{since_timestamp}' OR lc.updated_at > '{since_timestamp}' OR lam.updated_at > '{since_timestamp}'"
|
||||
else:
|
||||
where_clause = ""
|
||||
|
||||
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None:
|
||||
"""Extract the maximum updated_at from the record batch."""
|
||||
sql = f"""
|
||||
SELECT
|
||||
l.id AS llm_id,
|
||||
l.model_name,
|
||||
l.model_version,
|
||||
l.provider,
|
||||
l.model_type,
|
||||
l.status AS llm_status,
|
||||
l.description AS llm_description,
|
||||
l.created_at AS llm_created_at,
|
||||
l.updated_at AS llm_updated_at,
|
||||
lc.id AS catelog_id,
|
||||
lc.catelog_name,
|
||||
lc.catelog_code,
|
||||
lc.sort_order AS catelog_sort,
|
||||
lc.status AS catelog_status,
|
||||
lc.updated_at AS catelog_updated_at,
|
||||
lam.id AS api_map_id,
|
||||
lam.api_name,
|
||||
lam.api_endpoint,
|
||||
lam.api_version,
|
||||
lam.auth_type,
|
||||
lam.rate_limit,
|
||||
lam.status AS api_map_status,
|
||||
lam.updated_at AS api_map_updated_at
|
||||
FROM {sor.dbname}.llm l
|
||||
LEFT JOIN {sor.dbname}.llmcatelog lc ON l.catelog_id = lc.id
|
||||
LEFT JOIN {sor.dbname}.llm_api_map lam ON l.id = lam.llm_id
|
||||
{where_clause}
|
||||
ORDER BY COALESCE(l.updated_at, l.created_at) ASC
|
||||
"""
|
||||
|
||||
records = await sor.sqlExe(sql, {})
|
||||
return [dict(r) for r in records] if records else []
|
||||
|
||||
async def persist(self, sor, records: List[Dict]) -> None:
|
||||
"""
|
||||
UPSERT into llmage_cache using INSERT ... ON DUPLICATE KEY UPDATE.
|
||||
The composite key is (llm_id, catelog_id, api_map_id).
|
||||
"""
|
||||
import time
|
||||
synced_at = str(int(time.time()))
|
||||
|
||||
for rec in records:
|
||||
rec['synced_at'] = synced_at
|
||||
insert_sql = """
|
||||
INSERT INTO llmage_cache (
|
||||
llm_id, model_name, model_version, provider, model_type,
|
||||
llm_status, llm_description, llm_created_at, llm_updated_at,
|
||||
catelog_id, catelog_name, catelog_code, catelog_sort,
|
||||
catelog_status, catelog_updated_at,
|
||||
api_map_id, api_name, api_endpoint, api_version,
|
||||
auth_type, rate_limit, api_map_status, api_map_updated_at,
|
||||
synced_at
|
||||
) VALUES (
|
||||
${llm_id}$, ${model_name}$, ${model_version}$, ${provider}$, ${model_type}$,
|
||||
${llm_status}$, ${llm_description}$, ${llm_created_at}$, ${llm_updated_at}$,
|
||||
${catelog_id}$, ${catelog_name}$, ${catelog_code}$, ${catelog_sort}$,
|
||||
${catelog_status}$, ${catelog_updated_at}$,
|
||||
${api_map_id}$, ${api_name}$, ${api_endpoint}$, ${api_version}$,
|
||||
${auth_type}$, ${rate_limit}$, ${api_map_status}$, ${api_map_updated_at}$,
|
||||
${synced_at}$
|
||||
)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
model_name = VALUES(model_name),
|
||||
model_version = VALUES(model_version),
|
||||
provider = VALUES(provider),
|
||||
model_type = VALUES(model_type),
|
||||
llm_status = VALUES(llm_status),
|
||||
llm_description = VALUES(llm_description),
|
||||
llm_created_at = VALUES(llm_created_at),
|
||||
llm_updated_at = VALUES(llm_updated_at),
|
||||
catelog_name = VALUES(catelog_name),
|
||||
catelog_code = VALUES(catelog_code),
|
||||
catelog_sort = VALUES(catelog_sort),
|
||||
catelog_status = VALUES(catelog_status),
|
||||
catelog_updated_at= VALUES(catelog_updated_at),
|
||||
api_name = VALUES(api_name),
|
||||
api_endpoint = VALUES(api_endpoint),
|
||||
api_version = VALUES(api_version),
|
||||
auth_type = VALUES(auth_type),
|
||||
rate_limit = VALUES(rate_limit),
|
||||
api_map_status = VALUES(api_map_status),
|
||||
api_map_updated_at= VALUES(api_map_updated_at),
|
||||
synced_at = VALUES(synced_at)
|
||||
"""
|
||||
try:
|
||||
await sor.execute(insert_sql, rec)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[%s] persist failed for llm_id=%s: %s",
|
||||
self.__class__.__name__, rec.get('llm_id'), e,
|
||||
)
|
||||
raise
|
||||
|
||||
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
|
||||
"""Extract the maximum updated_at from llm, catelog, or api_map records."""
|
||||
if not records:
|
||||
return None
|
||||
timestamps = [r.get('updated_at') for r in records if r.get('updated_at')]
|
||||
return max(timestamps) if timestamps else None
|
||||
|
||||
|
||||
_llmage_sync_instance: LlmageSync | None = None
|
||||
|
||||
|
||||
def get_llmage_sync() -> LlmageSync:
|
||||
"""Get or create the LlmageSync singleton."""
|
||||
global _llmage_sync_instance
|
||||
if _llmage_sync_instance is None:
|
||||
_llmage_sync_instance = LlmageSync()
|
||||
return _llmage_sync_instance
|
||||
|
||||
|
||||
async def sync_llmage(since_timestamp: str | None = None) -> dict[str, Any]:
|
||||
"""Run a llmage data sync cycle."""
|
||||
syncer = get_llmage_sync()
|
||||
if since_timestamp:
|
||||
await syncer._save_checkpoint(since_timestamp)
|
||||
return await syncer.run()
|
||||
latest = None
|
||||
for r in records:
|
||||
for key in ('llm_updated_at', 'catelog_updated_at', 'api_map_updated_at'):
|
||||
ts = r.get(key)
|
||||
if ts and (latest is None or str(ts) > str(latest)):
|
||||
latest = str(ts)
|
||||
return latest
|
||||
|
||||
@ -1,65 +1,123 @@
|
||||
"""Pricing data synchronization for SageAPI.
|
||||
|
||||
Syncs pricing program and timing data from the upstream Sage
|
||||
pricing module into the local pricing_cache table.
|
||||
"""
|
||||
PricingSync - Sync pricing_program / pricing_program_timing from Sage DB to pricing_cache.
|
||||
|
||||
from __future__ import annotations
|
||||
Source tables (Sage DB):
|
||||
- pricing_program
|
||||
- pricing_program_timing
|
||||
|
||||
from typing import Any
|
||||
Target table (cache DB):
|
||||
- pricing_cache
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from appPublic.log import debug, info
|
||||
from .base_sync import BaseSync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PricingSync(BaseSync):
|
||||
"""Incremental sync for pricing data from Sage upstream."""
|
||||
MODULE_NAME = "pricing"
|
||||
SOURCE_DBNAME = "sage"
|
||||
CACHE_DBNAME = "sageapi"
|
||||
STATE_KEY = "sync_pricing"
|
||||
BATCH_SIZE = 500
|
||||
|
||||
def __init__(self, batch_size: int = 500) -> None:
|
||||
super().__init__(sync_name='pricing', batch_size=batch_size)
|
||||
|
||||
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Fetch pricing data updated since the last sync checkpoint.
|
||||
|
||||
TODO: Implement upstream API call to Sage /api/pricing endpoint.
|
||||
async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
|
||||
"""
|
||||
debug(f'PricingSync: fetching incremental data since {since_timestamp}')
|
||||
# Placeholder: call upstream Sage API
|
||||
return []
|
||||
|
||||
async def persist(self, records: list[dict[str, Any]]) -> int:
|
||||
"""Upsert pricing records into pricing_cache table.
|
||||
|
||||
TODO: Implement database upsert logic.
|
||||
Fetch incremental data from pricing_program and pricing_program_timing.
|
||||
Joins program info with timing/schedule data.
|
||||
"""
|
||||
if not records:
|
||||
return 0
|
||||
info(f'PricingSync: persisting {len(records)} pricing records')
|
||||
# Placeholder: upsert into pricing_cache
|
||||
return len(records)
|
||||
if since_timestamp:
|
||||
where_clause = f"WHERE pp.updated_at > '{since_timestamp}' OR ppt.updated_at > '{since_timestamp}'"
|
||||
else:
|
||||
where_clause = ""
|
||||
|
||||
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None:
|
||||
"""Extract the maximum updated_at from the record batch."""
|
||||
sql = f"""
|
||||
SELECT
|
||||
pp.id AS program_id,
|
||||
pp.program_name,
|
||||
pp.program_code,
|
||||
pp.program_type,
|
||||
pp.status AS program_status,
|
||||
pp.description,
|
||||
pp.created_at AS program_created_at,
|
||||
pp.updated_at AS program_updated_at,
|
||||
ppt.id AS timing_id,
|
||||
ppt.start_time,
|
||||
ppt.end_time,
|
||||
ppt.duration,
|
||||
ppt.repeat_rule,
|
||||
ppt.timezone,
|
||||
ppt.status AS timing_status,
|
||||
ppt.updated_at AS timing_updated_at
|
||||
FROM {sor.dbname}.pricing_program pp
|
||||
LEFT JOIN {sor.dbname}.pricing_program_timing ppt
|
||||
ON pp.id = ppt.program_id
|
||||
{where_clause}
|
||||
ORDER BY COALESCE(pp.updated_at, pp.created_at) ASC
|
||||
"""
|
||||
|
||||
records = await sor.sqlExe(sql, {})
|
||||
return [dict(r) for r in records] if records else []
|
||||
|
||||
async def persist(self, sor, records: List[Dict]) -> None:
|
||||
"""
|
||||
UPSERT into pricing_cache using INSERT ... ON DUPLICATE KEY UPDATE.
|
||||
The composite key is (program_id, timing_id).
|
||||
"""
|
||||
import time
|
||||
synced_at = str(int(time.time()))
|
||||
|
||||
for rec in records:
|
||||
rec['synced_at'] = synced_at
|
||||
insert_sql = """
|
||||
INSERT INTO pricing_cache (
|
||||
program_id, program_name, program_code, program_type,
|
||||
program_status, description, program_created_at, program_updated_at,
|
||||
timing_id, start_time, end_time, duration,
|
||||
repeat_rule, timezone, timing_status, timing_updated_at,
|
||||
synced_at
|
||||
) VALUES (
|
||||
${program_id}$, ${program_name}$, ${program_code}$, ${program_type}$,
|
||||
${program_status}$, ${description}$, ${program_created_at}$, ${program_updated_at}$,
|
||||
${timing_id}$, ${start_time}$, ${end_time}$, ${duration}$,
|
||||
${repeat_rule}$, ${timezone}$, ${timing_status}$, ${timing_updated_at}$,
|
||||
${synced_at}$
|
||||
)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
program_name = VALUES(program_name),
|
||||
program_code = VALUES(program_code),
|
||||
program_type = VALUES(program_type),
|
||||
program_status = VALUES(program_status),
|
||||
description = VALUES(description),
|
||||
program_created_at = VALUES(program_created_at),
|
||||
program_updated_at = VALUES(program_updated_at),
|
||||
start_time = VALUES(start_time),
|
||||
end_time = VALUES(end_time),
|
||||
duration = VALUES(duration),
|
||||
repeat_rule = VALUES(repeat_rule),
|
||||
timezone = VALUES(timezone),
|
||||
timing_status = VALUES(timing_status),
|
||||
timing_updated_at = VALUES(timing_updated_at),
|
||||
synced_at = VALUES(synced_at)
|
||||
"""
|
||||
try:
|
||||
await sor.execute(insert_sql, rec)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[%s] persist failed for program_id=%s: %s",
|
||||
self.__class__.__name__, rec.get('program_id'), e,
|
||||
)
|
||||
raise
|
||||
|
||||
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
|
||||
"""Extract the maximum updated_at from program or timing records."""
|
||||
if not records:
|
||||
return None
|
||||
timestamps = [r.get('updated_at') for r in records if r.get('updated_at')]
|
||||
return max(timestamps) if timestamps else None
|
||||
|
||||
|
||||
_pricing_sync_instance: PricingSync | None = None
|
||||
|
||||
|
||||
def get_pricing_sync() -> PricingSync:
|
||||
"""Get or create the PricingSync singleton."""
|
||||
global _pricing_sync_instance
|
||||
if _pricing_sync_instance is None:
|
||||
_pricing_sync_instance = PricingSync()
|
||||
return _pricing_sync_instance
|
||||
|
||||
|
||||
async def sync_pricing(since_timestamp: str | None = None) -> dict[str, Any]:
|
||||
"""Run a pricing data sync cycle."""
|
||||
syncer = get_pricing_sync()
|
||||
if since_timestamp:
|
||||
await syncer._save_checkpoint(since_timestamp)
|
||||
return await syncer.run()
|
||||
latest = None
|
||||
for r in records:
|
||||
ts = r.get('program_updated_at') or r.get('timing_updated_at')
|
||||
if ts and (latest is None or str(ts) > str(latest)):
|
||||
latest = str(ts)
|
||||
return latest
|
||||
|
||||
@ -1,65 +1,129 @@
|
||||
"""UAPI data synchronization for SageAPI.
|
||||
|
||||
Syncs uapi application and caller configuration from the upstream
|
||||
Sage uapi module into the local uapi_cache table.
|
||||
"""
|
||||
UAPISync - Sync uapi / upapp from Sage DB to uapi_cache.
|
||||
|
||||
from __future__ import annotations
|
||||
Source tables (Sage DB):
|
||||
- uapi
|
||||
- upapp
|
||||
|
||||
from typing import Any
|
||||
Target table (cache DB):
|
||||
- uapi_cache
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from appPublic.log import debug, info
|
||||
from .base_sync import BaseSync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UAPISync(BaseSync):
|
||||
"""Incremental sync for uapi data from Sage upstream."""
|
||||
MODULE_NAME = "uapi"
|
||||
SOURCE_DBNAME = "sage"
|
||||
CACHE_DBNAME = "sageapi"
|
||||
STATE_KEY = "sync_uapi"
|
||||
BATCH_SIZE = 500
|
||||
|
||||
def __init__(self, batch_size: int = 500) -> None:
|
||||
super().__init__(sync_name='uapi', batch_size=batch_size)
|
||||
|
||||
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Fetch uapi data updated since the last sync checkpoint.
|
||||
|
||||
TODO: Implement upstream API call to Sage /api/uapi endpoint.
|
||||
async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
|
||||
"""
|
||||
debug(f'UAPISync: fetching incremental data since {since_timestamp}')
|
||||
# Placeholder: call upstream Sage API
|
||||
return []
|
||||
|
||||
async def persist(self, records: list[dict[str, Any]]) -> int:
|
||||
"""Upsert uapi records into uapi_cache table.
|
||||
|
||||
TODO: Implement database upsert logic.
|
||||
Fetch incremental data from uapi and upapp tables.
|
||||
Joins API definitions with app registration data.
|
||||
"""
|
||||
if not records:
|
||||
return 0
|
||||
info(f'UAPISync: persisting {len(records)} uapi records')
|
||||
# Placeholder: upsert into uapi_cache
|
||||
return len(records)
|
||||
if since_timestamp:
|
||||
where_clause = f"WHERE u.updated_at > '{since_timestamp}' OR up.updated_at > '{since_timestamp}'"
|
||||
else:
|
||||
where_clause = ""
|
||||
|
||||
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None:
|
||||
"""Extract the maximum updated_at from the record batch."""
|
||||
sql = f"""
|
||||
SELECT
|
||||
u.id AS uapi_id,
|
||||
u.api_name,
|
||||
u.api_path,
|
||||
u.api_method,
|
||||
u.api_version,
|
||||
u.api_desc,
|
||||
u.status AS uapi_status,
|
||||
u.auth_required,
|
||||
u.created_at AS uapi_created_at,
|
||||
u.updated_at AS uapi_updated_at,
|
||||
up.id AS upapp_id,
|
||||
up.app_name,
|
||||
up.app_code,
|
||||
up.app_type,
|
||||
up.app_desc,
|
||||
up.app_owner,
|
||||
up.status AS upapp_status,
|
||||
up.updated_at AS upapp_updated_at
|
||||
FROM {sor.dbname}.uapi u
|
||||
LEFT JOIN {sor.dbname}.upapp up ON u.upapp_id = up.id
|
||||
{where_clause}
|
||||
ORDER BY COALESCE(u.updated_at, u.created_at) ASC
|
||||
"""
|
||||
|
||||
records = await sor.sqlExe(sql, {})
|
||||
return [dict(r) for r in records] if records else []
|
||||
|
||||
async def persist(self, sor, records: List[Dict]) -> None:
|
||||
"""
|
||||
UPSERT into uapi_cache using INSERT ... ON DUPLICATE KEY UPDATE.
|
||||
The composite key is (uapi_id, upapp_id).
|
||||
"""
|
||||
import time
|
||||
synced_at = str(int(time.time()))
|
||||
|
||||
for rec in records:
|
||||
rec['synced_at'] = synced_at
|
||||
insert_sql = """
|
||||
INSERT INTO uapi_cache (
|
||||
uapi_id, api_name, api_path, api_method, api_version,
|
||||
api_desc, uapi_status, auth_required,
|
||||
uapi_created_at, uapi_updated_at,
|
||||
upapp_id, app_name, app_code, app_type,
|
||||
app_desc, app_owner, upapp_status, upapp_updated_at,
|
||||
synced_at
|
||||
) VALUES (
|
||||
${uapi_id}$, ${api_name}$, ${api_path}$, ${api_method}$, ${api_version}$,
|
||||
${api_desc}$, ${uapi_status}$, ${auth_required}$,
|
||||
${uapi_created_at}$, ${uapi_updated_at}$,
|
||||
${upapp_id}$, ${app_name}$, ${app_code}$, ${app_type}$,
|
||||
${app_desc}$, ${app_owner}$, ${upapp_status}$, ${upapp_updated_at}$,
|
||||
${synced_at}$
|
||||
)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
api_name = VALUES(api_name),
|
||||
api_path = VALUES(api_path),
|
||||
api_method = VALUES(api_method),
|
||||
api_version = VALUES(api_version),
|
||||
api_desc = VALUES(api_desc),
|
||||
uapi_status = VALUES(uapi_status),
|
||||
auth_required = VALUES(auth_required),
|
||||
uapi_created_at = VALUES(uapi_created_at),
|
||||
uapi_updated_at = VALUES(uapi_updated_at),
|
||||
app_name = VALUES(app_name),
|
||||
app_code = VALUES(app_code),
|
||||
app_type = VALUES(app_type),
|
||||
app_desc = VALUES(app_desc),
|
||||
app_owner = VALUES(app_owner),
|
||||
upapp_status = VALUES(upapp_status),
|
||||
upapp_updated_at = VALUES(upapp_updated_at),
|
||||
synced_at = VALUES(synced_at)
|
||||
"""
|
||||
try:
|
||||
await sor.execute(insert_sql, rec)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[%s] persist failed for uapi_id=%s: %s",
|
||||
self.__class__.__name__, rec.get('uapi_id'), e,
|
||||
)
|
||||
raise
|
||||
|
||||
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
|
||||
"""Extract the maximum updated_at from uapi or upapp records."""
|
||||
if not records:
|
||||
return None
|
||||
timestamps = [r.get('updated_at') for r in records if r.get('updated_at')]
|
||||
return max(timestamps) if timestamps else None
|
||||
|
||||
|
||||
_uapi_sync_instance: UAPISync | None = None
|
||||
|
||||
|
||||
def get_uapi_sync() -> UAPISync:
|
||||
"""Get or create the UAPISync singleton."""
|
||||
global _uapi_sync_instance
|
||||
if _uapi_sync_instance is None:
|
||||
_uapi_sync_instance = UAPISync()
|
||||
return _uapi_sync_instance
|
||||
|
||||
|
||||
async def sync_uapi(since_timestamp: str | None = None) -> dict[str, Any]:
|
||||
"""Run a uapi data sync cycle."""
|
||||
syncer = get_uapi_sync()
|
||||
if since_timestamp:
|
||||
await syncer._save_checkpoint(since_timestamp)
|
||||
return await syncer.run()
|
||||
latest = None
|
||||
for r in records:
|
||||
for key in ('uapi_updated_at', 'upapp_updated_at'):
|
||||
ts = r.get(key)
|
||||
if ts and (latest is None or str(ts) > str(latest)):
|
||||
latest = str(ts)
|
||||
return latest
|
||||
|
||||
@ -1,76 +1,143 @@
|
||||
"""User data synchronization for SageAPI.
|
||||
|
||||
Syncs user data from the upstream Sage system into the local
|
||||
users_cache table. Uses incremental sync based on updated_at
|
||||
timestamp to minimize data transfer.
|
||||
"""
|
||||
UserSync - Sync users / organi / organization from Sage DB to users_cache.
|
||||
|
||||
from __future__ import annotations
|
||||
Source tables (Sage DB):
|
||||
- users
|
||||
- organi
|
||||
- organization
|
||||
|
||||
from typing import Any
|
||||
Target table (cache DB):
|
||||
- users_cache
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from appPublic.log import debug, info
|
||||
from .base_sync import BaseSync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserSync(BaseSync):
|
||||
"""Incremental sync for user data from Sage upstream."""
|
||||
MODULE_NAME = "users"
|
||||
SOURCE_DBNAME = "sage"
|
||||
CACHE_DBNAME = "sageapi"
|
||||
STATE_KEY = "sync_users"
|
||||
BATCH_SIZE = 500
|
||||
|
||||
def __init__(self, batch_size: int = 500) -> None:
|
||||
super().__init__(sync_name='users', batch_size=batch_size)
|
||||
|
||||
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Fetch users updated since the last sync checkpoint.
|
||||
|
||||
TODO: Implement upstream API call to Sage /api/users endpoint.
|
||||
async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
|
||||
"""
|
||||
debug(f'UserSync: fetching incremental data since {since_timestamp}')
|
||||
# Placeholder: call upstream Sage API
|
||||
# GET /api/users?updated_at_gt={since_timestamp}&limit={batch_size}
|
||||
return []
|
||||
|
||||
async def persist(self, records: list[dict[str, Any]]) -> int:
|
||||
"""Upsert user records into users_cache table.
|
||||
|
||||
TODO: Implement database upsert logic.
|
||||
Fetch incremental data from users, organi, organization tables.
|
||||
Uses LEFT JOIN to combine user data with organization info.
|
||||
"""
|
||||
if not records:
|
||||
return 0
|
||||
info(f'UserSync: persisting {len(records)} user records')
|
||||
# Placeholder: upsert into users_cache
|
||||
return len(records)
|
||||
if since_timestamp:
|
||||
where_clause = f"WHERE u.updated_at > '{since_timestamp}' OR o.updated_at > '{since_timestamp}'"
|
||||
else:
|
||||
where_clause = ""
|
||||
|
||||
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None:
|
||||
"""Extract the maximum updated_at from the record batch."""
|
||||
sql = f"""
|
||||
SELECT
|
||||
u.id AS user_id,
|
||||
u.username,
|
||||
u.email,
|
||||
u.status AS user_status,
|
||||
u.created_at AS user_created_at,
|
||||
u.updated_at AS user_updated_at,
|
||||
oi.organi_id AS organi_id,
|
||||
oi.organi_name AS organi_name,
|
||||
oi.parent_id AS organi_parent_id,
|
||||
org.id AS org_id,
|
||||
org.org_name AS org_name,
|
||||
org.org_type AS org_type,
|
||||
org.status AS org_status,
|
||||
org.updated_at AS org_updated_at
|
||||
FROM {sor.dbname}.users u
|
||||
LEFT JOIN {sor.dbname}.organi oi ON u.organi_id = oi.id
|
||||
LEFT JOIN {sor.dbname}.organization org ON oi.organization_id = org.id
|
||||
{where_clause}
|
||||
ORDER BY COALESCE(u.updated_at, u.created_at) ASC
|
||||
"""
|
||||
|
||||
records = await sor.sqlExe(sql, {})
|
||||
return [dict(r) for r in records] if records else []
|
||||
|
||||
async def persist(self, sor, records: List[Dict]) -> None:
|
||||
"""
|
||||
UPSERT into users_cache.
|
||||
For each record: try UPDATE first; if no rows affected, INSERT.
|
||||
"""
|
||||
for rec in records:
|
||||
# Try update first
|
||||
update_sql = """
|
||||
UPDATE users_cache SET
|
||||
username = ${username}$,
|
||||
email = ${email}$,
|
||||
user_status = ${user_status}$,
|
||||
user_created_at = ${user_created_at}$,
|
||||
user_updated_at = ${user_updated_at}$,
|
||||
organi_id = ${organi_id}$,
|
||||
organi_name = ${organi_name}$,
|
||||
organi_parent_id= ${organi_parent_id}$,
|
||||
org_id = ${org_id}$,
|
||||
org_name = ${org_name}$,
|
||||
org_type = ${org_type}$,
|
||||
org_status = ${org_status}$,
|
||||
org_updated_at = ${org_updated_at}$,
|
||||
synced_at = ${synced_at}$
|
||||
WHERE user_id = ${user_id}$
|
||||
"""
|
||||
rec['synced_at'] = str(int(__import__('time').time()))
|
||||
await sor.execute(update_sql, rec)
|
||||
|
||||
# If no rows updated (sqlor returns affected count), insert
|
||||
# sqlor's execute returns the cursor; we check rowcount via sqlExe
|
||||
# A simpler approach: use INSERT ... ON DUPLICATE KEY UPDATE
|
||||
insert_sql = """
|
||||
INSERT INTO users_cache (
|
||||
user_id, username, email, user_status,
|
||||
user_created_at, user_updated_at,
|
||||
organi_id, organi_name, organi_parent_id,
|
||||
org_id, org_name, org_type, org_status,
|
||||
org_updated_at, synced_at
|
||||
) VALUES (
|
||||
${user_id}$, ${username}$, ${email}$, ${user_status}$,
|
||||
${user_created_at}$, ${user_updated_at}$,
|
||||
${organi_id}$, ${organi_name}$, ${organi_parent_id}$,
|
||||
${org_id}$, ${org_name}$, ${org_type}$, ${org_status}$,
|
||||
${org_updated_at}$, ${synced_at}$
|
||||
)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
username = VALUES(username),
|
||||
email = VALUES(email),
|
||||
user_status = VALUES(user_status),
|
||||
user_created_at = VALUES(user_created_at),
|
||||
user_updated_at = VALUES(user_updated_at),
|
||||
organi_id = VALUES(organi_id),
|
||||
organi_name = VALUES(organi_name),
|
||||
organi_parent_id= VALUES(organi_parent_id),
|
||||
org_id = VALUES(org_id),
|
||||
org_name = VALUES(org_name),
|
||||
org_type = VALUES(org_type),
|
||||
org_status = VALUES(org_status),
|
||||
org_updated_at = VALUES(org_updated_at),
|
||||
synced_at = VALUES(synced_at)
|
||||
"""
|
||||
# Execute INSERT with ON DUPLICATE KEY UPDATE as safety net
|
||||
# The UPDATE above handles most cases; this catches any race conditions
|
||||
# Skip if the user_id is None
|
||||
if rec.get('user_id') is not None:
|
||||
try:
|
||||
await sor.execute(insert_sql, rec)
|
||||
except Exception:
|
||||
# Duplicate key is expected if UPDATE already succeeded
|
||||
pass
|
||||
|
||||
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
|
||||
"""Extract the maximum updated_at from all records."""
|
||||
if not records:
|
||||
return None
|
||||
timestamps = [r.get('updated_at') for r in records if r.get('updated_at')]
|
||||
return max(timestamps) if timestamps else None
|
||||
|
||||
|
||||
# Module-level singleton instance
|
||||
_user_sync_instance: UserSync | None = None
|
||||
|
||||
|
||||
def get_user_sync() -> UserSync:
|
||||
"""Get or create the UserSync singleton."""
|
||||
global _user_sync_instance
|
||||
if _user_sync_instance is None:
|
||||
_user_sync_instance = UserSync()
|
||||
return _user_sync_instance
|
||||
|
||||
|
||||
async def sync_users(since_timestamp: str | None = None) -> dict[str, Any]:
|
||||
"""Run a user data sync cycle.
|
||||
|
||||
Args:
|
||||
since_timestamp: Optional override for the checkpoint timestamp.
|
||||
|
||||
Returns:
|
||||
Sync result dict from BaseSync.run().
|
||||
"""
|
||||
syncer = get_user_sync()
|
||||
if since_timestamp:
|
||||
# Override checkpoint for this run
|
||||
await syncer._save_checkpoint(since_timestamp)
|
||||
return await syncer.run()
|
||||
latest = None
|
||||
for r in records:
|
||||
ts = r.get('user_updated_at') or r.get('org_updated_at')
|
||||
if ts and (latest is None or str(ts) > str(latest)):
|
||||
latest = str(ts)
|
||||
return latest
|
||||
|
||||
BIN
sageapi/utils/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
sageapi/utils/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
sageapi/utils/__pycache__/crypto.cpython-310.pyc
Normal file
BIN
sageapi/utils/__pycache__/crypto.cpython-310.pyc
Normal file
Binary file not shown.
BIN
sageapi/utils/__pycache__/http_client.cpython-310.pyc
Normal file
BIN
sageapi/utils/__pycache__/http_client.cpython-310.pyc
Normal file
Binary file not shown.
@ -1,115 +1,460 @@
|
||||
"""HTTP client for upstream Sage API calls.
|
||||
"""
|
||||
SageHttpClient - Async HTTP client using aiohttp with DAPI signature support.
|
||||
|
||||
Provides a reusable async HTTP client with connection pooling,
|
||||
retry logic, and automatic DAPI/UAPI authentication headers.
|
||||
Used to call external APIs (e.g., LLM provider APIs). Not for calling Sage local interfaces.
|
||||
|
||||
Features:
|
||||
- Automatic DAPI signature header injection
|
||||
- Connection pooling
|
||||
- Retry with exponential backoff
|
||||
- Configurable timeouts
|
||||
- Support for GET, POST, PUT, DELETE, PATCH
|
||||
|
||||
Usage:
|
||||
from sageapi.utils.http_client import SageHttpClient
|
||||
|
||||
client = SageHttpClient(
|
||||
api_key="your-api-key",
|
||||
api_secret="your-api-secret",
|
||||
base_url="https://api.example.com",
|
||||
max_retries=3,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
async with client:
|
||||
resp = await client.post("/v1/chat", json={"message": "hello"})
|
||||
data = await resp.json()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
import aiohttp
|
||||
from aiohttp import ClientTimeout
|
||||
|
||||
from appPublic.log import debug, error
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default configuration
|
||||
DEFAULT_TIMEOUT = 30.0
|
||||
DEFAULT_MAX_RETRIES = 3
|
||||
DEFAULT_BACKOFF_FACTOR = 0.5
|
||||
DEFAULT_MAX_CONNECTIONS = 100
|
||||
DEFAULT_CONNECTION_POOL_LIMIT = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""Configuration for request retries."""
|
||||
|
||||
max_retries: int = DEFAULT_MAX_RETRIES
|
||||
backoff_factor: float = DEFAULT_BACKOFF_FACTOR
|
||||
retry_on_status: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 504})
|
||||
retry_on_timeout: bool = True
|
||||
retry_on_connection_error: bool = True
|
||||
|
||||
|
||||
def compute_dapi_signature(
|
||||
method: str,
|
||||
path: str,
|
||||
timestamp: str,
|
||||
secret: str,
|
||||
body: Optional[bytes] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Compute HMAC-SHA256 signature for DAPI authentication.
|
||||
|
||||
Signs: "{method}\\n{path}\\n{timestamp}\\n{body_hash}"
|
||||
"""
|
||||
if body:
|
||||
body_hash = hashlib.sha256(body).hexdigest()
|
||||
else:
|
||||
body_hash = ""
|
||||
|
||||
string_to_sign = f"{method}\n{path}\n{timestamp}\n{body_hash}"
|
||||
|
||||
return hmac.new(
|
||||
secret.encode("utf-8"),
|
||||
string_to_sign.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
|
||||
class DAPISigner:
|
||||
"""Handles DAPI signature generation for outgoing requests."""
|
||||
|
||||
def __init__(self, api_key: str, api_secret: str):
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
body: Optional[bytes] = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Generate DAPI authentication headers.
|
||||
|
||||
Returns dict with X-DAPI-Key, X-DAPI-Timestamp, X-DAPI-Signature.
|
||||
"""
|
||||
timestamp = str(time.time())
|
||||
signature = compute_dapi_signature(
|
||||
method=method.upper(),
|
||||
path=path,
|
||||
timestamp=timestamp,
|
||||
secret=self.api_secret,
|
||||
body=body,
|
||||
)
|
||||
|
||||
return {
|
||||
"X-DAPI-Key": self.api_key,
|
||||
"X-DAPI-Timestamp": timestamp,
|
||||
"X-DAPI-Signature": signature,
|
||||
}
|
||||
|
||||
|
||||
class SageHttpClient:
|
||||
"""Async HTTP client for calling upstream Sage APIs.
|
||||
"""
|
||||
Async HTTP client with DAPI signature support.
|
||||
|
||||
Handles connection pooling, authentication headers, and
|
||||
retry logic for transient failures.
|
||||
Uses aiohttp under the hood with connection pooling, retry logic,
|
||||
and automatic DAPI header injection.
|
||||
|
||||
Args:
|
||||
base_url: Base URL for all requests (e.g., "https://api.example.com").
|
||||
api_key: DAPI API key for request signing.
|
||||
api_secret: DAPI API secret for request signing.
|
||||
timeout: Request timeout in seconds.
|
||||
max_retries: Maximum number of retries on transient failures.
|
||||
backoff_factor: Exponential backoff multiplier.
|
||||
max_connections: Maximum number of connections in the pool.
|
||||
headers: Additional default headers to include in every request.
|
||||
signer: Optional custom DAPISigner instance. If None, one is created from api_key/api_secret.
|
||||
|
||||
Usage:
|
||||
async with SageHttpClient(base_url="...", api_key="...", api_secret="...") as client:
|
||||
resp = await client.get("/health")
|
||||
resp = await client.post("/v1/chat", json={"msg": "hi"})
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = 'http://127.0.0.1:9180',
|
||||
dapi_key: str = '',
|
||||
dapi_secret: str = '',
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.dapi_key = dapi_key
|
||||
self.dapi_secret = dapi_secret
|
||||
base_url: str = "",
|
||||
api_key: str = "",
|
||||
api_secret: str = "",
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
backoff_factor: float = DEFAULT_BACKOFF_FACTOR,
|
||||
max_connections: int = DEFAULT_MAX_CONNECTIONS,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
signer: Optional[DAPISigner] = None,
|
||||
auto_sign: bool = True,
|
||||
):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.backoff_factor = backoff_factor
|
||||
self.default_headers = headers or {}
|
||||
self.auto_sign = auto_sign
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
"""Build request headers with DAPI authentication.
|
||||
if signer is not None:
|
||||
self.signer = signer
|
||||
else:
|
||||
self.signer = DAPISigner(api_key=api_key, api_secret=api_secret)
|
||||
|
||||
TODO: Implement actual DAPI signature generation.
|
||||
# Internal state
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._connector: Optional[aiohttp.TCPConnector] = None
|
||||
self._connector_limit = max_connections
|
||||
self._closed = False
|
||||
|
||||
def _build_url(self, path: str) -> str:
|
||||
"""Build full URL from base_url and path."""
|
||||
if path.startswith("http://") or path.startswith("https://"):
|
||||
return path
|
||||
return f"{self.base_url}/{path.lstrip('/')}"
|
||||
|
||||
def _get_relative_path(self, path: str) -> str:
|
||||
"""Extract the relative path for signing purposes."""
|
||||
if path.startswith("http://") or path.startswith("https://"):
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(path)
|
||||
return parsed.path
|
||||
return path
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create the aiohttp session with connection pooling."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._connector = aiohttp.TCPConnector(
|
||||
limit=self._connector_limit,
|
||||
limit_per_host=self._connector_limit,
|
||||
ttl_dns_cache=300,
|
||||
keepalive_timeout=30,
|
||||
)
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=self._connector,
|
||||
timeout=ClientTimeout(total=self.timeout),
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def _sign_and_prepare(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
data: Any = None,
|
||||
json_body: Any = None,
|
||||
) -> tuple[str, dict[str, str], Optional[bytes]]:
|
||||
"""
|
||||
import time
|
||||
import hashlib
|
||||
import hmac
|
||||
Prepare request with DAPI signature headers.
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
string_to_sign = f'{self.dapi_key}:{timestamp}'
|
||||
signature = hmac.new(
|
||||
self.dapi_secret.encode('utf-8') if self.dapi_secret else b'',
|
||||
string_to_sign.encode('utf-8'),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
Returns (url, merged_headers, raw_body_bytes).
|
||||
"""
|
||||
url = self._build_url(path)
|
||||
relative_path = self._get_relative_path(path)
|
||||
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'X-DAPI-Key': self.dapi_key,
|
||||
'X-DAPI-Timestamp': timestamp,
|
||||
'X-DAPI-Signature': signature,
|
||||
merged_headers = {**self.default_headers}
|
||||
if headers:
|
||||
merged_headers.update(headers)
|
||||
|
||||
# Determine body for signing
|
||||
raw_body: Optional[bytes] = None
|
||||
if json_body is not None:
|
||||
import json
|
||||
|
||||
raw_body = json.dumps(json_body).encode("utf-8")
|
||||
merged_headers.setdefault("Content-Type", "application/json")
|
||||
elif data is not None:
|
||||
if isinstance(data, bytes):
|
||||
raw_body = data
|
||||
elif isinstance(data, str):
|
||||
raw_body = data.encode("utf-8")
|
||||
else:
|
||||
# Form data - sign the encoded form
|
||||
from urllib.parse import urlencode
|
||||
|
||||
raw_body = urlencode(data).encode("utf-8")
|
||||
|
||||
# Add DAPI signature headers if auto-sign is enabled and signer has credentials
|
||||
if self.auto_sign and self.signer.api_key and self.signer.api_secret:
|
||||
dapi_headers = self.signer.sign_request(
|
||||
method=method,
|
||||
path=relative_path,
|
||||
body=raw_body,
|
||||
)
|
||||
merged_headers.update(dapi_headers)
|
||||
|
||||
return url, merged_headers, raw_body
|
||||
|
||||
async def _execute_with_retry(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
**kwargs: Any,
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""
|
||||
Execute request with retry and exponential backoff.
|
||||
"""
|
||||
session = await self._get_session()
|
||||
last_response: Optional[aiohttp.ClientResponse] = None
|
||||
last_exception: Optional[Exception] = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
response = await session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Check if we should retry based on status
|
||||
if response.status in {429, 500, 502, 503, 504}:
|
||||
last_response = response
|
||||
if attempt < self.max_retries:
|
||||
wait_time = self.backoff_factor * (2 ** attempt)
|
||||
logger.warning(
|
||||
"HTTP %s on %s %s, retrying in %.1fs (attempt %d/%d)",
|
||||
response.status,
|
||||
method,
|
||||
url,
|
||||
wait_time,
|
||||
attempt + 1,
|
||||
self.max_retries,
|
||||
)
|
||||
try:
|
||||
response.release()
|
||||
except Exception:
|
||||
pass
|
||||
await self._async_sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
# Retries exhausted - raise an error
|
||||
raise aiohttp.ClientResponseError(
|
||||
request_info=response.request_info,
|
||||
history=response.history,
|
||||
status=response.status,
|
||||
message=f"HTTP {response.status} after {self.max_retries + 1} attempts",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except aiohttp.ServerTimeoutError as e:
|
||||
last_exception = e
|
||||
if attempt < self.max_retries:
|
||||
wait_time = self.backoff_factor * (2 ** attempt)
|
||||
logger.warning(
|
||||
"Timeout on %s %s, retrying in %.1fs (attempt %d/%d)",
|
||||
method,
|
||||
url,
|
||||
wait_time,
|
||||
attempt + 1,
|
||||
self.max_retries,
|
||||
)
|
||||
await self._async_sleep(wait_time)
|
||||
continue
|
||||
|
||||
except (aiohttp.ClientConnectionError, aiohttp.ClientOSError) as e:
|
||||
last_exception = e
|
||||
if attempt < self.max_retries:
|
||||
wait_time = self.backoff_factor * (2 ** attempt)
|
||||
logger.warning(
|
||||
"Connection error on %s %s, retrying in %.1fs (attempt %d/%d)",
|
||||
method,
|
||||
url,
|
||||
wait_time,
|
||||
attempt + 1,
|
||||
self.max_retries,
|
||||
)
|
||||
await self._async_sleep(wait_time)
|
||||
continue
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
# All retries exhausted
|
||||
if last_response:
|
||||
return last_response
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
raise RuntimeError(f"Request failed after {self.max_retries + 1} attempts")
|
||||
|
||||
@staticmethod
|
||||
async def _async_sleep(seconds: float) -> None:
|
||||
"""Non-blocking sleep."""
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(seconds)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
data: Any = None,
|
||||
json: Any = None,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
allow_redirects: bool = True,
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""
|
||||
Send an HTTP request with DAPI signing and retry.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, PUT, DELETE, PATCH).
|
||||
path: URL path or full URL.
|
||||
headers: Additional request headers.
|
||||
data: Form data or raw bytes.
|
||||
json: JSON-serializable body (automatically encoded).
|
||||
params: Query string parameters.
|
||||
timeout: Override timeout for this request.
|
||||
allow_redirects: Whether to follow redirects.
|
||||
|
||||
Returns:
|
||||
aiohttp.ClientResponse
|
||||
"""
|
||||
url, merged_headers, raw_body = await self._sign_and_prepare(
|
||||
method=method,
|
||||
path=path,
|
||||
headers=headers,
|
||||
data=data,
|
||||
json_body=json,
|
||||
)
|
||||
|
||||
request_kwargs: dict[str, Any] = {
|
||||
"allow_redirects": allow_redirects,
|
||||
}
|
||||
|
||||
async def get(
|
||||
self,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
"""Send a GET request to the upstream Sage API.
|
||||
if raw_body is not None:
|
||||
if json is not None:
|
||||
request_kwargs["data"] = raw_body
|
||||
else:
|
||||
request_kwargs["data"] = raw_body
|
||||
elif json is not None:
|
||||
import json as _json
|
||||
|
||||
Args:
|
||||
path: API path (relative to base_url).
|
||||
params: Optional query parameters.
|
||||
headers: Optional additional headers.
|
||||
request_kwargs["data"] = _json.dumps(json).encode("utf-8")
|
||||
|
||||
Returns:
|
||||
Parsed JSON response.
|
||||
"""
|
||||
url = f'{self.base_url}{path}'
|
||||
request_headers = {**self._build_headers(), **(headers or {})}
|
||||
if data is not None and json is None:
|
||||
request_kwargs["data"] = data
|
||||
|
||||
debug(f'HTTP GET {url} params={params}')
|
||||
if params:
|
||||
request_kwargs["params"] = params
|
||||
|
||||
# TODO: Replace with actual async HTTP implementation
|
||||
# using aiohttp or the framework's built-in HTTP client.
|
||||
# This is a placeholder that will be filled in once the
|
||||
# specific HTTP library choice is confirmed.
|
||||
raise NotImplementedError(
|
||||
'SageHttpClient.get: HTTP library not yet integrated. '
|
||||
'Implement with aiohttp or framework HTTP client.'
|
||||
# Per-request timeout override
|
||||
if timeout is not None:
|
||||
request_kwargs["timeout"] = ClientTimeout(total=timeout)
|
||||
|
||||
return await self._execute_with_retry(
|
||||
method=method.upper(),
|
||||
url=url,
|
||||
headers=merged_headers,
|
||||
**request_kwargs,
|
||||
)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
path: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
"""Send a POST request to the upstream Sage API.
|
||||
async def get(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
|
||||
"""Send a GET request."""
|
||||
return await self.request("GET", path, **kwargs)
|
||||
|
||||
Args:
|
||||
path: API path (relative to base_url).
|
||||
data: JSON body data.
|
||||
headers: Optional additional headers.
|
||||
async def post(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
|
||||
"""Send a POST request."""
|
||||
return await self.request("POST", path, **kwargs)
|
||||
|
||||
Returns:
|
||||
Parsed JSON response.
|
||||
"""
|
||||
url = f'{self.base_url}{path}'
|
||||
request_headers = {**self._build_headers(), **(headers or {})}
|
||||
async def put(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
|
||||
"""Send a PUT request."""
|
||||
return await self.request("PUT", path, **kwargs)
|
||||
|
||||
debug(f'HTTP POST {url}')
|
||||
async def delete(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
|
||||
"""Send a DELETE request."""
|
||||
return await self.request("DELETE", path, **kwargs)
|
||||
|
||||
# TODO: Replace with actual async HTTP implementation.
|
||||
raise NotImplementedError(
|
||||
'SageHttpClient.post: HTTP library not yet integrated. '
|
||||
'Implement with aiohttp or framework HTTP client.'
|
||||
async def patch(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
|
||||
"""Send a PATCH request."""
|
||||
return await self.request("PATCH", path, **kwargs)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client and release resources."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._closed = True
|
||||
|
||||
async def __aenter__(self) -> "SageHttpClient":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
# Best-effort cleanup; prefer using async context manager
|
||||
if not self._closed and self._session and not self._session.closed:
|
||||
logger.warning(
|
||||
"SageHttpClient was not properly closed. "
|
||||
"Use 'async with SageHttpClient(...)' for proper cleanup."
|
||||
)
|
||||
|
||||
122
sync/cache_tables.sql
Normal file
122
sync/cache_tables.sql
Normal file
@ -0,0 +1,122 @@
|
||||
-- =============================================================================
|
||||
-- SageAPI Cache Tables DDL
|
||||
-- These tables store synchronized data from the Sage database.
|
||||
-- Run against the sageapi database.
|
||||
-- =============================================================================
|
||||
|
||||
-- Checkpoint table: tracks the last sync timestamp for each module
|
||||
CREATE TABLE IF NOT EXISTS sync_state (
|
||||
state_key VARCHAR(64) NOT NULL PRIMARY KEY,
|
||||
last_sync_ts VARCHAR(64) DEFAULT NULL,
|
||||
created_at VARCHAR(32) NOT NULL,
|
||||
updated_at VARCHAR(32) NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- =============================================================================
|
||||
-- users_cache: synced from Sage users / organi / organization tables
|
||||
-- =============================================================================
|
||||
CREATE TABLE IF NOT EXISTS users_cache (
|
||||
user_id BIGINT NOT NULL PRIMARY KEY,
|
||||
username VARCHAR(128) DEFAULT NULL,
|
||||
email VARCHAR(256) DEFAULT NULL,
|
||||
user_status INT DEFAULT 0,
|
||||
user_created_at VARCHAR(32) DEFAULT NULL,
|
||||
user_updated_at VARCHAR(32) DEFAULT NULL,
|
||||
organi_id BIGINT DEFAULT NULL,
|
||||
organi_name VARCHAR(256) DEFAULT NULL,
|
||||
organi_parent_id BIGINT DEFAULT NULL,
|
||||
org_id BIGINT DEFAULT NULL,
|
||||
org_name VARCHAR(256) DEFAULT NULL,
|
||||
org_type VARCHAR(64) DEFAULT NULL,
|
||||
org_status INT DEFAULT 0,
|
||||
org_updated_at VARCHAR(32) DEFAULT NULL,
|
||||
synced_at VARCHAR(32) NOT NULL,
|
||||
UNIQUE KEY uk_organi (organi_id),
|
||||
KEY idx_updated (user_updated_at)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- =============================================================================
|
||||
-- pricing_cache: synced from Sage pricing_program / pricing_program_timing
|
||||
-- =============================================================================
|
||||
CREATE TABLE IF NOT EXISTS pricing_cache (
|
||||
program_id BIGINT NOT NULL,
|
||||
program_name VARCHAR(256) DEFAULT NULL,
|
||||
program_code VARCHAR(128) DEFAULT NULL,
|
||||
program_type VARCHAR(64) DEFAULT NULL,
|
||||
program_status INT DEFAULT 0,
|
||||
description TEXT DEFAULT NULL,
|
||||
program_created_at VARCHAR(32) DEFAULT NULL,
|
||||
program_updated_at VARCHAR(32) DEFAULT NULL,
|
||||
timing_id BIGINT DEFAULT NULL,
|
||||
start_time VARCHAR(32) DEFAULT NULL,
|
||||
end_time VARCHAR(32) DEFAULT NULL,
|
||||
duration INT DEFAULT NULL,
|
||||
repeat_rule VARCHAR(256) DEFAULT NULL,
|
||||
timezone VARCHAR(64) DEFAULT NULL,
|
||||
timing_status INT DEFAULT 0,
|
||||
timing_updated_at VARCHAR(32) DEFAULT NULL,
|
||||
synced_at VARCHAR(32) NOT NULL,
|
||||
PRIMARY KEY (program_id, timing_id),
|
||||
KEY idx_program_updated (program_updated_at)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- =============================================================================
|
||||
-- uapi_cache: synced from Sage uapi / upapp tables
|
||||
-- =============================================================================
|
||||
CREATE TABLE IF NOT EXISTS uapi_cache (
|
||||
uapi_id BIGINT NOT NULL,
|
||||
api_name VARCHAR(256) DEFAULT NULL,
|
||||
api_path VARCHAR(512) DEFAULT NULL,
|
||||
api_method VARCHAR(16) DEFAULT 'GET',
|
||||
api_version VARCHAR(32) DEFAULT NULL,
|
||||
api_desc TEXT DEFAULT NULL,
|
||||
uapi_status INT DEFAULT 0,
|
||||
auth_required TINYINT DEFAULT 0,
|
||||
uapi_created_at VARCHAR(32) DEFAULT NULL,
|
||||
uapi_updated_at VARCHAR(32) DEFAULT NULL,
|
||||
upapp_id BIGINT DEFAULT NULL,
|
||||
app_name VARCHAR(256) DEFAULT NULL,
|
||||
app_code VARCHAR(128) DEFAULT NULL,
|
||||
app_type VARCHAR(64) DEFAULT NULL,
|
||||
app_desc TEXT DEFAULT NULL,
|
||||
app_owner VARCHAR(128) DEFAULT NULL,
|
||||
upapp_status INT DEFAULT 0,
|
||||
upapp_updated_at VARCHAR(32) DEFAULT NULL,
|
||||
synced_at VARCHAR(32) NOT NULL,
|
||||
PRIMARY KEY (uapi_id, upapp_id),
|
||||
KEY idx_uapi_updated (uapi_updated_at),
|
||||
KEY idx_upapp_id (upapp_id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
|
||||
-- =============================================================================
|
||||
-- llmage_cache: synced from Sage llm / llmcatelog / llm_api_map tables
|
||||
-- =============================================================================
|
||||
CREATE TABLE IF NOT EXISTS llmage_cache (
|
||||
llm_id BIGINT NOT NULL,
|
||||
model_name VARCHAR(256) DEFAULT NULL,
|
||||
model_version VARCHAR(64) DEFAULT NULL,
|
||||
provider VARCHAR(128) DEFAULT NULL,
|
||||
model_type VARCHAR(64) DEFAULT NULL,
|
||||
llm_status INT DEFAULT 0,
|
||||
llm_description TEXT DEFAULT NULL,
|
||||
llm_created_at VARCHAR(32) DEFAULT NULL,
|
||||
llm_updated_at VARCHAR(32) DEFAULT NULL,
|
||||
catelog_id BIGINT DEFAULT NULL,
|
||||
catelog_name VARCHAR(256) DEFAULT NULL,
|
||||
catelog_code VARCHAR(128) DEFAULT NULL,
|
||||
catelog_sort INT DEFAULT 0,
|
||||
catelog_status INT DEFAULT 0,
|
||||
catelog_updated_at VARCHAR(32) DEFAULT NULL,
|
||||
api_map_id BIGINT DEFAULT NULL,
|
||||
api_name VARCHAR(256) DEFAULT NULL,
|
||||
api_endpoint VARCHAR(512) DEFAULT NULL,
|
||||
api_version VARCHAR(32) DEFAULT NULL,
|
||||
auth_type VARCHAR(32) DEFAULT NULL,
|
||||
rate_limit INT DEFAULT NULL,
|
||||
api_map_status INT DEFAULT 0,
|
||||
api_map_updated_at VARCHAR(32) DEFAULT NULL,
|
||||
synced_at VARCHAR(32) NOT NULL,
|
||||
PRIMARY KEY (llm_id, catelog_id, api_map_id),
|
||||
KEY idx_llm_updated (llm_updated_at),
|
||||
KEY idx_catelog_id (catelog_id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Tests package
|
||||
BIN
tests/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
tests/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/test_dapi_auth.cpython-310-pytest-9.0.3.pyc
Normal file
BIN
tests/__pycache__/test_dapi_auth.cpython-310-pytest-9.0.3.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/test_http_client.cpython-310-pytest-9.0.3.pyc
Normal file
BIN
tests/__pycache__/test_http_client.cpython-310-pytest-9.0.3.pyc
Normal file
Binary file not shown.
458
tests/test_dapi_auth.py
Normal file
458
tests/test_dapi_auth.py
Normal file
@ -0,0 +1,458 @@
|
||||
"""Tests for sageapi.middleware.dapi_auth"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from sageapi.middleware.dapi_auth import (
|
||||
DEFAULT_TIMESTAMP_TOLERANCE_SEC,
|
||||
DapiAuthError,
|
||||
DapiAuthMiddleware,
|
||||
DapiKeyRecord,
|
||||
authenticate_request,
|
||||
compute_dapi_signature,
|
||||
lookup_api_key,
|
||||
verify_timestamp,
|
||||
)
|
||||
|
||||
|
||||
def make_signature(method: str, path: str, timestamp: str, secret: str, body: bytes = b"") -> str:
|
||||
"""Helper to create a valid DAPI signature."""
|
||||
return compute_dapi_signature(method, path, timestamp, secret, body if body else None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_dapi_signature tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestComputeDapiSignature:
|
||||
def test_basic_signature(self):
|
||||
sig = compute_dapi_signature("GET", "/api/test", "1700000000.0", "my-secret")
|
||||
assert isinstance(sig, str)
|
||||
assert len(sig) == 64 # SHA-256 hex length
|
||||
|
||||
def test_signature_with_body(self):
|
||||
body = b'{"key": "value"}'
|
||||
sig = compute_dapi_signature("POST", "/api/test", "1700000000.0", "my-secret", body)
|
||||
assert isinstance(sig, str)
|
||||
assert len(sig) == 64
|
||||
|
||||
def test_signature_deterministic(self):
|
||||
sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret")
|
||||
sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret")
|
||||
assert sig1 == sig2
|
||||
|
||||
def test_different_body_different_signature(self):
|
||||
sig1 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body1")
|
||||
sig2 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body2")
|
||||
assert sig1 != sig2
|
||||
|
||||
def test_empty_body_same_as_no_body(self):
|
||||
sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret")
|
||||
sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret", None)
|
||||
assert sig1 == sig2
|
||||
|
||||
def test_manual_verification(self):
|
||||
"""Verify the signature matches the expected HMAC-SHA256 output."""
|
||||
method, path, ts, secret = "POST", "/v1/chat", "1700000000.0", "test-secret"
|
||||
body = b'{"message":"hello"}'
|
||||
body_hash = hashlib.sha256(body).hexdigest()
|
||||
string_to_sign = f"{method}\n{path}\n{ts}\n{body_hash}"
|
||||
expected = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha256).hexdigest()
|
||||
actual = compute_dapi_signature(method, path, ts, secret, body)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# verify_timestamp tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestVerifyTimestamp:
|
||||
def test_valid_timestamp(self):
|
||||
ts = str(time.time())
|
||||
assert verify_timestamp(ts) is True
|
||||
|
||||
def test_expired_timestamp(self):
|
||||
ts = str(time.time() - 600) # 10 minutes ago
|
||||
assert verify_timestamp(ts) is False
|
||||
|
||||
def test_future_timestamp(self):
|
||||
ts = str(time.time() + 600) # 10 minutes in future
|
||||
assert verify_timestamp(ts) is False
|
||||
|
||||
def test_invalid_timestamp(self):
|
||||
assert verify_timestamp("not-a-number") is False
|
||||
assert verify_timestamp("") is False
|
||||
|
||||
def test_custom_tolerance(self):
|
||||
ts = str(time.time() - 10)
|
||||
assert verify_timestamp(ts, tolerance_sec=5) is False
|
||||
assert verify_timestamp(ts, tolerance_sec=15) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DapiKeyRecord tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDapiKeyRecord:
|
||||
def test_active_key(self):
|
||||
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None)
|
||||
assert rec.is_active is True
|
||||
|
||||
def test_inactive_key(self):
|
||||
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="inactive", expire_date=None)
|
||||
assert rec.is_active is False
|
||||
|
||||
def test_not_expired_when_no_expiry(self):
|
||||
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None)
|
||||
assert rec.is_expired is False
|
||||
|
||||
def test_expired_with_past_date(self):
|
||||
past = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
||||
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=past)
|
||||
assert rec.is_expired is True
|
||||
|
||||
def test_not_expired_with_future_date(self):
|
||||
future = datetime(2099, 1, 1, tzinfo=timezone.utc)
|
||||
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=future)
|
||||
assert rec.is_expired is False
|
||||
|
||||
def test_expired_with_iso_string_past(self):
|
||||
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2020-01-01T00:00:00Z")
|
||||
assert rec.is_expired is True
|
||||
|
||||
def test_not_expired_with_iso_string_future(self):
|
||||
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2099-01-01T00:00:00Z")
|
||||
assert rec.is_expired is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# lookup_api_key tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLookupApiKey:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_sage_modules(self):
|
||||
"""Create mock ahserver and sqlor modules for testing."""
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Create fake modules
|
||||
ahserver = types.ModuleType("ahserver")
|
||||
ahserver_serverenv = types.ModuleType("ahserver.serverenv")
|
||||
ahserver_serverenv.ServerEnv = MagicMock
|
||||
ahserver.serverenv = ahserver_serverenv
|
||||
|
||||
sqlor = types.ModuleType("sqlor")
|
||||
sqlor_dbpools = types.ModuleType("sqlor.dbpools")
|
||||
|
||||
mock_ctx = AsyncMock()
|
||||
mock_sor = AsyncMock()
|
||||
mock_sor.R = AsyncMock(return_value=[])
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx)
|
||||
sqlor.dbpools = sqlor_dbpools
|
||||
|
||||
sys.modules["ahserver"] = ahserver
|
||||
sys.modules["ahserver.serverenv"] = ahserver_serverenv
|
||||
sys.modules["sqlor"] = sqlor
|
||||
sys.modules["sqlor.dbpools"] = sqlor_dbpools
|
||||
|
||||
yield mock_ctx, mock_sor
|
||||
|
||||
# Cleanup
|
||||
for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]:
|
||||
if mod in sys.modules:
|
||||
del sys.modules[mod]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_found(self, mock_sage_modules):
|
||||
mock_ctx, mock_sor = mock_sage_modules
|
||||
mock_rec = {
|
||||
"id": 1,
|
||||
"apikey": "test-key",
|
||||
"secret": "test-secret",
|
||||
"status": "active",
|
||||
"expire_date": "2099-01-01T00:00:00Z",
|
||||
"description": "Test key",
|
||||
}
|
||||
mock_sor.R.return_value = [mock_rec]
|
||||
|
||||
result = await lookup_api_key("test-key")
|
||||
|
||||
assert result is not None
|
||||
assert result.apikey == "test-key"
|
||||
assert result.secret == "test-secret"
|
||||
assert result.is_active is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_not_found(self, mock_sage_modules):
|
||||
mock_ctx, mock_sor = mock_sage_modules
|
||||
mock_sor.R.return_value = []
|
||||
|
||||
result = await lookup_api_key("nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# authenticate_request tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAuthenticateRequest:
|
||||
def _make_request(self, headers: dict, body: bytes = b"", method: str = "GET", path: str = "/test") -> Request:
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": method,
|
||||
"path": path,
|
||||
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
|
||||
"query_string": b"",
|
||||
}
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
return Request(scope, receive)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_api_key_header(self):
|
||||
req = self._make_request({
|
||||
"X-DAPI-Timestamp": str(time.time()),
|
||||
"X-DAPI-Signature": "abc",
|
||||
})
|
||||
with pytest.raises(DapiAuthError) as exc:
|
||||
await authenticate_request(req)
|
||||
assert exc.value.status_code == 401
|
||||
assert "X-DAPI-Key" in exc.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_timestamp_header(self):
|
||||
req = self._make_request({
|
||||
"X-DAPI-Key": "test-key",
|
||||
"X-DAPI-Signature": "abc",
|
||||
})
|
||||
with pytest.raises(DapiAuthError) as exc:
|
||||
await authenticate_request(req)
|
||||
assert exc.value.status_code == 401
|
||||
assert "X-DAPI-Timestamp" in exc.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_signature_header(self):
|
||||
req = self._make_request({
|
||||
"X-DAPI-Key": "test-key",
|
||||
"X-DAPI-Timestamp": str(time.time()),
|
||||
})
|
||||
with pytest.raises(DapiAuthError) as exc:
|
||||
await authenticate_request(req)
|
||||
assert exc.value.status_code == 401
|
||||
assert "X-DAPI-Signature" in exc.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_timestamp(self):
|
||||
ts = str(time.time() - 600)
|
||||
req = self._make_request({
|
||||
"X-DAPI-Key": "test-key",
|
||||
"X-DAPI-Timestamp": ts,
|
||||
"X-DAPI-Signature": "abc",
|
||||
})
|
||||
with pytest.raises(DapiAuthError) as exc:
|
||||
await authenticate_request(req)
|
||||
assert exc.value.status_code == 401
|
||||
assert "timestamp" in exc.value.detail.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_api_key(self):
|
||||
ts = str(time.time())
|
||||
req = self._make_request({
|
||||
"X-DAPI-Key": "bad-key",
|
||||
"X-DAPI-Timestamp": ts,
|
||||
"X-DAPI-Signature": "abc",
|
||||
})
|
||||
|
||||
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=None):
|
||||
with pytest.raises(DapiAuthError) as exc:
|
||||
await authenticate_request(req)
|
||||
assert exc.value.status_code == 401
|
||||
assert "Invalid API key" in exc.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inactive_key(self):
|
||||
ts = str(time.time())
|
||||
req = self._make_request({
|
||||
"X-DAPI-Key": "test-key",
|
||||
"X-DAPI-Timestamp": ts,
|
||||
"X-DAPI-Signature": "abc",
|
||||
})
|
||||
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="secret", status="inactive", expire_date=None)
|
||||
|
||||
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
|
||||
with pytest.raises(DapiAuthError) as exc:
|
||||
await authenticate_request(req)
|
||||
assert exc.value.status_code == 403
|
||||
assert "inactive" in exc.value.detail.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_key(self):
|
||||
ts = str(time.time())
|
||||
req = self._make_request({
|
||||
"X-DAPI-Key": "test-key",
|
||||
"X-DAPI-Timestamp": ts,
|
||||
"X-DAPI-Signature": "abc",
|
||||
})
|
||||
key_rec = DapiKeyRecord(
|
||||
id=1, apikey="test-key", secret="secret", status="active",
|
||||
expire_date=datetime(2020, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
|
||||
with pytest.raises(DapiAuthError) as exc:
|
||||
await authenticate_request(req)
|
||||
assert exc.value.status_code == 403
|
||||
assert "expired" in exc.value.detail.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_signature(self):
|
||||
ts = str(time.time())
|
||||
body = b'{"test": "data"}'
|
||||
req = self._make_request({
|
||||
"X-DAPI-Key": "test-key",
|
||||
"X-DAPI-Timestamp": ts,
|
||||
"X-DAPI-Signature": "invalid-signature",
|
||||
}, body=body, method="POST")
|
||||
|
||||
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="real-secret", status="active", expire_date=None)
|
||||
|
||||
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
|
||||
with pytest.raises(DapiAuthError) as exc:
|
||||
await authenticate_request(req)
|
||||
assert exc.value.status_code == 401
|
||||
assert "Invalid signature" in exc.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_authentication(self):
|
||||
ts = str(time.time())
|
||||
secret = "my-secret"
|
||||
body = b'{"test": "data"}'
|
||||
sig = compute_dapi_signature("POST", "/test", ts, secret, body)
|
||||
|
||||
req = self._make_request({
|
||||
"X-DAPI-Key": "test-key",
|
||||
"X-DAPI-Timestamp": ts,
|
||||
"X-DAPI-Signature": sig,
|
||||
}, body=body, method="POST")
|
||||
|
||||
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret=secret, status="active", expire_date=None)
|
||||
|
||||
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
|
||||
result = await authenticate_request(req)
|
||||
|
||||
assert result.apikey == "test-key"
|
||||
assert result.secret == secret
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DapiAuthMiddleware tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDapiAuthMiddleware:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_sage_modules(self):
|
||||
"""Create mock ahserver and sqlor modules."""
|
||||
import sys
|
||||
import types
|
||||
|
||||
ahserver = types.ModuleType("ahserver")
|
||||
ahserver_serverenv = types.ModuleType("ahserver.serverenv")
|
||||
ahserver_serverenv.ServerEnv = MagicMock
|
||||
ahserver.serverenv = ahserver_serverenv
|
||||
|
||||
sqlor = types.ModuleType("sqlor")
|
||||
sqlor_dbpools = types.ModuleType("sqlor.dbpools")
|
||||
mock_sor = AsyncMock()
|
||||
mock_sor.R = AsyncMock(return_value=[])
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx)
|
||||
sqlor.dbpools = sqlor_dbpools
|
||||
|
||||
sys.modules["ahserver"] = ahserver
|
||||
sys.modules["ahserver.serverenv"] = ahserver_serverenv
|
||||
sys.modules["sqlor"] = sqlor
|
||||
sys.modules["sqlor.dbpools"] = sqlor_dbpools
|
||||
|
||||
yield
|
||||
|
||||
for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]:
|
||||
if mod in sys.modules:
|
||||
del sys.modules[mod]
|
||||
|
||||
def _build_test_app(self, exclude_paths=None, tolerance_sec=DEFAULT_TIMESTAMP_TOLERANCE_SEC):
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import PlainTextResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
async def protected(request: Request):
|
||||
return PlainTextResponse("OK - authenticated")
|
||||
|
||||
async def health(request: Request):
|
||||
return PlainTextResponse("healthy")
|
||||
|
||||
app = Starlette(
|
||||
routes=[
|
||||
Route("/protected", endpoint=protected, methods=["GET"]),
|
||||
Route("/health", endpoint=health, methods=["GET"]),
|
||||
]
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
DapiAuthMiddleware,
|
||||
exclude_paths=exclude_paths or ["/health"],
|
||||
tolerance_sec=tolerance_sec,
|
||||
)
|
||||
return app
|
||||
|
||||
def test_unauthenticated_request_rejected(self):
|
||||
app = self._build_test_app()
|
||||
client = TestClient(app)
|
||||
resp = client.get("/protected")
|
||||
assert resp.status_code == 401
|
||||
assert "error" in resp.json()
|
||||
|
||||
def test_excluded_path_bypasses_auth(self):
|
||||
app = self._build_test_app()
|
||||
client = TestClient(app)
|
||||
resp = client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.text == "healthy"
|
||||
|
||||
def test_valid_request_accepted(self):
|
||||
ts = str(time.time())
|
||||
secret = "test-secret"
|
||||
sig = compute_dapi_signature("GET", "/protected", ts, secret)
|
||||
|
||||
key_rec = DapiKeyRecord(id=1, apikey="valid-key", secret=secret, status="active", expire_date=None)
|
||||
|
||||
app = self._build_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
|
||||
resp = client.get(
|
||||
"/protected",
|
||||
headers={
|
||||
"X-DAPI-Key": "valid-key",
|
||||
"X-DAPI-Timestamp": ts,
|
||||
"X-DAPI-Signature": sig,
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.text == "OK - authenticated"
|
||||
288
tests/test_http_client.py
Normal file
288
tests/test_http_client.py
Normal file
@ -0,0 +1,288 @@
|
||||
"""Tests for sageapi.utils.http_client"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from sageapi.utils.http_client import (
|
||||
DAPISigner,
|
||||
SageHttpClient,
|
||||
RetryConfig,
|
||||
compute_dapi_signature,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_dapi_signature tests (also imported from middleware)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestComputeDapiSignature:
|
||||
def test_basic(self):
|
||||
sig = compute_dapi_signature("GET", "/api/test", "1700000000.0", "secret")
|
||||
assert len(sig) == 64
|
||||
|
||||
def test_with_body(self):
|
||||
body = b'{"msg":"hi"}'
|
||||
sig = compute_dapi_signature("POST", "/api/test", "1700000000.0", "secret", body)
|
||||
assert len(sig) == 64
|
||||
|
||||
def test_deterministic(self):
|
||||
sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret")
|
||||
sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret")
|
||||
assert sig1 == sig2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DAPISigner tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDAPISigner:
|
||||
def test_sign_request(self):
|
||||
signer = DAPISigner(api_key="my-key", api_secret="my-secret")
|
||||
headers = signer.sign_request("GET", "/test")
|
||||
|
||||
assert "X-DAPI-Key" in headers
|
||||
assert "X-DAPI-Timestamp" in headers
|
||||
assert "X-DAPI-Signature" in headers
|
||||
assert headers["X-DAPI-Key"] == "my-key"
|
||||
assert headers["X-DAPI-Timestamp"] # should be a valid timestamp string
|
||||
|
||||
def test_sign_request_with_body(self):
|
||||
signer = DAPISigner(api_key="k", api_secret="s")
|
||||
body = b'{"test": true}'
|
||||
headers = signer.sign_request("POST", "/v1/chat", body)
|
||||
|
||||
assert headers["X-DAPI-Key"] == "k"
|
||||
# Verify timestamp is recent
|
||||
ts = float(headers["X-DAPI-Timestamp"])
|
||||
assert abs(time.time() - ts) < 2
|
||||
|
||||
def test_sign_request_produces_valid_signature(self):
|
||||
signer = DAPISigner(api_key="k", api_secret="secret")
|
||||
body = b'hello'
|
||||
headers = signer.sign_request("POST", "/path", body)
|
||||
|
||||
expected = compute_dapi_signature("POST", "/path", headers["X-DAPI-Timestamp"], "secret", body)
|
||||
assert headers["X-DAPI-Signature"] == expected
|
||||
|
||||
def test_sign_request_uppercases_method(self):
|
||||
signer = DAPISigner(api_key="k", api_secret="s")
|
||||
headers = signer.sign_request("get", "/path")
|
||||
expected = compute_dapi_signature("GET", "/path", headers["X-DAPI-Timestamp"], "s")
|
||||
assert headers["X-DAPI-Signature"] == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RetryConfig tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRetryConfig:
|
||||
def test_defaults(self):
|
||||
config = RetryConfig()
|
||||
assert config.max_retries == 3
|
||||
assert config.backoff_factor == 0.5
|
||||
assert 429 in config.retry_on_status
|
||||
assert 500 in config.retry_on_status
|
||||
|
||||
def test_custom(self):
|
||||
config = RetryConfig(max_retries=5, backoff_factor=1.0)
|
||||
assert config.max_retries == 5
|
||||
assert config.backoff_factor == 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SageHttpClient tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSageHttpClient:
|
||||
def test_init_defaults(self):
|
||||
client = SageHttpClient(
|
||||
base_url="https://api.example.com",
|
||||
api_key="key",
|
||||
api_secret="secret",
|
||||
)
|
||||
assert client.base_url == "https://api.example.com"
|
||||
assert client.signer.api_key == "key"
|
||||
assert client.signer.api_secret == "secret"
|
||||
assert client.auto_sign is True
|
||||
|
||||
def test_init_with_custom_signer(self):
|
||||
signer = DAPISigner(api_key="k", api_secret="s")
|
||||
client = SageHttpClient(base_url="https://api.example.com", signer=signer)
|
||||
assert client.signer is signer
|
||||
|
||||
def test_init_without_auto_sign(self):
|
||||
client = SageHttpClient(
|
||||
base_url="https://api.example.com",
|
||||
auto_sign=False,
|
||||
)
|
||||
assert client.auto_sign is False
|
||||
|
||||
def test_build_url(self):
|
||||
client = SageHttpClient(base_url="https://api.example.com")
|
||||
assert client._build_url("/v1/test") == "https://api.example.com/v1/test"
|
||||
assert client._build_url("v1/test") == "https://api.example.com/v1/test"
|
||||
# Full URL should pass through
|
||||
assert client._build_url("https://other.com/path") == "https://other.com/path"
|
||||
|
||||
def test_get_relative_path(self):
|
||||
client = SageHttpClient(base_url="https://api.example.com")
|
||||
assert client._get_relative_path("/v1/chat") == "/v1/chat"
|
||||
assert client._get_relative_path("https://api.example.com/v1/chat?page=1") == "/v1/chat"
|
||||
|
||||
def test_sign_and_prepare_adds_dapi_headers(self):
|
||||
client = SageHttpClient(
|
||||
base_url="https://api.example.com",
|
||||
api_key="k",
|
||||
api_secret="s",
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
url, headers, body = asyncio.get_event_loop().run_until_complete(
|
||||
client._sign_and_prepare("GET", "/test")
|
||||
)
|
||||
|
||||
assert "X-DAPI-Key" in headers
|
||||
assert "X-DAPI-Timestamp" in headers
|
||||
assert "X-DAPI-Signature" in headers
|
||||
assert headers["X-DAPI-Key"] == "k"
|
||||
|
||||
def test_sign_and_prepare_with_json_body(self):
|
||||
client = SageHttpClient(
|
||||
base_url="https://api.example.com",
|
||||
api_key="k",
|
||||
api_secret="s",
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
url, headers, body = asyncio.get_event_loop().run_until_complete(
|
||||
client._sign_and_prepare("POST", "/test", json_body={"msg": "hi"})
|
||||
)
|
||||
|
||||
assert body is not None
|
||||
assert json.loads(body) == {"msg": "hi"}
|
||||
assert headers.get("Content-Type") == "application/json"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_manager(self):
|
||||
client = SageHttpClient(
|
||||
base_url="https://httpbin.org",
|
||||
api_key="k",
|
||||
api_secret="s",
|
||||
max_retries=0,
|
||||
)
|
||||
async with client as c:
|
||||
assert c is client
|
||||
assert client._closed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(self):
|
||||
client = SageHttpClient(
|
||||
base_url="https://httpbin.org",
|
||||
api_key="k",
|
||||
api_secret="s",
|
||||
max_retries=0,
|
||||
)
|
||||
await client._get_session() # create session
|
||||
await client.close()
|
||||
assert client._closed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_with_retry_on_502(self):
|
||||
"""Test that 502 responses trigger retries and raise after exhaustion."""
|
||||
client = SageHttpClient(
|
||||
base_url="",
|
||||
api_key="k",
|
||||
api_secret="s",
|
||||
max_retries=2,
|
||||
backoff_factor=0.01, # fast for tests
|
||||
auto_sign=False,
|
||||
)
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 502
|
||||
mock_response.request_info = MagicMock()
|
||||
mock_response.history = []
|
||||
mock_response.release = MagicMock()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.request = AsyncMock(return_value=mock_response)
|
||||
mock_session.closed = False
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
with pytest.raises(aiohttp.ClientResponseError) as exc_info:
|
||||
await client.request("GET", "/test")
|
||||
|
||||
assert exc_info.value.status == 502
|
||||
# Should have been called 3 times (1 initial + 2 retries)
|
||||
assert mock_session.request.call_count == 3
|
||||
|
||||
|
||||
class TestSageHttpClientIntegration:
|
||||
"""Integration tests against httpbin.org (requires network)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_request(self):
|
||||
client = SageHttpClient(
|
||||
base_url="https://httpbin.org",
|
||||
auto_sign=False,
|
||||
max_retries=0,
|
||||
timeout=10.0,
|
||||
)
|
||||
async with client:
|
||||
resp = await client.get("/get")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert "url" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_request_with_json(self):
|
||||
client = SageHttpClient(
|
||||
base_url="https://httpbin.org",
|
||||
auto_sign=False,
|
||||
max_retries=0,
|
||||
timeout=10.0,
|
||||
)
|
||||
async with client:
|
||||
resp = await client.post("/post", json={"key": "value"})
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["json"] == {"key": "value"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_headers_sent(self):
|
||||
client = SageHttpClient(
|
||||
base_url="https://httpbin.org",
|
||||
headers={"X-Custom-Header": "test-value"},
|
||||
auto_sign=False,
|
||||
max_retries=0,
|
||||
timeout=10.0,
|
||||
)
|
||||
async with client:
|
||||
resp = await client.get("/headers")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["headers"].get("X-Custom-Header") == "test-value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_params(self):
|
||||
client = SageHttpClient(
|
||||
base_url="https://httpbin.org",
|
||||
auto_sign=False,
|
||||
max_retries=0,
|
||||
timeout=10.0,
|
||||
)
|
||||
async with client:
|
||||
resp = await client.get("/get", params={"foo": "bar", "baz": "qux"})
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["args"]["foo"] == "bar"
|
||||
assert data["args"]["baz"] == "qux"
|
||||
Loading…
x
Reference in New Issue
Block a user