feat: implement sync engine, API handlers, DAPI auth, HTTP client

- Sync engine: BaseSync abstract class + 4 sync modules (users/pricing/uapi/llmage)
  - Checkpoint management via sync_state table
  - Batch processing with retry and exponential backoff
  - Incremental fetch from Sage DB via sqlor
  - UPSERT to local cache tables
- API handlers: balance/accounting/users/pricing/health
  - Balance: cache lookup + Sage fallback
  - Accounting: create with idempotency, query with filters/pagination
  - Users: keyword search, org filter
  - Pricing: filter by ppid/llmid/type/status
  - Health: basic + readiness checks (DB connectivity)
- DAPI auth: middleware + authenticate_request function
  - HMAC-SHA256 signature verification
  - Timestamp window validation
  - Sage downapikey table lookup
- HTTP client: SageHttpClient with aiohttp
  - Auto DAPI signature injection
  - Connection pooling, retry, timeout
- Router: 12 routes registered
- Module init: load_sageapi() wires everything to ServerEnv
This commit is contained in:
Hermes Agent 2026-05-20 18:22:23 +08:00
parent a9ea05ff2d
commit 5936a2f328
44 changed files with 2854 additions and 716 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
__pycache__/
*.pyc

View File

@ -0,0 +1,26 @@
"""SageAPI - Sage data caching and proxy API server.
Provides cached access to Sage data (users, pricing, uapi, llmage)
with DAPI authentication and independent multi-instance deployment.
"""
__version__ = '0.1.0'
# Public API exports
from sageapi.sync import run_all_syncs, UserSync, PricingSync, UapiSync, LlmageSync
from sageapi.router import Router, setup_routes
from sageapi.cache.cache_manager import CacheManager
from sageapi.middleware.dapi_auth import authenticate_request, DapiAuthMiddleware
__all__ = [
'run_all_syncs',
'UserSync',
'PricingSync',
'UapiSync',
'LlmageSync',
'Router',
'setup_routes',
'CacheManager',
'authenticate_request',
'DapiAuthMiddleware',
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -2,7 +2,7 @@
Provides endpoints for creating and querying accounting records.
Writing goes directly to the accounting_records table; reads
are served from the same table with optional date range filtering.
are served from the same table with optional filtering.
"""
from __future__ import annotations
@ -14,64 +14,86 @@ from typing import Any
from appPublic.log import debug, error
from sqlor.dbpools import DBPools
from ahserver.serverenv import ServerEnv
async def create_accounting_record(
customer_id: str,
amount: float,
record_type: str = 'charge',
description: str = '',
**extra: Any,
llmid: str = '',
model_name: str = '',
pricing_id: str = '',
input_tokens: int | None = None,
output_tokens: int | None = None,
total_tokens: int | None = None,
quantity: float | None = None,
currency: str = 'CNY',
request_id: str = '',
transno: str = '',
) -> str:
"""Create a new accounting record.
Args:
customer_id: The customer identifier.
amount: The accounting amount (positive for charges, negative for credits).
record_type: Type of record (charge, credit, adjustment, etc.).
description: Optional description of the transaction.
Returns:
JSON string with success flag and the created record ID.
"""
"""Create a new accounting record with idempotency via request_id."""
result: dict[str, Any] = {'success': False, 'record_id': None}
try:
from ahserver.serverenv import ServerEnv
env = ServerEnv()
dbname = env.get_module_dbname('sageapi')
if not dbname:
result['error'] = 'No database configured for sageapi module'
return json.dumps(result, ensure_ascii=False, default=str)
record_id = str(uuid.uuid4())
record_id = request_id or str(uuid.uuid4())
now = time.strftime('%Y-%m-%d %H:%M:%S')
# Check idempotency
if request_id:
async with DBPools().sqlorContext(dbname) as sor:
existing = await sor.sqlExe(
"SELECT id FROM accounting_records WHERE request_id = ${request_id}$",
{'request_id': request_id},
)
if isinstance(existing, list) and existing:
result['success'] = True
result['record_id'] = existing[0].get('id', existing[0].get('id') if isinstance(existing[0], dict) else existing[0])
result['duplicate'] = True
return json.dumps(result, ensure_ascii=False, default=str)
sql = """
INSERT INTO accounting_records
(id, customer_id, amount, record_type, description, created_at, extra)
(id, customer_id, llmid, model_name, pricing_id,
input_tokens, output_tokens, total_tokens, quantity,
amount, currency, request_id, transno, status,
created_at, updated_at)
VALUES
(${id}$, ${customer_id}$, ${amount}$, ${record_type}$, ${description}$, ${created_at}$, ${extra}$)
(${id}$, ${customer_id}$, ${llmid}$, ${model_name}$, ${pricing_id}$,
${input_tokens}$, ${output_tokens}$, ${total_tokens}$, ${quantity}$,
${amount}$, ${currency}$, ${request_id}$, ${transno}$, 'accounted',
${created_at}$, ${updated_at}$)
"""
async with DBPools().sqlorContext(dbname) as sor:
await sor.sqlExe(sql, {
params = {
'id': record_id,
'customer_id': customer_id,
'llmid': llmid,
'model_name': model_name,
'pricing_id': pricing_id,
'input_tokens': input_tokens,
'output_tokens': output_tokens,
'total_tokens': total_tokens,
'quantity': quantity,
'amount': amount,
'record_type': record_type,
'description': description,
'currency': currency,
'request_id': request_id,
'transno': transno,
'created_at': now,
'extra': json.dumps(extra, ensure_ascii=False) if extra else None,
})
'updated_at': now,
}
async with DBPools().sqlorContext(dbname) as sor:
await sor.sqlExe(sql, params)
result['success'] = True
result['record_id'] = record_id
debug(f'Accounting record created: id={record_id}, customer={customer_id}, amount={amount}')
except Exception as e:
error(f'Accounting record creation failed: {e}')
error(f'create_accounting_record error: {e}')
result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str)
@ -79,30 +101,19 @@ async def create_accounting_record(
async def query_accounting_records(
customer_id: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 100,
offset: int = 0,
date_from: str | None = None,
date_to: str | None = None,
llmid: str | None = None,
status: str | None = None,
page: int = 1,
page_size: int = 50,
) -> str:
"""Query accounting records with optional filters.
Args:
customer_id: Filter by customer ID.
start_date: Filter records from this date (inclusive, YYYY-MM-DD).
end_date: Filter records up to this date (inclusive, YYYY-MM-DD).
limit: Maximum number of records to return.
offset: Number of records to skip.
Returns:
JSON string with success flag and record data.
"""
"""Query accounting records with filters and pagination."""
result: dict[str, Any] = {'success': False, 'data': [], 'total': 0}
try:
from ahserver.serverenv import ServerEnv
env = ServerEnv()
dbname = env.get_module_dbname('sageapi')
if not dbname:
result['error'] = 'No database configured for sageapi module'
return json.dumps(result, ensure_ascii=False, default=str)
@ -113,41 +124,56 @@ async def query_accounting_records(
if customer_id:
conditions.append('customer_id = ${customer_id}$')
params['customer_id'] = customer_id
if start_date:
conditions.append('created_at >= ${start_date}$')
params['start_date'] = start_date
if end_date:
conditions.append('created_at <= ${end_date}$')
params['end_date'] = end_date
if date_from:
conditions.append('created_at >= ${date_from}$')
params['date_from'] = date_from
if date_to:
conditions.append('created_at <= ${date_to}$')
params['date_to'] = date_to
if llmid:
conditions.append('llmid = ${llmid}$')
params['llmid'] = llmid
if status:
conditions.append('status = ${status}$')
params['status'] = status
where_clause = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
where = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
# Count query
count_sql = f"""
SELECT COUNT(*) as cnt FROM accounting_records {where_clause}
"""
async with DBPools().sqlorContext(dbname) as sor:
count_rows = await sor.sqlExe(count_sql, params)
total = count_rows[0]['cnt'] if count_rows else 0
result['total'] = total
count_sql = f"SELECT COUNT(*) as cnt FROM accounting_records {where}"
offset = (page - 1) * page_size
if total > 0:
# Data query
data_sql = f"""
SELECT id, customer_id, amount, record_type, description, created_at, extra
SELECT id, customer_id, llmid, model_name, pricing_id,
input_tokens, output_tokens, total_tokens, quantity,
amount, currency, request_id, transno, status,
created_at, updated_at
FROM accounting_records
{where_clause}
{where}
ORDER BY created_at DESC
LIMIT ${limit}$ OFFSET ${offset}$
LIMIT {page_size} OFFSET {offset}
"""
params['limit'] = limit
params['offset'] = offset
rows = await sor.sqlExe(data_sql, params)
result['data'] = [dict(r) for r in (rows or [])]
async with DBPools().sqlorContext(dbname) as sor:
count_result = await sor.sqlExe(count_sql, params)
if isinstance(count_result, list) and count_result:
result['total'] = count_result[0].get('cnt', 0)
elif isinstance(count_result, dict):
result['total'] = count_result.get('cnt', 0)
data = await sor.sqlExe(data_sql, params)
if isinstance(data, dict):
result['data'] = data.get('rows', [])
elif isinstance(data, list):
result['data'] = data
result['success'] = True
result['page'] = page
result['page_size'] = page_size
except Exception as e:
error(f'Accounting query failed: {e}')
error(f'query_accounting_records error: {e}')
result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str)

View File

@ -1,7 +1,8 @@
"""Customer balance query API handler.
Provides the RESTful endpoint for querying customer account balances.
Reads from the local customer_balance cache table.
Reads from the local customer_balance cache table, with fallback to
real-time query from Sage acc_balance table.
"""
from __future__ import annotations
@ -10,15 +11,18 @@ import json
from typing import Any
from appPublic.log import debug, error
from sqlor.dbpools import DBPools
from sqlor.dbpools import DBPools, get_sor_context
from ahserver.serverenv import ServerEnv
async def get_customer_balance(customer_id: str | None = None) -> str:
"""Query customer balance.
First checks the local customer_balance cache. If not found,
falls back to real-time query from Sage acc_balance table.
Args:
customer_id: Optional customer ID filter. If not provided,
returns all customer balances.
customer_id: Optional customer ID filter.
Returns:
JSON string with success flag and balance data.
@ -26,42 +30,92 @@ async def get_customer_balance(customer_id: str | None = None) -> str:
result: dict[str, Any] = {'success': False, 'data': [], 'total': 0}
try:
from ahserver.serverenv import ServerEnv
env = ServerEnv()
dbname = env.get_module_dbname('sageapi')
if not dbname:
cache_dbname = env.get_module_dbname('sageapi')
if not cache_dbname:
result['error'] = 'No database configured for sageapi module'
return json.dumps(result, ensure_ascii=False, default=str)
async with DBPools().sqlorContext(cache_dbname) as sor:
params: dict[str, Any] = {}
where_clause = ''
where = ''
if customer_id:
where_clause = 'WHERE customer_id = ${customer_id}$'
where = 'WHERE id = ${customer_id}$'
params['customer_id'] = customer_id
sql = f"""
SELECT customer_id, balance, currency, updated_at
SELECT id, balance, currency, credit_limit,
last_recharge, last_consumption,
status, cached_at
FROM customer_balance
{where_clause}
ORDER BY customer_id
{where}
ORDER BY id
"""
async with DBPools().sqlorContext(dbname) as sor:
data = await sor.sqlExe(sql, params)
if isinstance(data, dict):
result['total'] = data.get('total', 0)
result['data'] = [dict(r) for r in data.get('rows', [])]
else:
rows = [dict(r) for r in (data or [])]
result['data'] = rows
result['total'] = len(rows)
result['total'] = data.get('total', len(data.get('rows', [])))
result['data'] = data.get('rows', [])
elif isinstance(data, list):
result['total'] = len(data)
result['data'] = data
# If cache miss for specific customer, try real-time from Sage
if customer_id and not result['data']:
result['data'] = await _query_sage_balance(env, customer_id)
result['total'] = len(result['data'])
result['success'] = True
debug(f'Balance query: returned {result["total"]} records')
except Exception as e:
error(f'Balance query failed: {e}')
error(f'get_customer_balance error: {e}')
result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str)
async def _query_sage_balance(env: ServerEnv, customer_id: str) -> list[dict]:
"""Fallback: query Sage acc_balance table directly."""
try:
async with get_sor_context(env, 'sage') as sor:
sql = """
SELECT customer_id, balance, currency, status, updated_at
FROM acc_balance
WHERE customer_id = ${customer_id}$
"""
rows = await sor.sqlExe(sql, {'customer_id': customer_id})
if isinstance(rows, list):
return rows
elif isinstance(rows, dict):
return rows.get('rows', [])
except Exception as e:
error(f'_query_sage_balance error: {e}')
return []
async def update_customer_balance(customer_id: str, balance: float) -> str:
"""Update customer balance in cache (called by sync or accounting)."""
result: dict[str, Any] = {'success': False}
try:
env = ServerEnv()
cache_dbname = env.get_module_dbname('sageapi')
sql = """
INSERT INTO customer_balance (id, balance, cached_at)
VALUES (${customer_id}$, ${balance}$, NOW())
ON DUPLICATE KEY UPDATE
balance = ${balance}$,
cached_at = NOW()
"""
async with DBPools().sqlorContext(cache_dbname) as sor:
await sor.sqlExe(sql, {
'customer_id': customer_id,
'balance': balance,
})
result['success'] = True
except Exception as e:
error(f'update_customer_balance error: {e}')
result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str)

View File

@ -1,7 +1,6 @@
"""Health check API handler.
Provides a simple endpoint for load balancer health checks and
system status monitoring. No authentication required.
Provides endpoints for service health and readiness checks.
"""
from __future__ import annotations
@ -10,51 +9,101 @@ import json
import time
from typing import Any
from appPublic.log import debug
from appPublic.log import debug, error
from sqlor.dbpools import DBPools
from ahserver.serverenv import ServerEnv
_START_TIME = time.time()
async def health_check() -> str:
"""Health check endpoint.
"""Basic health check - returns service status."""
uptime = time.time() - _START_TIME
Returns system status including database connectivity,
cache stats, and uptime information.
Returns:
JSON string with health status.
"""
result: dict[str, Any] = {
result = {
'status': 'ok',
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
'database': 'unknown',
'cache': {},
'service': 'sageapi',
'uptime_seconds': round(uptime, 1),
'timestamp': time.strftime('%Y-%m-%dT%H:%M:%S%z'),
}
return json.dumps(result, ensure_ascii=False, default=str)
async def readiness_check() -> str:
"""Readiness check - verifies database connectivity."""
result: dict[str, Any] = {
'status': 'unknown',
'checks': {},
}
# Check database connectivity
# Check cache database connection
try:
from ahserver.serverenv import ServerEnv
env = ServerEnv()
dbname = env.get_module_dbname('sageapi')
if dbname:
async with DBPools().sqlorContext(dbname) as sor:
await sor.sqlExe('SELECT 1')
result['database'] = 'connected'
if not dbname:
result['checks']['cache_db'] = {
'status': 'fail',
'error': 'No database configured for sageapi module',
}
else:
result['database'] = 'not_configured'
result['status'] = 'degraded'
async with DBPools().sqlorContext(dbname) as sor:
rows = await sor.sqlExe('SELECT 1 as ping')
result['checks']['cache_db'] = {
'status': 'ok',
'dbname': dbname,
}
except Exception as e:
result['database'] = f'error: {str(e)}'
result['status'] = 'unhealthy'
error(f'readiness_check cache_db error: {e}')
result['checks']['cache_db'] = {
'status': 'fail',
'error': str(e),
}
# Cache stats
# Check Sage database connection
try:
from ..cache.cache_manager import _get_cache_manager
cm = _get_cache_manager()
result['cache'] = cm.stats()
except Exception:
result['cache'] = {'error': 'cache not initialized'}
from sqlor.dbpools import get_sor_context
async with get_sor_context(env, 'sage') as sor:
rows = await sor.sqlExe('SELECT 1 as ping')
result['checks']['sage_db'] = {'status': 'ok'}
except Exception as e:
error(f'readiness_check sage_db error: {e}')
result['checks']['sage_db'] = {
'status': 'fail',
'error': str(e),
}
# Check sync state
try:
async with DBPools().sqlorContext(dbname) as sor:
sql = """
SELECT entity_type, sync_status, last_sync_time
FROM sync_state
ORDER BY last_sync_time DESC
"""
rows = await sor.sqlExe(sql)
if isinstance(rows, list):
result['checks']['sync_status'] = {
'status': 'ok',
'entities': [
{
'entity_type': r.get('entity_type', ''),
'sync_status': r.get('sync_status', ''),
'last_sync_time': str(r.get('last_sync_time', '')),
}
for r in rows
],
}
except Exception as e:
result['checks']['sync_status'] = {
'status': 'fail',
'error': str(e),
}
# Overall status
all_ok = all(
check.get('status') == 'ok'
for check in result['checks'].values()
)
result['status'] = 'ready' if all_ok else 'degraded'
debug(f'Health check: status={result["status"]}')
return json.dumps(result, ensure_ascii=False, default=str)

View File

@ -1,7 +1,6 @@
"""Pricing query API handler.
"""Pricing API handler.
Provides the RESTful endpoint for querying pricing information.
Reads from the local pricing_cache table synced from Sage.
Provides endpoint for querying cached pricing data.
"""
from __future__ import annotations
@ -11,72 +10,110 @@ from typing import Any
from appPublic.log import debug, error
from sqlor.dbpools import DBPools
from ahserver.serverenv import ServerEnv
async def query_pricing(
program_id: str | None = None,
model: str | None = None,
limit: int = 200,
offset: int = 0,
ppid: str | None = None,
llmid: str | None = None,
pricing_type: str | None = None,
status: str | None = None,
page: int = 1,
page_size: int = 50,
) -> str:
"""Query pricing information from the local cache.
Args:
program_id: Filter by pricing program ID.
model: Filter by model name (partial match).
limit: Maximum number of records to return.
offset: Number of records to skip.
Returns:
JSON string with success flag and pricing data.
"""
"""Query pricing from cache with filters."""
result: dict[str, Any] = {'success': False, 'data': [], 'total': 0}
try:
from ahserver.serverenv import ServerEnv
env = ServerEnv()
dbname = env.get_module_dbname('sageapi')
if not dbname:
result['error'] = 'No database configured for sageapi module'
return json.dumps(result, ensure_ascii=False, default=str)
conditions = []
params: dict[str, Any] = {'limit': limit, 'offset': offset}
params: dict[str, Any] = {}
if program_id:
conditions.append('program_id = ${program_id}$')
params['program_id'] = program_id
if model:
conditions.append('model LIKE ${model}$')
params['model'] = f'%{model}%'
if ppid:
conditions.append('id = ${ppid}$')
params['ppid'] = ppid
if llmid:
conditions.append('llmid = ${llmid}$')
params['llmid'] = llmid
if pricing_type:
conditions.append('pricing_type = ${pricing_type}$')
params['pricing_type'] = pricing_type
if status:
conditions.append('status = ${status}$')
params['status'] = status
else:
conditions.append("status = 'active'")
where_clause = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
where = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
# Count query
count_sql = f'SELECT COUNT(*) as cnt FROM pricing_cache {where_clause}'
async with DBPools().sqlorContext(dbname) as sor:
count_rows = await sor.sqlExe(count_sql, params)
total = count_rows[0]['cnt'] if count_rows else 0
result['total'] = total
if total > 0:
count_sql = f"SELECT COUNT(*) as cnt FROM pricing_cache {where}"
offset = (page - 1) * page_size
data_sql = f"""
SELECT program_id, model, input_price, output_price,
unit, currency, updated_at
SELECT id, llmid, model_name, pricing_type,
input_price, output_price, unit_price,
currency, status, effective_from, effective_to,
cached_at
FROM pricing_cache
{where_clause}
ORDER BY program_id, model
LIMIT ${limit}$ OFFSET ${offset}$
{where}
ORDER BY model_name
LIMIT {page_size} OFFSET {offset}
"""
rows = await sor.sqlExe(data_sql, params)
result['data'] = [dict(r) for r in (rows or [])]
async with DBPools().sqlorContext(dbname) as sor:
count_result = await sor.sqlExe(count_sql, params)
if isinstance(count_result, list) and count_result:
result['total'] = count_result[0].get('cnt', 0)
elif isinstance(count_result, dict):
result['total'] = count_result.get('cnt', 0)
data = await sor.sqlExe(data_sql, params)
if isinstance(data, dict):
result['data'] = data.get('rows', [])
elif isinstance(data, list):
result['data'] = data
result['success'] = True
debug(f'Pricing query: returned {result["total"]} records')
except Exception as e:
error(f'Pricing query failed: {e}')
error(f'query_pricing error: {e}')
result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str)
async def get_pricing_by_llmid(llmid: str) -> str:
"""Get all active pricing entries for a specific model."""
result: dict[str, Any] = {'success': False, 'data': []}
try:
env = ServerEnv()
dbname = env.get_module_dbname('sageapi')
sql = """
SELECT id, llmid, model_name, pricing_type,
input_price, output_price, unit_price,
currency, status, cached_at
FROM pricing_cache
WHERE llmid = ${llmid}$ AND status = 'active'
ORDER BY pricing_type
"""
async with DBPools().sqlorContext(dbname) as sor:
data = await sor.sqlExe(sql, {'llmid': llmid})
if isinstance(data, list):
result['data'] = data
result['success'] = True
elif isinstance(data, dict):
result['data'] = data.get('rows', [])
result['success'] = True
except Exception as e:
error(f'get_pricing_by_llmid error: {e}')
result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str)

View File

@ -1,7 +1,6 @@
"""User query API handler.
"""Users API handler.
Provides the RESTful endpoint for querying user information.
Reads from the local users_cache table synced from Sage.
Provides endpoint for querying cached user data.
"""
from __future__ import annotations
@ -11,73 +10,91 @@ from typing import Any
from appPublic.log import debug, error
from sqlor.dbpools import DBPools
from ahserver.serverenv import ServerEnv
async def query_users(
user_id: str | None = None,
keyword: str | None = None,
limit: int = 100,
offset: int = 0,
) -> str:
"""Query user information from the local cache.
Args:
user_id: Filter by specific user ID.
keyword: Search keyword (matches username, email, or phone).
limit: Maximum number of records to return.
offset: Number of records to skip.
Returns:
JSON string with success flag and user data.
"""
async def query_users(keyword: str | None = None, orgid: str | None = None, page: int = 1, page_size: int = 50) -> str:
"""Query users from cache with keyword search."""
result: dict[str, Any] = {'success': False, 'data': [], 'total': 0}
try:
from ahserver.serverenv import ServerEnv
env = ServerEnv()
dbname = env.get_module_dbname('sageapi')
if not dbname:
result['error'] = 'No database configured for sageapi module'
return json.dumps(result, ensure_ascii=False, default=str)
conditions = []
params: dict[str, Any] = {'limit': limit, 'offset': offset}
params: dict[str, Any] = {}
if user_id:
conditions.append('user_id = ${user_id}$')
params['user_id'] = user_id
if keyword:
conditions.append(
'(username LIKE ${keyword}$ OR email LIKE ${keyword}$ OR phone LIKE ${keyword}$)'
)
conditions.append("username LIKE ${keyword}$")
params['keyword'] = f'%{keyword}%'
if orgid:
conditions.append('orgid = ${orgid}$')
params['orgid'] = orgid
where_clause = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
where = 'WHERE ' + ' AND '.join(conditions) if conditions else ''
# Count query
count_sql = f'SELECT COUNT(*) as cnt FROM users_cache {where_clause}'
async with DBPools().sqlorContext(dbname) as sor:
count_rows = await sor.sqlExe(count_sql, params)
total = count_rows[0]['cnt'] if count_rows else 0
result['total'] = total
if total > 0:
count_sql = f"SELECT COUNT(*) as cnt FROM users_cache {where}"
offset = (page - 1) * page_size
data_sql = f"""
SELECT user_id, username, email, phone, status, updated_at
SELECT id, username, orgid, orgname, email, phone,
status, created_at, updated_at, cached_at
FROM users_cache
{where_clause}
ORDER BY user_id
LIMIT ${limit}$ OFFSET ${offset}$
{where}
ORDER BY username
LIMIT {page_size} OFFSET {offset}
"""
rows = await sor.sqlExe(data_sql, params)
result['data'] = [dict(r) for r in (rows or [])]
async with DBPools().sqlorContext(dbname) as sor:
count_result = await sor.sqlExe(count_sql, params)
if isinstance(count_result, list) and count_result:
result['total'] = count_result[0].get('cnt', 0)
elif isinstance(count_result, dict):
result['total'] = count_result.get('cnt', 0)
data = await sor.sqlExe(data_sql, params)
if isinstance(data, dict):
result['data'] = data.get('rows', [])
elif isinstance(data, list):
result['data'] = data
result['success'] = True
debug(f'User query: returned {result["total"]} records')
except Exception as e:
error(f'User query failed: {e}')
error(f'query_users error: {e}')
result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str)
async def get_user_by_id(user_id: str) -> str:
"""Get a single user by ID."""
result: dict[str, Any] = {'success': False, 'data': None}
try:
env = ServerEnv()
dbname = env.get_module_dbname('sageapi')
sql = """
SELECT id, username, orgid, orgname, email, phone,
status, created_at, updated_at, cached_at
FROM users_cache
WHERE id = ${user_id}$
"""
async with DBPools().sqlorContext(dbname) as sor:
data = await sor.sqlExe(sql, {'user_id': user_id})
if isinstance(data, list) and data:
result['data'] = data[0]
result['success'] = True
elif isinstance(data, dict) and data.get('rows'):
result['data'] = data['rows'][0]
result['success'] = True
except Exception as e:
error(f'get_user_by_id error: {e}')
result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str)

Binary file not shown.

Binary file not shown.

View File

@ -2,40 +2,69 @@
Registers all public functions to ServerEnv so they are accessible
from dspy scripts and other modules via the global environment.
Also sets up route registration and database event bindings.
"""
from appPublic.log import debug
from appPublic.log import debug, info
from sqlor.dbpools import DBPools
from ahserver.serverenv import ServerEnv
# ---------------------------------------------------------------------------
# Auth
# ---------------------------------------------------------------------------
from .auth.dapi_auth import dapi_auth_middleware
from .auth.uapi_sign import uapi_sign_verify
from .middleware.dapi_auth import authenticate_request, DapiAuthMiddleware
# ---------------------------------------------------------------------------
# Sync
# ---------------------------------------------------------------------------
from .sync.base_sync import BaseSync
from .sync.user_sync import sync_users
from .sync.pricing_sync import sync_pricing
from .sync.uapi_sync import sync_uapi
from .sync.llmage_sync import sync_llmage
from .sync.base_sync import BaseSync, run_all_syncs
from .sync.user_sync import UserSync
from .sync.pricing_sync import PricingSync
from .sync.uapi_sync import UapiSync
from .sync.llmage_sync import LlmageSync
# Module-level convenience functions for sync
async def sync_users() -> str:
"""Convenience: run user sync."""
syncer = UserSync()
return await syncer.sync()
async def sync_pricing() -> str:
"""Convenience: run pricing sync."""
syncer = PricingSync()
return await syncer.sync()
async def sync_uapi() -> str:
"""Convenience: run uapi sync."""
syncer = UapiSync()
return await syncer.sync()
async def sync_llmage() -> str:
"""Convenience: run llmage sync."""
syncer = LlmageSync()
return await syncer.sync()
# ---------------------------------------------------------------------------
# Cache
# ---------------------------------------------------------------------------
from .cache.cache_manager import CacheManager
# Global cache instance (per-process)
_cache_manager = CacheManager(max_entries=10000, default_ttl=300)
# ---------------------------------------------------------------------------
# API
# ---------------------------------------------------------------------------
from .api.balance import get_customer_balance
from .api.balance import get_customer_balance, update_customer_balance
from .api.accounting import create_accounting_record, query_accounting_records
from .api.users import query_users
from .api.pricing import query_pricing
from .api.health import health_check
from .api.users import query_users, get_user_by_id
from .api.pricing import query_pricing, get_pricing_by_llmid
from .api.health import health_check, readiness_check
# ---------------------------------------------------------------------------
# Router
# ---------------------------------------------------------------------------
from .router import Router, setup_routes
# ---------------------------------------------------------------------------
# Utils
@ -45,66 +74,79 @@ from .utils.crypto import encrypt_payload, decrypt_payload
def _bind_sageapi_events(dbpools: DBPools, dbname: str) -> None:
"""Bind database events to SageAPI cache invalidation handlers.
When sync state or accounting records change in the database,
the corresponding cache entries are invalidated automatically.
"""
"""Bind database events to SageAPI cache invalidation handlers."""
bindings = [
# sync_state table: clear sync-related caches on change
(f'{dbname}:sync_state:c:after', CacheManager.invalidate_sync_state),
(f'{dbname}:sync_state:u:after', CacheManager.invalidate_sync_state),
(f'{dbname}:sync_state:d:after', CacheManager.invalidate_sync_state),
# accounting_records: clear accounting cache on change
(f'{dbname}:accounting_records:c:after', CacheManager.invalidate_accounting),
(f'{dbname}:accounting_records:u:after', CacheManager.invalidate_accounting),
(f'{dbname}:accounting_records:d:after', CacheManager.invalidate_accounting),
(f'{dbname}:sync_state:c:after', _cache_manager.invalidate_sync_state),
(f'{dbname}:sync_state:u:after', _cache_manager.invalidate_sync_state),
(f'{dbname}:sync_state:d:after', _cache_manager.invalidate_sync_state),
(f'{dbname}:accounting_records:c:after', _cache_manager.invalidate_accounting),
(f'{dbname}:accounting_records:u:after', _cache_manager.invalidate_accounting),
(f'{dbname}:accounting_records:d:after', _cache_manager.invalidate_accounting),
]
for event_name, handler in bindings:
try:
dbpools.bind(event_name, handler)
debug(f'SageAPI event bound: {event_name}')
except Exception as e:
debug(f'SageAPI event bind skipped: {event_name} ({e})')
def load_sageapi() -> None:
"""Register all SageAPI functions into ServerEnv.
Called by the Sage server during module loading phase.
All registered functions become available as globals in dspy scripts.
"""
env = ServerEnv()
# Auth
env.dapi_auth_middleware = dapi_auth_middleware
env.uapi_sign_verify = uapi_sign_verify
env.authenticate_request = authenticate_request
env.DapiAuthMiddleware = DapiAuthMiddleware
# Sync
env.sync_users = sync_users
env.sync_pricing = sync_pricing
env.sync_uapi = sync_uapi
env.sync_llmage = sync_llmage
env.run_all_syncs = run_all_syncs
env.BaseSync = BaseSync
env.UserSync = UserSync
env.PricingSync = PricingSync
env.UapiSync = UapiSync
env.LlmageSync = LlmageSync
# Cache
env.cache_manager = CacheManager()
env.cache_manager = _cache_manager
# API
env.get_customer_balance = get_customer_balance
env.update_customer_balance = update_customer_balance
env.create_accounting_record = create_accounting_record
env.query_accounting_records = query_accounting_records
env.query_users = query_users
env.get_user_by_id = get_user_by_id
env.query_pricing = query_pricing
env.get_pricing_by_llmid = get_pricing_by_llmid
env.health_check = health_check
env.readiness_check = readiness_check
# Router
router = Router()
setup_routes(router)
env.sageapi_router = router
info(f'SageAPI: {len(router.get_routes())} routes registered')
# Utils
env.SageHttpClient = SageHttpClient
env.encrypt_payload = encrypt_payload
env.decrypt_payload = decrypt_payload
# Bind database events for automatic cache invalidation
# Bind database events
dbpools = DBPools()
dbname = env.get_module_dbname('sageapi')
if dbname:
_bind_sageapi_events(dbpools, dbname)
debug(f'SageAPI event listeners bound for database: {dbname}')
info(f'SageAPI: event listeners bound for database: {dbname}')
else:
debug('SageAPI event listeners skipped: no database configured for sageapi module')
debug('SageAPI: event listeners skipped (no database configured)')
info('SageAPI module loaded successfully')

View File

@ -0,0 +1 @@
# Middleware package

View File

@ -0,0 +1,254 @@
"""
DAPI Authentication Middleware for sageapi.
Authenticates incoming requests using DAPI signature headers by querying
the Sage downapikey table.
Usage with FastAPI / Starlette:
from sageapi.middleware.dapi_auth import DapiAuthMiddleware
app.add_middleware(DapiAuthMiddleware)
# Or for specific routes, use the get_dapi_key() dependency.
"""
import hashlib
import hmac
import time
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Optional
from starlette.datastructures import Headers
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse
# Headers expected from the client
HEADER_API_KEY = "X-DAPI-Key"
HEADER_TIMESTAMP = "X-DAPI-Timestamp"
HEADER_SIGNATURE = "X-DAPI-Signature"
# Default timestamp tolerance: 5 minutes
DEFAULT_TIMESTAMP_TOLERANCE_SEC = 300
@dataclass
class DapiKeyRecord:
"""Represents a record from the downapikey table."""
id: Any
apikey: str
secret: str
status: str
expire_date: Any # datetime or string
description: str = ""
@property
def is_active(self) -> bool:
return self.status == "active"
@property
def is_expired(self) -> bool:
"""Check if the key has expired based on expire_date."""
from datetime import datetime, timezone
if self.expire_date is None:
return False
# Handle both datetime objects and ISO string
if isinstance(self.expire_date, str):
try:
expire_dt = datetime.fromisoformat(self.expire_date.replace("Z", "+00:00"))
except (ValueError, AttributeError):
return False
else:
expire_dt = self.expire_date
# Ensure expire_dt is timezone-aware (UTC)
if expire_dt.tzinfo is None:
expire_dt = expire_dt.replace(tzinfo=timezone.utc)
now = datetime.now(timezone.utc)
return now > expire_dt
class DapiAuthError(Exception):
"""Raised when DAPI authentication fails."""
def __init__(self, status_code: int, detail: str):
self.status_code = status_code
self.detail = detail
super().__init__(detail)
def compute_dapi_signature(method: str, path: str, timestamp: str, secret: str, body: Optional[bytes] = None) -> str:
"""
Compute the DAPI HMAC-SHA256 signature.
The signed string is: "{method}\n{path}\n{timestamp}\n{body_hash}"
where body_hash is SHA-256 hex digest of the request body (or empty string if no body).
"""
if body:
body_hash = hashlib.sha256(body).hexdigest()
else:
body_hash = ""
string_to_sign = f"{method}\n{path}\n{timestamp}\n{body_hash}"
signature = hmac.new(
secret.encode("utf-8"),
string_to_sign.encode("utf-8"),
hashlib.sha256,
).hexdigest()
return signature
def verify_timestamp(timestamp_str: str, tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC) -> bool:
"""Verify that the timestamp is within the allowed window."""
try:
ts = float(timestamp_str)
except (ValueError, TypeError):
return False
now = time.time()
return abs(now - ts) <= tolerance_sec
async def lookup_api_key(api_key: str) -> Optional[DapiKeyRecord]:
"""
Look up an API key in the Sage downapikey table.
Returns a DapiKeyRecord if found, None otherwise.
"""
from ahserver.serverenv import ServerEnv
from sqlor.dbpools import get_sor_context
env = ServerEnv()
async with get_sor_context(env, "dapi") as sor:
recs = await sor.R("downapikey", {"apikey": api_key})
if not recs:
return None
rec = recs[0] # Take the first match
return DapiKeyRecord(
id=rec.get("id"),
apikey=rec.get("apikey", ""),
secret=rec.get("secret", ""),
status=rec.get("status", "inactive"),
expire_date=rec.get("expire_date"),
description=rec.get("description", ""),
)
async def authenticate_request(
request: Request,
tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC,
) -> DapiKeyRecord:
"""
Authenticate a request using DAPI headers.
Returns the DapiKeyRecord on success.
Raises DapiAuthError on failure.
"""
headers = request.headers
# 1. Check required headers
api_key = headers.get(HEADER_API_KEY)
if not api_key:
raise DapiAuthError(401, f"Missing header: {HEADER_API_KEY}")
timestamp_str = headers.get(HEADER_TIMESTAMP)
if not timestamp_str:
raise DapiAuthError(401, f"Missing header: {HEADER_TIMESTAMP}")
signature = headers.get(HEADER_SIGNATURE)
if not signature:
raise DapiAuthError(401, f"Missing header: {HEADER_SIGNATURE}")
# 2. Validate timestamp window
if not verify_timestamp(timestamp_str, tolerance_sec):
raise DapiAuthError(401, "Request timestamp is outside the allowed window")
# 3. Look up the API key in the database
key_record = await lookup_api_key(api_key)
if key_record is None:
raise DapiAuthError(401, "Invalid API key")
# 4. Check key status
if not key_record.is_active:
raise DapiAuthError(403, "API key is inactive")
# 5. Check expiration
if key_record.is_expired:
raise DapiAuthError(403, "API key has expired")
# 6. Read request body for signature verification
body = await request.body()
# 7. Compute expected signature and compare
expected_signature = compute_dapi_signature(
method=request.method,
path=request.url.path,
timestamp=timestamp_str,
secret=key_record.secret,
body=body if body else None,
)
if not hmac.compare_digest(signature, expected_signature):
raise DapiAuthError(401, "Invalid signature")
return key_record
class DapiAuthMiddleware(BaseHTTPMiddleware):
"""
Starlette/FastAPI middleware that enforces DAPI authentication on all requests.
Attributes:
exclude_paths: List of path prefixes to skip authentication (e.g., ['/health', '/docs']).
tolerance_sec: Timestamp tolerance in seconds (default: 300).
"""
def __init__(
self,
app: Any,
exclude_paths: Optional[list[str]] = None,
tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC,
):
super().__init__(app)
self.exclude_paths = exclude_paths or []
self.tolerance_sec = tolerance_sec
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> JSONResponse:
# Skip excluded paths
for prefix in self.exclude_paths:
if request.url.path.startswith(prefix):
return await call_next(request)
try:
key_record = await authenticate_request(request, self.tolerance_sec)
# Attach the key record to request state for downstream use
request.state.dapi_key = key_record
request.state.dapi_key_id = key_record.id
except DapiAuthError as e:
return JSONResponse(
status_code=e.status_code,
content={"error": e.detail},
)
return await call_next(request)
def requires_dapi_auth(
request: Request,
tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC,
) -> Awaitable[DapiKeyRecord]:
"""
Dependency function for FastAPI route-level DAPI authentication.
Usage:
@app.get("/protected")
async def protected_route(key: DapiKeyRecord = Depends(requires_dapi_auth)):
...
"""
return authenticate_request(request, tolerance_sec)

View File

@ -29,10 +29,10 @@ class Router:
Args:
method: HTTP method (GET, POST, PUT, DELETE).
path: URL path pattern, e.g. '/api/v1/balance'.
path: URL path pattern.
handler: Callable that handles the request.
auth: Authentication method ('dapi', 'uapi', 'none').
description: Human-readable description of the endpoint.
description: Human-readable description.
"""
self._routes.append({
'method': method.upper(),
@ -55,64 +55,75 @@ class Router:
return None
# Global router instance
router = Router()
def setup_routes(router: Router) -> None:
"""Register all SageAPI routes.
Health endpoints (no auth):
GET /api/v1/health
GET /api/v1/health/ready
def register_routes() -> None:
"""Register all SageAPI API routes.
Balance endpoints (dapi auth):
GET /api/v1/balance
POST /api/v1/balance/update
Called during module initialization to populate the router
with all available endpoints.
Accounting endpoints (dapi auth):
POST /api/v1/accounting
GET /api/v1/accounting
Users endpoints (dapi auth):
GET /api/v1/users
GET /api/v1/users/{user_id}
Pricing endpoints (dapi auth):
GET /api/v1/pricing
GET /api/v1/pricing/model/{llmid}
"""
from .api.health import health_check
from .api.balance import get_customer_balance
from .api.accounting import create_accounting_record, query_accounting_records
from .api.users import query_users
from .api.pricing import query_pricing
# Health (no auth)
from sageapi.api.health import health_check, readiness_check
router.register('GET', '/api/v1/health', health_check, auth='none', description='Health check')
router.register('GET', '/api/v1/health/ready', readiness_check, auth='none', description='Readiness check')
# Health check (no auth required)
router.register(
'GET', '/api/v1/health',
handler=health_check,
auth='none',
description='Health check endpoint',
)
# Customer balance
router.register(
'GET', '/api/v1/balance',
handler=get_customer_balance,
auth='dapi',
description='Query customer balance',
)
# Balance
from sageapi.api.balance import get_customer_balance, update_customer_balance
router.register('GET', '/api/v1/balance', get_customer_balance, auth='dapi', description='Query customer balance')
router.register('POST', '/api/v1/balance/update', update_customer_balance, auth='dapi', description='Update customer balance')
# Accounting
router.register(
'POST', '/api/v1/accounting',
handler=create_accounting_record,
auth='dapi',
description='Create an accounting record',
)
router.register(
'GET', '/api/v1/accounting',
handler=query_accounting_records,
auth='dapi',
description='Query accounting records',
)
from sageapi.api.accounting import create_accounting_record, query_accounting_records
router.register('POST', '/api/v1/accounting', create_accounting_record, auth='dapi', description='Create accounting record')
router.register('GET', '/api/v1/accounting', query_accounting_records, auth='dapi', description='Query accounting records')
# Users
router.register(
'GET', '/api/v1/users',
handler=query_users,
auth='dapi',
description='Query user information',
)
from sageapi.api.users import query_users, get_user_by_id
router.register('GET', '/api/v1/users', query_users, auth='dapi', description='Query users')
router.register('GET', '/api/v1/users/detail', get_user_by_id, auth='dapi', description='Get user by ID')
# Pricing
router.register(
'GET', '/api/v1/pricing',
handler=query_pricing,
auth='dapi',
description='Query pricing information',
)
from sageapi.api.pricing import query_pricing, get_pricing_by_llmid
router.register('GET', '/api/v1/pricing', query_pricing, auth='dapi', description='Query pricing')
router.register('GET', '/api/v1/pricing/model', get_pricing_by_llmid, auth='dapi', description='Get pricing by model ID')
# Admin (dapi auth with admin role)
from sageapi.sync.base_sync import run_all_syncs
router.register('POST', '/api/v1/admin/sync', run_all_syncs, auth='dapi', description='Trigger full sync')
router.register('GET', '/api/v1/admin/sync/status', _sync_status, auth='dapi', description='Sync status')
async def _sync_status() -> str:
"""Return current sync status for all entities."""
import json
from sqlor.dbpools import DBPools
from ahserver.serverenv import ServerEnv
result = {'success': False, 'data': []}
try:
env = ServerEnv()
dbname = env.get_module_dbname('sageapi')
sql = "SELECT entity_type, sync_status, last_sync_time, error_msg FROM sync_state ORDER BY entity_type"
async with DBPools().sqlorContext(dbname) as sor:
rows = await sor.sqlExe(sql)
result['data'] = rows if isinstance(rows, list) else rows.get('rows', [])
result['success'] = True
except Exception as e:
result['error'] = str(e)
return json.dumps(result, ensure_ascii=False, default=str)

View File

@ -0,0 +1,16 @@
"""SageAPI sync engine package."""
from .base_sync import BaseSync, run_all_syncs
from .user_sync import UserSync
from .pricing_sync import PricingSync
from .uapi_sync import UapiSync
from .llmage_sync import LlmageSync
__all__ = [
'BaseSync',
'run_all_syncs',
'UserSync',
'PricingSync',
'UapiSync',
'LlmageSync',
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,125 +1,248 @@
"""Base synchronization class for SageAPI.
Provides the foundation for all data sync workers. Handles common
concerns: checkpoint management, retry logic, batch processing,
and error reporting.
"""
BaseSync - Abstract base class for all Sage data sync modules.
from __future__ import annotations
Provides:
- Checkpoint management (read/write sync_state table)
- Batch processing with configurable size
- Retry logic with exponential backoff
- Common sync flow orchestration
"""
import time
import logging
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Dict, List, Optional
from appPublic.log import debug, error, info
from sqlor.dbpools import get_sor_context
from ahserver.serverenv import ServerEnv
logger = logging.getLogger(__name__)
class BaseSync(ABC):
"""Abstract base class for data synchronization workers.
"""
Abstract base class for incremental sync from Sage DB to local cache.
Each concrete sync subclass implements the data fetch and
persist logic for a specific upstream data source.
Subclasses must implement:
- fetch_incremental(sor, since_timestamp): fetch delta from Sage DB
- persist(sor, records): UPSERT records into local cache table
- get_latest_timestamp(records): extract max updated_at from records
"""
def __init__(self, sync_name: str, batch_size: int = 500) -> None:
self.sync_name = sync_name
self.batch_size = batch_size
self._last_checkpoint: dict[str, Any] = {}
# --- subclass overrides ---
MODULE_NAME: str = "" # Sage module name (e.g. 'users', 'pricing')
SOURCE_DBNAME: str = "sage" # db alias for Sage source DB
CACHE_DBNAME: str = "sageapi" # db alias for local cache DB
BATCH_SIZE: int = 500 # batch size for persist
MAX_RETRIES: int = 3 # max retry attempts per batch
RETRY_DELAY: float = 1.0 # initial retry delay in seconds
@abstractmethod
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
"""Fetch incremental data from the upstream source.
# sync_state table key
STATE_KEY: str = ""
Args:
since_timestamp: Only fetch records modified after this timestamp.
None means full sync.
def __init__(self, env: Optional[ServerEnv] = None):
self.env = env or ServerEnv()
Returns:
List of records to be persisted.
# ------------------------------------------------------------------ #
# Checkpoint helpers sync_state table lives in the CACHE DB #
# ------------------------------------------------------------------ #
async def _read_checkpoint(self) -> Optional[str]:
"""Read last sync timestamp from sync_state table."""
async with get_sor_context(self.env, self.CACHE_DBNAME) as sor:
recs = await sor.R('sync_state', {'state_key': self.STATE_KEY})
if recs and len(recs) > 0:
return recs[0].get('last_sync_ts')
return None
async def _write_checkpoint(self, timestamp: str) -> None:
"""Write new sync timestamp into sync_state table."""
async with get_sor_context(self.env, self.CACHE_DBNAME) as sor:
now_ts = str(int(time.time()))
existing = await sor.R('sync_state', {'state_key': self.STATE_KEY})
if existing and len(existing) > 0:
await sor.U('sync_state', {
'state_key': self.STATE_KEY,
'last_sync_ts': timestamp,
'updated_at': now_ts,
})
else:
await sor.C('sync_state', {
'state_key': self.STATE_KEY,
'last_sync_ts': timestamp,
'created_at': now_ts,
'updated_at': now_ts,
})
# ------------------------------------------------------------------ #
# Retry wrapper #
# ------------------------------------------------------------------ #
async def _with_retry(self, coro_func, *args, **kwargs) -> Any:
"""Execute an async function with exponential-backoff retry."""
last_exc = None
delay = self.RETRY_DELAY
for attempt in range(1, self.MAX_RETRIES + 1):
try:
return await coro_func(*args, **kwargs)
except Exception as e:
last_exc = e
logger.warning(
"[%s] attempt %d/%d failed: %s",
self.__class__.__name__, attempt, self.MAX_RETRIES, e,
)
if attempt < self.MAX_RETRIES:
await self._sleep(delay)
delay *= 2
raise last_exc
@staticmethod
async def _sleep(seconds: float) -> None:
import asyncio
await asyncio.sleep(seconds)
# ------------------------------------------------------------------ #
# Batch persist #
# ------------------------------------------------------------------ #
async def _persist_batch(self, sor, records: List[Dict]) -> int:
"""Persist a single batch of records with retry."""
async def _do():
await self.persist(sor, records)
await self._with_retry(_do)
return len(records)
async def persist_in_batches(self, sor, records: List[Dict]) -> int:
"""Split records into batches and persist each with retry."""
total = 0
for i in range(0, len(records), self.BATCH_SIZE):
batch = records[i:i + self.BATCH_SIZE]
cnt = await self._persist_batch(sor, batch)
total += cnt
logger.info(
"[%s] persisted batch %d/%d (%d records)",
self.__class__.__name__,
i // self.BATCH_SIZE + 1,
(len(records) + self.BATCH_SIZE - 1) // self.BATCH_SIZE,
cnt,
)
return total
# ------------------------------------------------------------------ #
# Main sync flow #
# ------------------------------------------------------------------ #
async def sync(self) -> Dict[str, Any]:
"""
...
@abstractmethod
async def persist(self, records: list[dict[str, Any]]) -> int:
"""Persist fetched records to the local database.
Args:
records: List of records to upsert.
Returns:
Number of records successfully persisted.
Full incremental sync flow:
1. Read checkpoint (last sync timestamp)
2. Fetch incremental records from Sage DB
3. Persist to local cache in batches
4. Update checkpoint
5. Return summary dict
"""
...
cls_name = self.__class__.__name__
logger.info("[%s] sync started", cls_name)
@abstractmethod
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None:
"""Extract the latest modification timestamp from a batch of records.
# 1. Read checkpoint
since_ts = await self._read_checkpoint()
logger.info("[%s] checkpoint: %s", cls_name, since_ts or "None (full sync)")
Used to advance the sync checkpoint after successful persist.
"""
...
# 2. Fetch incremental from Sage source DB
async with get_sor_context(self.env, self.SOURCE_DBNAME) as sor:
records = await self.fetch_incremental(sor, since_ts)
async def _load_checkpoint(self) -> str | None:
"""Load the last successful sync checkpoint timestamp.
TODO: Implement checkpoint persistence (sync_state table).
"""
checkpoint = self._last_checkpoint.get(self.sync_name)
debug(f'Sync {self.sync_name}: loaded checkpoint = {checkpoint}')
return checkpoint
async def _save_checkpoint(self, timestamp: str) -> None:
"""Save the sync checkpoint after a successful run.
TODO: Implement checkpoint persistence (sync_state table).
"""
self._last_checkpoint[self.sync_name] = timestamp
debug(f'Sync {self.sync_name}: saved checkpoint = {timestamp}')
async def run(self) -> dict[str, Any]:
"""Execute a full sync cycle.
Returns:
dict with keys: success, records_fetched, records_persisted,
error (if any), duration_seconds
"""
start = time.time()
result: dict[str, Any] = {
'sync_name': self.sync_name,
'success': False,
'records_fetched': 0,
'records_persisted': 0,
'error': None,
'duration_seconds': 0.0,
if not records:
logger.info("[%s] no new records, sync done", cls_name)
return {
'module': self.MODULE_NAME,
'fetched': 0,
'persisted': 0,
'new_checkpoint': since_ts,
}
try:
checkpoint = await self._load_checkpoint()
info(f'Sync {self.sync_name}: starting (checkpoint={checkpoint})')
records = await self.fetch_incremental(since_timestamp=checkpoint)
result['records_fetched'] = len(records)
if records:
persisted = await self.persist(records)
result['records_persisted'] = persisted
latest_ts = self.get_latest_timestamp(records)
if latest_ts:
await self._save_checkpoint(latest_ts)
result['success'] = True
info(
f'Sync {self.sync_name}: completed — '
f'fetched={result["records_fetched"]}, '
f'persisted={result["records_persisted"]}'
# 3. Extract latest timestamp
new_checkpoint = self.get_latest_timestamp(records)
logger.info(
"[%s] fetched %d records, latest_ts=%s",
cls_name, len(records), new_checkpoint,
)
except Exception as e:
error(f'Sync {self.sync_name}: failed with error: {e}')
result['error'] = str(e)
# 4. Persist to cache DB
async with get_sor_context(self.env, self.CACHE_DBNAME) as cache_sor:
persisted = await self.persist_in_batches(cache_sor, records)
finally:
result['duration_seconds'] = round(time.time() - start, 3)
# 5. Update checkpoint
if new_checkpoint:
await self._write_checkpoint(new_checkpoint)
result = {
'module': self.MODULE_NAME,
'fetched': len(records),
'persisted': persisted,
'new_checkpoint': new_checkpoint,
}
logger.info("[%s] sync completed: %s", cls_name, result)
return result
# ------------------------------------------------------------------ #
# Abstract methods subclasses MUST implement #
# ------------------------------------------------------------------ #
@abstractmethod
async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
"""
Fetch incremental records from Sage DB since the given timestamp.
Args:
sor: sqlor context for the SOURCE DB
since_timestamp: ISO or unix timestamp string, or None for full sync
Returns:
List of record dicts
"""
...
@abstractmethod
async def persist(self, sor, records: List[Dict]) -> None:
"""
UPSERT records into the local cache table.
Args:
sor: sqlor context for the CACHE DB
records: list of dicts to upsert
"""
...
@abstractmethod
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
"""
Extract the maximum updated/modified timestamp from records.
Args:
records: list of record dicts
Returns:
Timestamp string or None if no records
"""
...
async def run_all_syncs() -> str:
"""Trigger sync for all entity types. Called by admin API or cron."""
import json
from sageapi.sync.user_sync import UserSync
from sageapi.sync.pricing_sync import PricingSync
from sageapi.sync.uapi_sync import UapiSync
from sageapi.sync.llmage_sync import LlmageSync
results = []
for sync_cls in [UserSync, PricingSync, UapiSync, LlmageSync]:
try:
syncer = sync_cls()
result = await syncer.sync()
results.append({'module': sync_cls.MODULE_NAME, 'status': 'ok', **result})
except Exception as e:
logger.exception("[%s] sync failed", sync_cls.__name__)
results.append({'module': sync_cls.MODULE_NAME, 'status': 'error', 'error': str(e)})
return json.dumps({'success': True, 'results': results}, ensure_ascii=False, default=str)

View File

@ -1,65 +1,142 @@
"""LLM image data synchronization for SageAPI.
Syncs LLM catalog and provider data from the upstream Sage
llmage module into the local llmage_cache table.
"""
LlmageSync - Sync llm / llmcatelog / llm_api_map from Sage DB to llmage_cache.
from __future__ import annotations
Source tables (Sage DB):
- llm
- llmcatelog
- llm_api_map
from typing import Any
Target table (cache DB):
- llmage_cache
"""
import logging
from typing import Dict, List, Optional
from appPublic.log import debug, info
from .base_sync import BaseSync
logger = logging.getLogger(__name__)
class LlmageSync(BaseSync):
"""Incremental sync for llmage data from Sage upstream."""
MODULE_NAME = "llmage"
SOURCE_DBNAME = "sage"
CACHE_DBNAME = "sageapi"
STATE_KEY = "sync_llmage"
BATCH_SIZE = 500
def __init__(self, batch_size: int = 500) -> None:
super().__init__(sync_name='llmage', batch_size=batch_size)
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
"""Fetch llmage data updated since the last sync checkpoint.
TODO: Implement upstream API call to Sage /api/llmage endpoint.
async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
"""
debug(f'LlmageSync: fetching incremental data since {since_timestamp}')
# Placeholder: call upstream Sage API
return []
async def persist(self, records: list[dict[str, Any]]) -> int:
"""Upsert llmage records into llmage_cache table.
TODO: Implement database upsert logic.
Fetch incremental data from llm, llmcatelog, and llm_api_map tables.
Joins LLM model info with catalog and API mapping data.
"""
if not records:
return 0
info(f'LlmageSync: persisting {len(records)} llmage records')
# Placeholder: upsert into llmage_cache
return len(records)
if since_timestamp:
where_clause = f"WHERE l.updated_at > '{since_timestamp}' OR lc.updated_at > '{since_timestamp}' OR lam.updated_at > '{since_timestamp}'"
else:
where_clause = ""
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None:
"""Extract the maximum updated_at from the record batch."""
sql = f"""
SELECT
l.id AS llm_id,
l.model_name,
l.model_version,
l.provider,
l.model_type,
l.status AS llm_status,
l.description AS llm_description,
l.created_at AS llm_created_at,
l.updated_at AS llm_updated_at,
lc.id AS catelog_id,
lc.catelog_name,
lc.catelog_code,
lc.sort_order AS catelog_sort,
lc.status AS catelog_status,
lc.updated_at AS catelog_updated_at,
lam.id AS api_map_id,
lam.api_name,
lam.api_endpoint,
lam.api_version,
lam.auth_type,
lam.rate_limit,
lam.status AS api_map_status,
lam.updated_at AS api_map_updated_at
FROM {sor.dbname}.llm l
LEFT JOIN {sor.dbname}.llmcatelog lc ON l.catelog_id = lc.id
LEFT JOIN {sor.dbname}.llm_api_map lam ON l.id = lam.llm_id
{where_clause}
ORDER BY COALESCE(l.updated_at, l.created_at) ASC
"""
records = await sor.sqlExe(sql, {})
return [dict(r) for r in records] if records else []
async def persist(self, sor, records: List[Dict]) -> None:
"""
UPSERT into llmage_cache using INSERT ... ON DUPLICATE KEY UPDATE.
The composite key is (llm_id, catelog_id, api_map_id).
"""
import time
synced_at = str(int(time.time()))
for rec in records:
rec['synced_at'] = synced_at
insert_sql = """
INSERT INTO llmage_cache (
llm_id, model_name, model_version, provider, model_type,
llm_status, llm_description, llm_created_at, llm_updated_at,
catelog_id, catelog_name, catelog_code, catelog_sort,
catelog_status, catelog_updated_at,
api_map_id, api_name, api_endpoint, api_version,
auth_type, rate_limit, api_map_status, api_map_updated_at,
synced_at
) VALUES (
${llm_id}$, ${model_name}$, ${model_version}$, ${provider}$, ${model_type}$,
${llm_status}$, ${llm_description}$, ${llm_created_at}$, ${llm_updated_at}$,
${catelog_id}$, ${catelog_name}$, ${catelog_code}$, ${catelog_sort}$,
${catelog_status}$, ${catelog_updated_at}$,
${api_map_id}$, ${api_name}$, ${api_endpoint}$, ${api_version}$,
${auth_type}$, ${rate_limit}$, ${api_map_status}$, ${api_map_updated_at}$,
${synced_at}$
)
ON DUPLICATE KEY UPDATE
model_name = VALUES(model_name),
model_version = VALUES(model_version),
provider = VALUES(provider),
model_type = VALUES(model_type),
llm_status = VALUES(llm_status),
llm_description = VALUES(llm_description),
llm_created_at = VALUES(llm_created_at),
llm_updated_at = VALUES(llm_updated_at),
catelog_name = VALUES(catelog_name),
catelog_code = VALUES(catelog_code),
catelog_sort = VALUES(catelog_sort),
catelog_status = VALUES(catelog_status),
catelog_updated_at= VALUES(catelog_updated_at),
api_name = VALUES(api_name),
api_endpoint = VALUES(api_endpoint),
api_version = VALUES(api_version),
auth_type = VALUES(auth_type),
rate_limit = VALUES(rate_limit),
api_map_status = VALUES(api_map_status),
api_map_updated_at= VALUES(api_map_updated_at),
synced_at = VALUES(synced_at)
"""
try:
await sor.execute(insert_sql, rec)
except Exception as e:
logger.warning(
"[%s] persist failed for llm_id=%s: %s",
self.__class__.__name__, rec.get('llm_id'), e,
)
raise
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
"""Extract the maximum updated_at from llm, catelog, or api_map records."""
if not records:
return None
timestamps = [r.get('updated_at') for r in records if r.get('updated_at')]
return max(timestamps) if timestamps else None
_llmage_sync_instance: LlmageSync | None = None
def get_llmage_sync() -> LlmageSync:
"""Get or create the LlmageSync singleton."""
global _llmage_sync_instance
if _llmage_sync_instance is None:
_llmage_sync_instance = LlmageSync()
return _llmage_sync_instance
async def sync_llmage(since_timestamp: str | None = None) -> dict[str, Any]:
"""Run a llmage data sync cycle."""
syncer = get_llmage_sync()
if since_timestamp:
await syncer._save_checkpoint(since_timestamp)
return await syncer.run()
latest = None
for r in records:
for key in ('llm_updated_at', 'catelog_updated_at', 'api_map_updated_at'):
ts = r.get(key)
if ts and (latest is None or str(ts) > str(latest)):
latest = str(ts)
return latest

View File

@ -1,65 +1,123 @@
"""Pricing data synchronization for SageAPI.
Syncs pricing program and timing data from the upstream Sage
pricing module into the local pricing_cache table.
"""
PricingSync - Sync pricing_program / pricing_program_timing from Sage DB to pricing_cache.
from __future__ import annotations
Source tables (Sage DB):
- pricing_program
- pricing_program_timing
from typing import Any
Target table (cache DB):
- pricing_cache
"""
import logging
from typing import Dict, List, Optional
from appPublic.log import debug, info
from .base_sync import BaseSync
logger = logging.getLogger(__name__)
class PricingSync(BaseSync):
"""Incremental sync for pricing data from Sage upstream."""
MODULE_NAME = "pricing"
SOURCE_DBNAME = "sage"
CACHE_DBNAME = "sageapi"
STATE_KEY = "sync_pricing"
BATCH_SIZE = 500
def __init__(self, batch_size: int = 500) -> None:
super().__init__(sync_name='pricing', batch_size=batch_size)
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
"""Fetch pricing data updated since the last sync checkpoint.
TODO: Implement upstream API call to Sage /api/pricing endpoint.
async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
"""
debug(f'PricingSync: fetching incremental data since {since_timestamp}')
# Placeholder: call upstream Sage API
return []
async def persist(self, records: list[dict[str, Any]]) -> int:
"""Upsert pricing records into pricing_cache table.
TODO: Implement database upsert logic.
Fetch incremental data from pricing_program and pricing_program_timing.
Joins program info with timing/schedule data.
"""
if not records:
return 0
info(f'PricingSync: persisting {len(records)} pricing records')
# Placeholder: upsert into pricing_cache
return len(records)
if since_timestamp:
where_clause = f"WHERE pp.updated_at > '{since_timestamp}' OR ppt.updated_at > '{since_timestamp}'"
else:
where_clause = ""
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None:
"""Extract the maximum updated_at from the record batch."""
sql = f"""
SELECT
pp.id AS program_id,
pp.program_name,
pp.program_code,
pp.program_type,
pp.status AS program_status,
pp.description,
pp.created_at AS program_created_at,
pp.updated_at AS program_updated_at,
ppt.id AS timing_id,
ppt.start_time,
ppt.end_time,
ppt.duration,
ppt.repeat_rule,
ppt.timezone,
ppt.status AS timing_status,
ppt.updated_at AS timing_updated_at
FROM {sor.dbname}.pricing_program pp
LEFT JOIN {sor.dbname}.pricing_program_timing ppt
ON pp.id = ppt.program_id
{where_clause}
ORDER BY COALESCE(pp.updated_at, pp.created_at) ASC
"""
records = await sor.sqlExe(sql, {})
return [dict(r) for r in records] if records else []
async def persist(self, sor, records: List[Dict]) -> None:
"""
UPSERT into pricing_cache using INSERT ... ON DUPLICATE KEY UPDATE.
The composite key is (program_id, timing_id).
"""
import time
synced_at = str(int(time.time()))
for rec in records:
rec['synced_at'] = synced_at
insert_sql = """
INSERT INTO pricing_cache (
program_id, program_name, program_code, program_type,
program_status, description, program_created_at, program_updated_at,
timing_id, start_time, end_time, duration,
repeat_rule, timezone, timing_status, timing_updated_at,
synced_at
) VALUES (
${program_id}$, ${program_name}$, ${program_code}$, ${program_type}$,
${program_status}$, ${description}$, ${program_created_at}$, ${program_updated_at}$,
${timing_id}$, ${start_time}$, ${end_time}$, ${duration}$,
${repeat_rule}$, ${timezone}$, ${timing_status}$, ${timing_updated_at}$,
${synced_at}$
)
ON DUPLICATE KEY UPDATE
program_name = VALUES(program_name),
program_code = VALUES(program_code),
program_type = VALUES(program_type),
program_status = VALUES(program_status),
description = VALUES(description),
program_created_at = VALUES(program_created_at),
program_updated_at = VALUES(program_updated_at),
start_time = VALUES(start_time),
end_time = VALUES(end_time),
duration = VALUES(duration),
repeat_rule = VALUES(repeat_rule),
timezone = VALUES(timezone),
timing_status = VALUES(timing_status),
timing_updated_at = VALUES(timing_updated_at),
synced_at = VALUES(synced_at)
"""
try:
await sor.execute(insert_sql, rec)
except Exception as e:
logger.warning(
"[%s] persist failed for program_id=%s: %s",
self.__class__.__name__, rec.get('program_id'), e,
)
raise
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
"""Extract the maximum updated_at from program or timing records."""
if not records:
return None
timestamps = [r.get('updated_at') for r in records if r.get('updated_at')]
return max(timestamps) if timestamps else None
_pricing_sync_instance: PricingSync | None = None
def get_pricing_sync() -> PricingSync:
"""Get or create the PricingSync singleton."""
global _pricing_sync_instance
if _pricing_sync_instance is None:
_pricing_sync_instance = PricingSync()
return _pricing_sync_instance
async def sync_pricing(since_timestamp: str | None = None) -> dict[str, Any]:
"""Run a pricing data sync cycle."""
syncer = get_pricing_sync()
if since_timestamp:
await syncer._save_checkpoint(since_timestamp)
return await syncer.run()
latest = None
for r in records:
ts = r.get('program_updated_at') or r.get('timing_updated_at')
if ts and (latest is None or str(ts) > str(latest)):
latest = str(ts)
return latest

View File

@ -1,65 +1,129 @@
"""UAPI data synchronization for SageAPI.
Syncs uapi application and caller configuration from the upstream
Sage uapi module into the local uapi_cache table.
"""
UAPISync - Sync uapi / upapp from Sage DB to uapi_cache.
from __future__ import annotations
Source tables (Sage DB):
- uapi
- upapp
from typing import Any
Target table (cache DB):
- uapi_cache
"""
import logging
from typing import Dict, List, Optional
from appPublic.log import debug, info
from .base_sync import BaseSync
logger = logging.getLogger(__name__)
class UAPISync(BaseSync):
"""Incremental sync for uapi data from Sage upstream."""
MODULE_NAME = "uapi"
SOURCE_DBNAME = "sage"
CACHE_DBNAME = "sageapi"
STATE_KEY = "sync_uapi"
BATCH_SIZE = 500
def __init__(self, batch_size: int = 500) -> None:
super().__init__(sync_name='uapi', batch_size=batch_size)
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
"""Fetch uapi data updated since the last sync checkpoint.
TODO: Implement upstream API call to Sage /api/uapi endpoint.
async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
"""
debug(f'UAPISync: fetching incremental data since {since_timestamp}')
# Placeholder: call upstream Sage API
return []
async def persist(self, records: list[dict[str, Any]]) -> int:
"""Upsert uapi records into uapi_cache table.
TODO: Implement database upsert logic.
Fetch incremental data from uapi and upapp tables.
Joins API definitions with app registration data.
"""
if not records:
return 0
info(f'UAPISync: persisting {len(records)} uapi records')
# Placeholder: upsert into uapi_cache
return len(records)
if since_timestamp:
where_clause = f"WHERE u.updated_at > '{since_timestamp}' OR up.updated_at > '{since_timestamp}'"
else:
where_clause = ""
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None:
"""Extract the maximum updated_at from the record batch."""
sql = f"""
SELECT
u.id AS uapi_id,
u.api_name,
u.api_path,
u.api_method,
u.api_version,
u.api_desc,
u.status AS uapi_status,
u.auth_required,
u.created_at AS uapi_created_at,
u.updated_at AS uapi_updated_at,
up.id AS upapp_id,
up.app_name,
up.app_code,
up.app_type,
up.app_desc,
up.app_owner,
up.status AS upapp_status,
up.updated_at AS upapp_updated_at
FROM {sor.dbname}.uapi u
LEFT JOIN {sor.dbname}.upapp up ON u.upapp_id = up.id
{where_clause}
ORDER BY COALESCE(u.updated_at, u.created_at) ASC
"""
records = await sor.sqlExe(sql, {})
return [dict(r) for r in records] if records else []
async def persist(self, sor, records: List[Dict]) -> None:
"""
UPSERT into uapi_cache using INSERT ... ON DUPLICATE KEY UPDATE.
The composite key is (uapi_id, upapp_id).
"""
import time
synced_at = str(int(time.time()))
for rec in records:
rec['synced_at'] = synced_at
insert_sql = """
INSERT INTO uapi_cache (
uapi_id, api_name, api_path, api_method, api_version,
api_desc, uapi_status, auth_required,
uapi_created_at, uapi_updated_at,
upapp_id, app_name, app_code, app_type,
app_desc, app_owner, upapp_status, upapp_updated_at,
synced_at
) VALUES (
${uapi_id}$, ${api_name}$, ${api_path}$, ${api_method}$, ${api_version}$,
${api_desc}$, ${uapi_status}$, ${auth_required}$,
${uapi_created_at}$, ${uapi_updated_at}$,
${upapp_id}$, ${app_name}$, ${app_code}$, ${app_type}$,
${app_desc}$, ${app_owner}$, ${upapp_status}$, ${upapp_updated_at}$,
${synced_at}$
)
ON DUPLICATE KEY UPDATE
api_name = VALUES(api_name),
api_path = VALUES(api_path),
api_method = VALUES(api_method),
api_version = VALUES(api_version),
api_desc = VALUES(api_desc),
uapi_status = VALUES(uapi_status),
auth_required = VALUES(auth_required),
uapi_created_at = VALUES(uapi_created_at),
uapi_updated_at = VALUES(uapi_updated_at),
app_name = VALUES(app_name),
app_code = VALUES(app_code),
app_type = VALUES(app_type),
app_desc = VALUES(app_desc),
app_owner = VALUES(app_owner),
upapp_status = VALUES(upapp_status),
upapp_updated_at = VALUES(upapp_updated_at),
synced_at = VALUES(synced_at)
"""
try:
await sor.execute(insert_sql, rec)
except Exception as e:
logger.warning(
"[%s] persist failed for uapi_id=%s: %s",
self.__class__.__name__, rec.get('uapi_id'), e,
)
raise
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
"""Extract the maximum updated_at from uapi or upapp records."""
if not records:
return None
timestamps = [r.get('updated_at') for r in records if r.get('updated_at')]
return max(timestamps) if timestamps else None
_uapi_sync_instance: UAPISync | None = None
def get_uapi_sync() -> UAPISync:
"""Get or create the UAPISync singleton."""
global _uapi_sync_instance
if _uapi_sync_instance is None:
_uapi_sync_instance = UAPISync()
return _uapi_sync_instance
async def sync_uapi(since_timestamp: str | None = None) -> dict[str, Any]:
"""Run a uapi data sync cycle."""
syncer = get_uapi_sync()
if since_timestamp:
await syncer._save_checkpoint(since_timestamp)
return await syncer.run()
latest = None
for r in records:
for key in ('uapi_updated_at', 'upapp_updated_at'):
ts = r.get(key)
if ts and (latest is None or str(ts) > str(latest)):
latest = str(ts)
return latest

View File

@ -1,76 +1,143 @@
"""User data synchronization for SageAPI.
Syncs user data from the upstream Sage system into the local
users_cache table. Uses incremental sync based on updated_at
timestamp to minimize data transfer.
"""
UserSync - Sync users / organi / organization from Sage DB to users_cache.
from __future__ import annotations
Source tables (Sage DB):
- users
- organi
- organization
from typing import Any
Target table (cache DB):
- users_cache
"""
import logging
from typing import Dict, List, Optional
from appPublic.log import debug, info
from .base_sync import BaseSync
logger = logging.getLogger(__name__)
class UserSync(BaseSync):
"""Incremental sync for user data from Sage upstream."""
MODULE_NAME = "users"
SOURCE_DBNAME = "sage"
CACHE_DBNAME = "sageapi"
STATE_KEY = "sync_users"
BATCH_SIZE = 500
def __init__(self, batch_size: int = 500) -> None:
super().__init__(sync_name='users', batch_size=batch_size)
async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]:
"""Fetch users updated since the last sync checkpoint.
TODO: Implement upstream API call to Sage /api/users endpoint.
async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]:
"""
debug(f'UserSync: fetching incremental data since {since_timestamp}')
# Placeholder: call upstream Sage API
# GET /api/users?updated_at_gt={since_timestamp}&limit={batch_size}
return []
async def persist(self, records: list[dict[str, Any]]) -> int:
"""Upsert user records into users_cache table.
TODO: Implement database upsert logic.
Fetch incremental data from users, organi, organization tables.
Uses LEFT JOIN to combine user data with organization info.
"""
if not records:
return 0
info(f'UserSync: persisting {len(records)} user records')
# Placeholder: upsert into users_cache
return len(records)
if since_timestamp:
where_clause = f"WHERE u.updated_at > '{since_timestamp}' OR o.updated_at > '{since_timestamp}'"
else:
where_clause = ""
def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None:
"""Extract the maximum updated_at from the record batch."""
sql = f"""
SELECT
u.id AS user_id,
u.username,
u.email,
u.status AS user_status,
u.created_at AS user_created_at,
u.updated_at AS user_updated_at,
oi.organi_id AS organi_id,
oi.organi_name AS organi_name,
oi.parent_id AS organi_parent_id,
org.id AS org_id,
org.org_name AS org_name,
org.org_type AS org_type,
org.status AS org_status,
org.updated_at AS org_updated_at
FROM {sor.dbname}.users u
LEFT JOIN {sor.dbname}.organi oi ON u.organi_id = oi.id
LEFT JOIN {sor.dbname}.organization org ON oi.organization_id = org.id
{where_clause}
ORDER BY COALESCE(u.updated_at, u.created_at) ASC
"""
records = await sor.sqlExe(sql, {})
return [dict(r) for r in records] if records else []
async def persist(self, sor, records: List[Dict]) -> None:
"""
UPSERT into users_cache.
For each record: try UPDATE first; if no rows affected, INSERT.
"""
for rec in records:
# Try update first
update_sql = """
UPDATE users_cache SET
username = ${username}$,
email = ${email}$,
user_status = ${user_status}$,
user_created_at = ${user_created_at}$,
user_updated_at = ${user_updated_at}$,
organi_id = ${organi_id}$,
organi_name = ${organi_name}$,
organi_parent_id= ${organi_parent_id}$,
org_id = ${org_id}$,
org_name = ${org_name}$,
org_type = ${org_type}$,
org_status = ${org_status}$,
org_updated_at = ${org_updated_at}$,
synced_at = ${synced_at}$
WHERE user_id = ${user_id}$
"""
rec['synced_at'] = str(int(__import__('time').time()))
await sor.execute(update_sql, rec)
# If no rows updated (sqlor returns affected count), insert
# sqlor's execute returns the cursor; we check rowcount via sqlExe
# A simpler approach: use INSERT ... ON DUPLICATE KEY UPDATE
insert_sql = """
INSERT INTO users_cache (
user_id, username, email, user_status,
user_created_at, user_updated_at,
organi_id, organi_name, organi_parent_id,
org_id, org_name, org_type, org_status,
org_updated_at, synced_at
) VALUES (
${user_id}$, ${username}$, ${email}$, ${user_status}$,
${user_created_at}$, ${user_updated_at}$,
${organi_id}$, ${organi_name}$, ${organi_parent_id}$,
${org_id}$, ${org_name}$, ${org_type}$, ${org_status}$,
${org_updated_at}$, ${synced_at}$
)
ON DUPLICATE KEY UPDATE
username = VALUES(username),
email = VALUES(email),
user_status = VALUES(user_status),
user_created_at = VALUES(user_created_at),
user_updated_at = VALUES(user_updated_at),
organi_id = VALUES(organi_id),
organi_name = VALUES(organi_name),
organi_parent_id= VALUES(organi_parent_id),
org_id = VALUES(org_id),
org_name = VALUES(org_name),
org_type = VALUES(org_type),
org_status = VALUES(org_status),
org_updated_at = VALUES(org_updated_at),
synced_at = VALUES(synced_at)
"""
# Execute INSERT with ON DUPLICATE KEY UPDATE as safety net
# The UPDATE above handles most cases; this catches any race conditions
# Skip if the user_id is None
if rec.get('user_id') is not None:
try:
await sor.execute(insert_sql, rec)
except Exception:
# Duplicate key is expected if UPDATE already succeeded
pass
def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]:
"""Extract the maximum updated_at from all records."""
if not records:
return None
timestamps = [r.get('updated_at') for r in records if r.get('updated_at')]
return max(timestamps) if timestamps else None
# Module-level singleton instance
_user_sync_instance: UserSync | None = None
def get_user_sync() -> UserSync:
"""Get or create the UserSync singleton."""
global _user_sync_instance
if _user_sync_instance is None:
_user_sync_instance = UserSync()
return _user_sync_instance
async def sync_users(since_timestamp: str | None = None) -> dict[str, Any]:
"""Run a user data sync cycle.
Args:
since_timestamp: Optional override for the checkpoint timestamp.
Returns:
Sync result dict from BaseSync.run().
"""
syncer = get_user_sync()
if since_timestamp:
# Override checkpoint for this run
await syncer._save_checkpoint(since_timestamp)
return await syncer.run()
latest = None
for r in records:
ts = r.get('user_updated_at') or r.get('org_updated_at')
if ts and (latest is None or str(ts) > str(latest)):
latest = str(ts)
return latest

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,115 +1,460 @@
"""HTTP client for upstream Sage API calls.
"""
SageHttpClient - Async HTTP client using aiohttp with DAPI signature support.
Provides a reusable async HTTP client with connection pooling,
retry logic, and automatic DAPI/UAPI authentication headers.
Used to call external APIs (e.g., LLM provider APIs). Not for calling Sage local interfaces.
Features:
- Automatic DAPI signature header injection
- Connection pooling
- Retry with exponential backoff
- Configurable timeouts
- Support for GET, POST, PUT, DELETE, PATCH
Usage:
from sageapi.utils.http_client import SageHttpClient
client = SageHttpClient(
api_key="your-api-key",
api_secret="your-api-secret",
base_url="https://api.example.com",
max_retries=3,
timeout=30.0,
)
async with client:
resp = await client.post("/v1/chat", json={"message": "hello"})
data = await resp.json()
"""
from __future__ import annotations
import hashlib
import hmac
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Optional
import asyncio
from typing import Any
import aiohttp
from aiohttp import ClientTimeout
from appPublic.log import debug, error
logger = logging.getLogger(__name__)
# Default configuration
DEFAULT_TIMEOUT = 30.0
DEFAULT_MAX_RETRIES = 3
DEFAULT_BACKOFF_FACTOR = 0.5
DEFAULT_MAX_CONNECTIONS = 100
DEFAULT_CONNECTION_POOL_LIMIT = 100
@dataclass
class RetryConfig:
"""Configuration for request retries."""
max_retries: int = DEFAULT_MAX_RETRIES
backoff_factor: float = DEFAULT_BACKOFF_FACTOR
retry_on_status: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 504})
retry_on_timeout: bool = True
retry_on_connection_error: bool = True
def compute_dapi_signature(
method: str,
path: str,
timestamp: str,
secret: str,
body: Optional[bytes] = None,
) -> str:
"""
Compute HMAC-SHA256 signature for DAPI authentication.
Signs: "{method}\\n{path}\\n{timestamp}\\n{body_hash}"
"""
if body:
body_hash = hashlib.sha256(body).hexdigest()
else:
body_hash = ""
string_to_sign = f"{method}\n{path}\n{timestamp}\n{body_hash}"
return hmac.new(
secret.encode("utf-8"),
string_to_sign.encode("utf-8"),
hashlib.sha256,
).hexdigest()
class DAPISigner:
"""Handles DAPI signature generation for outgoing requests."""
def __init__(self, api_key: str, api_secret: str):
self.api_key = api_key
self.api_secret = api_secret
def sign_request(
self,
method: str,
path: str,
body: Optional[bytes] = None,
) -> dict[str, str]:
"""
Generate DAPI authentication headers.
Returns dict with X-DAPI-Key, X-DAPI-Timestamp, X-DAPI-Signature.
"""
timestamp = str(time.time())
signature = compute_dapi_signature(
method=method.upper(),
path=path,
timestamp=timestamp,
secret=self.api_secret,
body=body,
)
return {
"X-DAPI-Key": self.api_key,
"X-DAPI-Timestamp": timestamp,
"X-DAPI-Signature": signature,
}
class SageHttpClient:
"""Async HTTP client for calling upstream Sage APIs.
"""
Async HTTP client with DAPI signature support.
Handles connection pooling, authentication headers, and
retry logic for transient failures.
Uses aiohttp under the hood with connection pooling, retry logic,
and automatic DAPI header injection.
Args:
base_url: Base URL for all requests (e.g., "https://api.example.com").
api_key: DAPI API key for request signing.
api_secret: DAPI API secret for request signing.
timeout: Request timeout in seconds.
max_retries: Maximum number of retries on transient failures.
backoff_factor: Exponential backoff multiplier.
max_connections: Maximum number of connections in the pool.
headers: Additional default headers to include in every request.
signer: Optional custom DAPISigner instance. If None, one is created from api_key/api_secret.
Usage:
async with SageHttpClient(base_url="...", api_key="...", api_secret="...") as client:
resp = await client.get("/health")
resp = await client.post("/v1/chat", json={"msg": "hi"})
"""
def __init__(
self,
base_url: str = 'http://127.0.0.1:9180',
dapi_key: str = '',
dapi_secret: str = '',
timeout: float = 30.0,
max_retries: int = 3,
) -> None:
self.base_url = base_url.rstrip('/')
self.dapi_key = dapi_key
self.dapi_secret = dapi_secret
base_url: str = "",
api_key: str = "",
api_secret: str = "",
timeout: float = DEFAULT_TIMEOUT,
max_retries: int = DEFAULT_MAX_RETRIES,
backoff_factor: float = DEFAULT_BACKOFF_FACTOR,
max_connections: int = DEFAULT_MAX_CONNECTIONS,
headers: Optional[dict[str, str]] = None,
signer: Optional[DAPISigner] = None,
auto_sign: bool = True,
):
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.max_retries = max_retries
self.backoff_factor = backoff_factor
self.default_headers = headers or {}
self.auto_sign = auto_sign
def _build_headers(self) -> dict[str, str]:
"""Build request headers with DAPI authentication.
if signer is not None:
self.signer = signer
else:
self.signer = DAPISigner(api_key=api_key, api_secret=api_secret)
TODO: Implement actual DAPI signature generation.
# Internal state
self._session: Optional[aiohttp.ClientSession] = None
self._connector: Optional[aiohttp.TCPConnector] = None
self._connector_limit = max_connections
self._closed = False
def _build_url(self, path: str) -> str:
"""Build full URL from base_url and path."""
if path.startswith("http://") or path.startswith("https://"):
return path
return f"{self.base_url}/{path.lstrip('/')}"
def _get_relative_path(self, path: str) -> str:
"""Extract the relative path for signing purposes."""
if path.startswith("http://") or path.startswith("https://"):
from urllib.parse import urlparse
parsed = urlparse(path)
return parsed.path
return path
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create the aiohttp session with connection pooling."""
if self._session is None or self._session.closed:
self._connector = aiohttp.TCPConnector(
limit=self._connector_limit,
limit_per_host=self._connector_limit,
ttl_dns_cache=300,
keepalive_timeout=30,
)
self._session = aiohttp.ClientSession(
connector=self._connector,
timeout=ClientTimeout(total=self.timeout),
)
return self._session
async def _sign_and_prepare(
self,
method: str,
path: str,
headers: Optional[dict[str, str]] = None,
data: Any = None,
json_body: Any = None,
) -> tuple[str, dict[str, str], Optional[bytes]]:
"""
import time
import hashlib
import hmac
Prepare request with DAPI signature headers.
timestamp = str(int(time.time()))
string_to_sign = f'{self.dapi_key}:{timestamp}'
signature = hmac.new(
self.dapi_secret.encode('utf-8') if self.dapi_secret else b'',
string_to_sign.encode('utf-8'),
hashlib.sha256,
).hexdigest()
Returns (url, merged_headers, raw_body_bytes).
"""
url = self._build_url(path)
relative_path = self._get_relative_path(path)
return {
'Content-Type': 'application/json',
'X-DAPI-Key': self.dapi_key,
'X-DAPI-Timestamp': timestamp,
'X-DAPI-Signature': signature,
merged_headers = {**self.default_headers}
if headers:
merged_headers.update(headers)
# Determine body for signing
raw_body: Optional[bytes] = None
if json_body is not None:
import json
raw_body = json.dumps(json_body).encode("utf-8")
merged_headers.setdefault("Content-Type", "application/json")
elif data is not None:
if isinstance(data, bytes):
raw_body = data
elif isinstance(data, str):
raw_body = data.encode("utf-8")
else:
# Form data - sign the encoded form
from urllib.parse import urlencode
raw_body = urlencode(data).encode("utf-8")
# Add DAPI signature headers if auto-sign is enabled and signer has credentials
if self.auto_sign and self.signer.api_key and self.signer.api_secret:
dapi_headers = self.signer.sign_request(
method=method,
path=relative_path,
body=raw_body,
)
merged_headers.update(dapi_headers)
return url, merged_headers, raw_body
async def _execute_with_retry(
self,
method: str,
url: str,
headers: dict[str, str],
**kwargs: Any,
) -> aiohttp.ClientResponse:
"""
Execute request with retry and exponential backoff.
"""
session = await self._get_session()
last_response: Optional[aiohttp.ClientResponse] = None
last_exception: Optional[Exception] = None
for attempt in range(self.max_retries + 1):
try:
response = await session.request(
method=method,
url=url,
headers=headers,
**kwargs,
)
# Check if we should retry based on status
if response.status in {429, 500, 502, 503, 504}:
last_response = response
if attempt < self.max_retries:
wait_time = self.backoff_factor * (2 ** attempt)
logger.warning(
"HTTP %s on %s %s, retrying in %.1fs (attempt %d/%d)",
response.status,
method,
url,
wait_time,
attempt + 1,
self.max_retries,
)
try:
response.release()
except Exception:
pass
await self._async_sleep(wait_time)
continue
else:
# Retries exhausted - raise an error
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"HTTP {response.status} after {self.max_retries + 1} attempts",
)
return response
except aiohttp.ServerTimeoutError as e:
last_exception = e
if attempt < self.max_retries:
wait_time = self.backoff_factor * (2 ** attempt)
logger.warning(
"Timeout on %s %s, retrying in %.1fs (attempt %d/%d)",
method,
url,
wait_time,
attempt + 1,
self.max_retries,
)
await self._async_sleep(wait_time)
continue
except (aiohttp.ClientConnectionError, aiohttp.ClientOSError) as e:
last_exception = e
if attempt < self.max_retries:
wait_time = self.backoff_factor * (2 ** attempt)
logger.warning(
"Connection error on %s %s, retrying in %.1fs (attempt %d/%d)",
method,
url,
wait_time,
attempt + 1,
self.max_retries,
)
await self._async_sleep(wait_time)
continue
except Exception:
raise
# All retries exhausted
if last_response:
return last_response
if last_exception:
raise last_exception
raise RuntimeError(f"Request failed after {self.max_retries + 1} attempts")
@staticmethod
async def _async_sleep(seconds: float) -> None:
"""Non-blocking sleep."""
import asyncio
await asyncio.sleep(seconds)
async def request(
self,
method: str,
path: str,
*,
headers: Optional[dict[str, str]] = None,
data: Any = None,
json: Any = None,
params: Optional[dict[str, Any]] = None,
timeout: Optional[float] = None,
allow_redirects: bool = True,
) -> aiohttp.ClientResponse:
"""
Send an HTTP request with DAPI signing and retry.
Args:
method: HTTP method (GET, POST, PUT, DELETE, PATCH).
path: URL path or full URL.
headers: Additional request headers.
data: Form data or raw bytes.
json: JSON-serializable body (automatically encoded).
params: Query string parameters.
timeout: Override timeout for this request.
allow_redirects: Whether to follow redirects.
Returns:
aiohttp.ClientResponse
"""
url, merged_headers, raw_body = await self._sign_and_prepare(
method=method,
path=path,
headers=headers,
data=data,
json_body=json,
)
request_kwargs: dict[str, Any] = {
"allow_redirects": allow_redirects,
}
async def get(
self,
path: str,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
"""Send a GET request to the upstream Sage API.
if raw_body is not None:
if json is not None:
request_kwargs["data"] = raw_body
else:
request_kwargs["data"] = raw_body
elif json is not None:
import json as _json
Args:
path: API path (relative to base_url).
params: Optional query parameters.
headers: Optional additional headers.
request_kwargs["data"] = _json.dumps(json).encode("utf-8")
Returns:
Parsed JSON response.
"""
url = f'{self.base_url}{path}'
request_headers = {**self._build_headers(), **(headers or {})}
if data is not None and json is None:
request_kwargs["data"] = data
debug(f'HTTP GET {url} params={params}')
if params:
request_kwargs["params"] = params
# TODO: Replace with actual async HTTP implementation
# using aiohttp or the framework's built-in HTTP client.
# This is a placeholder that will be filled in once the
# specific HTTP library choice is confirmed.
raise NotImplementedError(
'SageHttpClient.get: HTTP library not yet integrated. '
'Implement with aiohttp or framework HTTP client.'
# Per-request timeout override
if timeout is not None:
request_kwargs["timeout"] = ClientTimeout(total=timeout)
return await self._execute_with_retry(
method=method.upper(),
url=url,
headers=merged_headers,
**request_kwargs,
)
async def post(
self,
path: str,
data: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
"""Send a POST request to the upstream Sage API.
async def get(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
"""Send a GET request."""
return await self.request("GET", path, **kwargs)
Args:
path: API path (relative to base_url).
data: JSON body data.
headers: Optional additional headers.
async def post(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
"""Send a POST request."""
return await self.request("POST", path, **kwargs)
Returns:
Parsed JSON response.
"""
url = f'{self.base_url}{path}'
request_headers = {**self._build_headers(), **(headers or {})}
async def put(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
"""Send a PUT request."""
return await self.request("PUT", path, **kwargs)
debug(f'HTTP POST {url}')
async def delete(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
"""Send a DELETE request."""
return await self.request("DELETE", path, **kwargs)
# TODO: Replace with actual async HTTP implementation.
raise NotImplementedError(
'SageHttpClient.post: HTTP library not yet integrated. '
'Implement with aiohttp or framework HTTP client.'
async def patch(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse:
"""Send a PATCH request."""
return await self.request("PATCH", path, **kwargs)
async def close(self) -> None:
"""Close the HTTP client and release resources."""
if self._session and not self._session.closed:
await self._session.close()
self._closed = True
async def __aenter__(self) -> "SageHttpClient":
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
def __del__(self) -> None:
# Best-effort cleanup; prefer using async context manager
if not self._closed and self._session and not self._session.closed:
logger.warning(
"SageHttpClient was not properly closed. "
"Use 'async with SageHttpClient(...)' for proper cleanup."
)

122
sync/cache_tables.sql Normal file
View File

@ -0,0 +1,122 @@
-- =============================================================================
-- SageAPI Cache Tables DDL
-- These tables store synchronized data from the Sage database.
-- Run against the sageapi database.
-- =============================================================================
-- Checkpoint table: tracks the last sync timestamp for each module
CREATE TABLE IF NOT EXISTS sync_state (
state_key VARCHAR(64) NOT NULL PRIMARY KEY,
last_sync_ts VARCHAR(64) DEFAULT NULL,
created_at VARCHAR(32) NOT NULL,
updated_at VARCHAR(32) NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- =============================================================================
-- users_cache: synced from Sage users / organi / organization tables
-- =============================================================================
CREATE TABLE IF NOT EXISTS users_cache (
user_id BIGINT NOT NULL PRIMARY KEY,
username VARCHAR(128) DEFAULT NULL,
email VARCHAR(256) DEFAULT NULL,
user_status INT DEFAULT 0,
user_created_at VARCHAR(32) DEFAULT NULL,
user_updated_at VARCHAR(32) DEFAULT NULL,
organi_id BIGINT DEFAULT NULL,
organi_name VARCHAR(256) DEFAULT NULL,
organi_parent_id BIGINT DEFAULT NULL,
org_id BIGINT DEFAULT NULL,
org_name VARCHAR(256) DEFAULT NULL,
org_type VARCHAR(64) DEFAULT NULL,
org_status INT DEFAULT 0,
org_updated_at VARCHAR(32) DEFAULT NULL,
synced_at VARCHAR(32) NOT NULL,
UNIQUE KEY uk_organi (organi_id),
KEY idx_updated (user_updated_at)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- =============================================================================
-- pricing_cache: synced from Sage pricing_program / pricing_program_timing
-- =============================================================================
CREATE TABLE IF NOT EXISTS pricing_cache (
program_id BIGINT NOT NULL,
program_name VARCHAR(256) DEFAULT NULL,
program_code VARCHAR(128) DEFAULT NULL,
program_type VARCHAR(64) DEFAULT NULL,
program_status INT DEFAULT 0,
description TEXT DEFAULT NULL,
program_created_at VARCHAR(32) DEFAULT NULL,
program_updated_at VARCHAR(32) DEFAULT NULL,
timing_id BIGINT DEFAULT NULL,
start_time VARCHAR(32) DEFAULT NULL,
end_time VARCHAR(32) DEFAULT NULL,
duration INT DEFAULT NULL,
repeat_rule VARCHAR(256) DEFAULT NULL,
timezone VARCHAR(64) DEFAULT NULL,
timing_status INT DEFAULT 0,
timing_updated_at VARCHAR(32) DEFAULT NULL,
synced_at VARCHAR(32) NOT NULL,
PRIMARY KEY (program_id, timing_id),
KEY idx_program_updated (program_updated_at)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- =============================================================================
-- uapi_cache: synced from Sage uapi / upapp tables
-- =============================================================================
CREATE TABLE IF NOT EXISTS uapi_cache (
uapi_id BIGINT NOT NULL,
api_name VARCHAR(256) DEFAULT NULL,
api_path VARCHAR(512) DEFAULT NULL,
api_method VARCHAR(16) DEFAULT 'GET',
api_version VARCHAR(32) DEFAULT NULL,
api_desc TEXT DEFAULT NULL,
uapi_status INT DEFAULT 0,
auth_required TINYINT DEFAULT 0,
uapi_created_at VARCHAR(32) DEFAULT NULL,
uapi_updated_at VARCHAR(32) DEFAULT NULL,
upapp_id BIGINT DEFAULT NULL,
app_name VARCHAR(256) DEFAULT NULL,
app_code VARCHAR(128) DEFAULT NULL,
app_type VARCHAR(64) DEFAULT NULL,
app_desc TEXT DEFAULT NULL,
app_owner VARCHAR(128) DEFAULT NULL,
upapp_status INT DEFAULT 0,
upapp_updated_at VARCHAR(32) DEFAULT NULL,
synced_at VARCHAR(32) NOT NULL,
PRIMARY KEY (uapi_id, upapp_id),
KEY idx_uapi_updated (uapi_updated_at),
KEY idx_upapp_id (upapp_id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- =============================================================================
-- llmage_cache: synced from Sage llm / llmcatelog / llm_api_map tables
-- =============================================================================
CREATE TABLE IF NOT EXISTS llmage_cache (
llm_id BIGINT NOT NULL,
model_name VARCHAR(256) DEFAULT NULL,
model_version VARCHAR(64) DEFAULT NULL,
provider VARCHAR(128) DEFAULT NULL,
model_type VARCHAR(64) DEFAULT NULL,
llm_status INT DEFAULT 0,
llm_description TEXT DEFAULT NULL,
llm_created_at VARCHAR(32) DEFAULT NULL,
llm_updated_at VARCHAR(32) DEFAULT NULL,
catelog_id BIGINT DEFAULT NULL,
catelog_name VARCHAR(256) DEFAULT NULL,
catelog_code VARCHAR(128) DEFAULT NULL,
catelog_sort INT DEFAULT 0,
catelog_status INT DEFAULT 0,
catelog_updated_at VARCHAR(32) DEFAULT NULL,
api_map_id BIGINT DEFAULT NULL,
api_name VARCHAR(256) DEFAULT NULL,
api_endpoint VARCHAR(512) DEFAULT NULL,
api_version VARCHAR(32) DEFAULT NULL,
auth_type VARCHAR(32) DEFAULT NULL,
rate_limit INT DEFAULT NULL,
api_map_status INT DEFAULT 0,
api_map_updated_at VARCHAR(32) DEFAULT NULL,
synced_at VARCHAR(32) NOT NULL,
PRIMARY KEY (llm_id, catelog_id, api_map_id),
KEY idx_llm_updated (llm_updated_at),
KEY idx_catelog_id (catelog_id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

1
tests/__init__.py Normal file
View File

@ -0,0 +1 @@
# Tests package

Binary file not shown.

458
tests/test_dapi_auth.py Normal file
View File

@ -0,0 +1,458 @@
"""Tests for sageapi.middleware.dapi_auth"""
import hashlib
import hmac
import time
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.testclient import TestClient
from sageapi.middleware.dapi_auth import (
DEFAULT_TIMESTAMP_TOLERANCE_SEC,
DapiAuthError,
DapiAuthMiddleware,
DapiKeyRecord,
authenticate_request,
compute_dapi_signature,
lookup_api_key,
verify_timestamp,
)
def make_signature(method: str, path: str, timestamp: str, secret: str, body: bytes = b"") -> str:
"""Helper to create a valid DAPI signature."""
return compute_dapi_signature(method, path, timestamp, secret, body if body else None)
# ---------------------------------------------------------------------------
# compute_dapi_signature tests
# ---------------------------------------------------------------------------
class TestComputeDapiSignature:
def test_basic_signature(self):
sig = compute_dapi_signature("GET", "/api/test", "1700000000.0", "my-secret")
assert isinstance(sig, str)
assert len(sig) == 64 # SHA-256 hex length
def test_signature_with_body(self):
body = b'{"key": "value"}'
sig = compute_dapi_signature("POST", "/api/test", "1700000000.0", "my-secret", body)
assert isinstance(sig, str)
assert len(sig) == 64
def test_signature_deterministic(self):
sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret")
sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret")
assert sig1 == sig2
def test_different_body_different_signature(self):
sig1 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body1")
sig2 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body2")
assert sig1 != sig2
def test_empty_body_same_as_no_body(self):
sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret")
sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret", None)
assert sig1 == sig2
def test_manual_verification(self):
"""Verify the signature matches the expected HMAC-SHA256 output."""
method, path, ts, secret = "POST", "/v1/chat", "1700000000.0", "test-secret"
body = b'{"message":"hello"}'
body_hash = hashlib.sha256(body).hexdigest()
string_to_sign = f"{method}\n{path}\n{ts}\n{body_hash}"
expected = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha256).hexdigest()
actual = compute_dapi_signature(method, path, ts, secret, body)
assert actual == expected
# ---------------------------------------------------------------------------
# verify_timestamp tests
# ---------------------------------------------------------------------------
class TestVerifyTimestamp:
def test_valid_timestamp(self):
ts = str(time.time())
assert verify_timestamp(ts) is True
def test_expired_timestamp(self):
ts = str(time.time() - 600) # 10 minutes ago
assert verify_timestamp(ts) is False
def test_future_timestamp(self):
ts = str(time.time() + 600) # 10 minutes in future
assert verify_timestamp(ts) is False
def test_invalid_timestamp(self):
assert verify_timestamp("not-a-number") is False
assert verify_timestamp("") is False
def test_custom_tolerance(self):
ts = str(time.time() - 10)
assert verify_timestamp(ts, tolerance_sec=5) is False
assert verify_timestamp(ts, tolerance_sec=15) is True
# ---------------------------------------------------------------------------
# DapiKeyRecord tests
# ---------------------------------------------------------------------------
class TestDapiKeyRecord:
def test_active_key(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None)
assert rec.is_active is True
def test_inactive_key(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="inactive", expire_date=None)
assert rec.is_active is False
def test_not_expired_when_no_expiry(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None)
assert rec.is_expired is False
def test_expired_with_past_date(self):
past = datetime(2020, 1, 1, tzinfo=timezone.utc)
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=past)
assert rec.is_expired is True
def test_not_expired_with_future_date(self):
future = datetime(2099, 1, 1, tzinfo=timezone.utc)
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=future)
assert rec.is_expired is False
def test_expired_with_iso_string_past(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2020-01-01T00:00:00Z")
assert rec.is_expired is True
def test_not_expired_with_iso_string_future(self):
rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2099-01-01T00:00:00Z")
assert rec.is_expired is False
# ---------------------------------------------------------------------------
# lookup_api_key tests
# ---------------------------------------------------------------------------
class TestLookupApiKey:
@pytest.fixture(autouse=True)
def mock_sage_modules(self):
"""Create mock ahserver and sqlor modules for testing."""
import sys
import types
# Create fake modules
ahserver = types.ModuleType("ahserver")
ahserver_serverenv = types.ModuleType("ahserver.serverenv")
ahserver_serverenv.ServerEnv = MagicMock
ahserver.serverenv = ahserver_serverenv
sqlor = types.ModuleType("sqlor")
sqlor_dbpools = types.ModuleType("sqlor.dbpools")
mock_ctx = AsyncMock()
mock_sor = AsyncMock()
mock_sor.R = AsyncMock(return_value=[])
mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx)
sqlor.dbpools = sqlor_dbpools
sys.modules["ahserver"] = ahserver
sys.modules["ahserver.serverenv"] = ahserver_serverenv
sys.modules["sqlor"] = sqlor
sys.modules["sqlor.dbpools"] = sqlor_dbpools
yield mock_ctx, mock_sor
# Cleanup
for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]:
if mod in sys.modules:
del sys.modules[mod]
@pytest.mark.asyncio
async def test_lookup_found(self, mock_sage_modules):
mock_ctx, mock_sor = mock_sage_modules
mock_rec = {
"id": 1,
"apikey": "test-key",
"secret": "test-secret",
"status": "active",
"expire_date": "2099-01-01T00:00:00Z",
"description": "Test key",
}
mock_sor.R.return_value = [mock_rec]
result = await lookup_api_key("test-key")
assert result is not None
assert result.apikey == "test-key"
assert result.secret == "test-secret"
assert result.is_active is True
@pytest.mark.asyncio
async def test_lookup_not_found(self, mock_sage_modules):
mock_ctx, mock_sor = mock_sage_modules
mock_sor.R.return_value = []
result = await lookup_api_key("nonexistent")
assert result is None
# ---------------------------------------------------------------------------
# authenticate_request tests
# ---------------------------------------------------------------------------
class TestAuthenticateRequest:
def _make_request(self, headers: dict, body: bytes = b"", method: str = "GET", path: str = "/test") -> Request:
scope = {
"type": "http",
"method": method,
"path": path,
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
"query_string": b"",
}
async def receive():
return {"type": "http.request", "body": body, "more_body": False}
return Request(scope, receive)
@pytest.mark.asyncio
async def test_missing_api_key_header(self):
req = self._make_request({
"X-DAPI-Timestamp": str(time.time()),
"X-DAPI-Signature": "abc",
})
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "X-DAPI-Key" in exc.value.detail
@pytest.mark.asyncio
async def test_missing_timestamp_header(self):
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Signature": "abc",
})
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "X-DAPI-Timestamp" in exc.value.detail
@pytest.mark.asyncio
async def test_missing_signature_header(self):
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": str(time.time()),
})
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "X-DAPI-Signature" in exc.value.detail
@pytest.mark.asyncio
async def test_expired_timestamp(self):
ts = str(time.time() - 600)
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "abc",
})
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "timestamp" in exc.value.detail.lower()
@pytest.mark.asyncio
async def test_invalid_api_key(self):
ts = str(time.time())
req = self._make_request({
"X-DAPI-Key": "bad-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "abc",
})
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=None):
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "Invalid API key" in exc.value.detail
@pytest.mark.asyncio
async def test_inactive_key(self):
ts = str(time.time())
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "abc",
})
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="secret", status="inactive", expire_date=None)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 403
assert "inactive" in exc.value.detail.lower()
@pytest.mark.asyncio
async def test_expired_key(self):
ts = str(time.time())
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "abc",
})
key_rec = DapiKeyRecord(
id=1, apikey="test-key", secret="secret", status="active",
expire_date=datetime(2020, 1, 1, tzinfo=timezone.utc),
)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 403
assert "expired" in exc.value.detail.lower()
@pytest.mark.asyncio
async def test_invalid_signature(self):
ts = str(time.time())
body = b'{"test": "data"}'
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": "invalid-signature",
}, body=body, method="POST")
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="real-secret", status="active", expire_date=None)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
with pytest.raises(DapiAuthError) as exc:
await authenticate_request(req)
assert exc.value.status_code == 401
assert "Invalid signature" in exc.value.detail
@pytest.mark.asyncio
async def test_valid_authentication(self):
ts = str(time.time())
secret = "my-secret"
body = b'{"test": "data"}'
sig = compute_dapi_signature("POST", "/test", ts, secret, body)
req = self._make_request({
"X-DAPI-Key": "test-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": sig,
}, body=body, method="POST")
key_rec = DapiKeyRecord(id=1, apikey="test-key", secret=secret, status="active", expire_date=None)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
result = await authenticate_request(req)
assert result.apikey == "test-key"
assert result.secret == secret
# ---------------------------------------------------------------------------
# DapiAuthMiddleware tests
# ---------------------------------------------------------------------------
class TestDapiAuthMiddleware:
@pytest.fixture(autouse=True)
def mock_sage_modules(self):
"""Create mock ahserver and sqlor modules."""
import sys
import types
ahserver = types.ModuleType("ahserver")
ahserver_serverenv = types.ModuleType("ahserver.serverenv")
ahserver_serverenv.ServerEnv = MagicMock
ahserver.serverenv = ahserver_serverenv
sqlor = types.ModuleType("sqlor")
sqlor_dbpools = types.ModuleType("sqlor.dbpools")
mock_sor = AsyncMock()
mock_sor.R = AsyncMock(return_value=[])
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx)
sqlor.dbpools = sqlor_dbpools
sys.modules["ahserver"] = ahserver
sys.modules["ahserver.serverenv"] = ahserver_serverenv
sys.modules["sqlor"] = sqlor
sys.modules["sqlor.dbpools"] = sqlor_dbpools
yield
for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]:
if mod in sys.modules:
del sys.modules[mod]
def _build_test_app(self, exclude_paths=None, tolerance_sec=DEFAULT_TIMESTAMP_TOLERANCE_SEC):
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse
from starlette.routing import Route
async def protected(request: Request):
return PlainTextResponse("OK - authenticated")
async def health(request: Request):
return PlainTextResponse("healthy")
app = Starlette(
routes=[
Route("/protected", endpoint=protected, methods=["GET"]),
Route("/health", endpoint=health, methods=["GET"]),
]
)
app.add_middleware(
DapiAuthMiddleware,
exclude_paths=exclude_paths or ["/health"],
tolerance_sec=tolerance_sec,
)
return app
def test_unauthenticated_request_rejected(self):
app = self._build_test_app()
client = TestClient(app)
resp = client.get("/protected")
assert resp.status_code == 401
assert "error" in resp.json()
def test_excluded_path_bypasses_auth(self):
app = self._build_test_app()
client = TestClient(app)
resp = client.get("/health")
assert resp.status_code == 200
assert resp.text == "healthy"
def test_valid_request_accepted(self):
ts = str(time.time())
secret = "test-secret"
sig = compute_dapi_signature("GET", "/protected", ts, secret)
key_rec = DapiKeyRecord(id=1, apikey="valid-key", secret=secret, status="active", expire_date=None)
app = self._build_test_app()
client = TestClient(app)
with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec):
resp = client.get(
"/protected",
headers={
"X-DAPI-Key": "valid-key",
"X-DAPI-Timestamp": ts,
"X-DAPI-Signature": sig,
},
)
assert resp.status_code == 200
assert resp.text == "OK - authenticated"

288
tests/test_http_client.py Normal file
View File

@ -0,0 +1,288 @@
"""Tests for sageapi.utils.http_client"""
import hashlib
import hmac
import json
import time
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
import pytest
from sageapi.utils.http_client import (
DAPISigner,
SageHttpClient,
RetryConfig,
compute_dapi_signature,
)
# ---------------------------------------------------------------------------
# compute_dapi_signature tests (also imported from middleware)
# ---------------------------------------------------------------------------
class TestComputeDapiSignature:
def test_basic(self):
sig = compute_dapi_signature("GET", "/api/test", "1700000000.0", "secret")
assert len(sig) == 64
def test_with_body(self):
body = b'{"msg":"hi"}'
sig = compute_dapi_signature("POST", "/api/test", "1700000000.0", "secret", body)
assert len(sig) == 64
def test_deterministic(self):
sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret")
sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret")
assert sig1 == sig2
# ---------------------------------------------------------------------------
# DAPISigner tests
# ---------------------------------------------------------------------------
class TestDAPISigner:
def test_sign_request(self):
signer = DAPISigner(api_key="my-key", api_secret="my-secret")
headers = signer.sign_request("GET", "/test")
assert "X-DAPI-Key" in headers
assert "X-DAPI-Timestamp" in headers
assert "X-DAPI-Signature" in headers
assert headers["X-DAPI-Key"] == "my-key"
assert headers["X-DAPI-Timestamp"] # should be a valid timestamp string
def test_sign_request_with_body(self):
signer = DAPISigner(api_key="k", api_secret="s")
body = b'{"test": true}'
headers = signer.sign_request("POST", "/v1/chat", body)
assert headers["X-DAPI-Key"] == "k"
# Verify timestamp is recent
ts = float(headers["X-DAPI-Timestamp"])
assert abs(time.time() - ts) < 2
def test_sign_request_produces_valid_signature(self):
signer = DAPISigner(api_key="k", api_secret="secret")
body = b'hello'
headers = signer.sign_request("POST", "/path", body)
expected = compute_dapi_signature("POST", "/path", headers["X-DAPI-Timestamp"], "secret", body)
assert headers["X-DAPI-Signature"] == expected
def test_sign_request_uppercases_method(self):
signer = DAPISigner(api_key="k", api_secret="s")
headers = signer.sign_request("get", "/path")
expected = compute_dapi_signature("GET", "/path", headers["X-DAPI-Timestamp"], "s")
assert headers["X-DAPI-Signature"] == expected
# ---------------------------------------------------------------------------
# RetryConfig tests
# ---------------------------------------------------------------------------
class TestRetryConfig:
def test_defaults(self):
config = RetryConfig()
assert config.max_retries == 3
assert config.backoff_factor == 0.5
assert 429 in config.retry_on_status
assert 500 in config.retry_on_status
def test_custom(self):
config = RetryConfig(max_retries=5, backoff_factor=1.0)
assert config.max_retries == 5
assert config.backoff_factor == 1.0
# ---------------------------------------------------------------------------
# SageHttpClient tests
# ---------------------------------------------------------------------------
class TestSageHttpClient:
def test_init_defaults(self):
client = SageHttpClient(
base_url="https://api.example.com",
api_key="key",
api_secret="secret",
)
assert client.base_url == "https://api.example.com"
assert client.signer.api_key == "key"
assert client.signer.api_secret == "secret"
assert client.auto_sign is True
def test_init_with_custom_signer(self):
signer = DAPISigner(api_key="k", api_secret="s")
client = SageHttpClient(base_url="https://api.example.com", signer=signer)
assert client.signer is signer
def test_init_without_auto_sign(self):
client = SageHttpClient(
base_url="https://api.example.com",
auto_sign=False,
)
assert client.auto_sign is False
def test_build_url(self):
client = SageHttpClient(base_url="https://api.example.com")
assert client._build_url("/v1/test") == "https://api.example.com/v1/test"
assert client._build_url("v1/test") == "https://api.example.com/v1/test"
# Full URL should pass through
assert client._build_url("https://other.com/path") == "https://other.com/path"
def test_get_relative_path(self):
client = SageHttpClient(base_url="https://api.example.com")
assert client._get_relative_path("/v1/chat") == "/v1/chat"
assert client._get_relative_path("https://api.example.com/v1/chat?page=1") == "/v1/chat"
def test_sign_and_prepare_adds_dapi_headers(self):
client = SageHttpClient(
base_url="https://api.example.com",
api_key="k",
api_secret="s",
)
import asyncio
url, headers, body = asyncio.get_event_loop().run_until_complete(
client._sign_and_prepare("GET", "/test")
)
assert "X-DAPI-Key" in headers
assert "X-DAPI-Timestamp" in headers
assert "X-DAPI-Signature" in headers
assert headers["X-DAPI-Key"] == "k"
def test_sign_and_prepare_with_json_body(self):
client = SageHttpClient(
base_url="https://api.example.com",
api_key="k",
api_secret="s",
)
import asyncio
url, headers, body = asyncio.get_event_loop().run_until_complete(
client._sign_and_prepare("POST", "/test", json_body={"msg": "hi"})
)
assert body is not None
assert json.loads(body) == {"msg": "hi"}
assert headers.get("Content-Type") == "application/json"
@pytest.mark.asyncio
async def test_context_manager(self):
client = SageHttpClient(
base_url="https://httpbin.org",
api_key="k",
api_secret="s",
max_retries=0,
)
async with client as c:
assert c is client
assert client._closed is True
@pytest.mark.asyncio
async def test_close(self):
client = SageHttpClient(
base_url="https://httpbin.org",
api_key="k",
api_secret="s",
max_retries=0,
)
await client._get_session() # create session
await client.close()
assert client._closed is True
@pytest.mark.asyncio
async def test_request_with_retry_on_502(self):
"""Test that 502 responses trigger retries and raise after exhaustion."""
client = SageHttpClient(
base_url="",
api_key="k",
api_secret="s",
max_retries=2,
backoff_factor=0.01, # fast for tests
auto_sign=False,
)
mock_response = AsyncMock()
mock_response.status = 502
mock_response.request_info = MagicMock()
mock_response.history = []
mock_response.release = MagicMock()
mock_session = AsyncMock()
mock_session.request = AsyncMock(return_value=mock_response)
mock_session.closed = False
client._session = mock_session
with pytest.raises(aiohttp.ClientResponseError) as exc_info:
await client.request("GET", "/test")
assert exc_info.value.status == 502
# Should have been called 3 times (1 initial + 2 retries)
assert mock_session.request.call_count == 3
class TestSageHttpClientIntegration:
"""Integration tests against httpbin.org (requires network)."""
@pytest.mark.asyncio
async def test_get_request(self):
client = SageHttpClient(
base_url="https://httpbin.org",
auto_sign=False,
max_retries=0,
timeout=10.0,
)
async with client:
resp = await client.get("/get")
assert resp.status == 200
data = await resp.json()
assert "url" in data
@pytest.mark.asyncio
async def test_post_request_with_json(self):
client = SageHttpClient(
base_url="https://httpbin.org",
auto_sign=False,
max_retries=0,
timeout=10.0,
)
async with client:
resp = await client.post("/post", json={"key": "value"})
assert resp.status == 200
data = await resp.json()
assert data["json"] == {"key": "value"}
@pytest.mark.asyncio
async def test_headers_sent(self):
client = SageHttpClient(
base_url="https://httpbin.org",
headers={"X-Custom-Header": "test-value"},
auto_sign=False,
max_retries=0,
timeout=10.0,
)
async with client:
resp = await client.get("/headers")
assert resp.status == 200
data = await resp.json()
assert data["headers"].get("X-Custom-Header") == "test-value"
@pytest.mark.asyncio
async def test_query_params(self):
client = SageHttpClient(
base_url="https://httpbin.org",
auto_sign=False,
max_retries=0,
timeout=10.0,
)
async with client:
resp = await client.get("/get", params={"foo": "bar", "baz": "qux"})
assert resp.status == 200
data = await resp.json()
assert data["args"]["foo"] == "bar"
assert data["args"]["baz"] == "qux"