diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7a60b85 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +*.pyc diff --git a/sageapi/__init__.py b/sageapi/__init__.py index e69de29..57872ca 100644 --- a/sageapi/__init__.py +++ b/sageapi/__init__.py @@ -0,0 +1,26 @@ +"""SageAPI - Sage data caching and proxy API server. + +Provides cached access to Sage data (users, pricing, uapi, llmage) +with DAPI authentication and independent multi-instance deployment. +""" + +__version__ = '0.1.0' + +# Public API exports +from sageapi.sync import run_all_syncs, UserSync, PricingSync, UapiSync, LlmageSync +from sageapi.router import Router, setup_routes +from sageapi.cache.cache_manager import CacheManager +from sageapi.middleware.dapi_auth import authenticate_request, DapiAuthMiddleware + +__all__ = [ + 'run_all_syncs', + 'UserSync', + 'PricingSync', + 'UapiSync', + 'LlmageSync', + 'Router', + 'setup_routes', + 'CacheManager', + 'authenticate_request', + 'DapiAuthMiddleware', +] diff --git a/sageapi/api/__pycache__/__init__.cpython-310.pyc b/sageapi/api/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..f3ed3f2 Binary files /dev/null and b/sageapi/api/__pycache__/__init__.cpython-310.pyc differ diff --git a/sageapi/api/__pycache__/accounting.cpython-310.pyc b/sageapi/api/__pycache__/accounting.cpython-310.pyc new file mode 100644 index 0000000..151b90e Binary files /dev/null and b/sageapi/api/__pycache__/accounting.cpython-310.pyc differ diff --git a/sageapi/api/__pycache__/balance.cpython-310.pyc b/sageapi/api/__pycache__/balance.cpython-310.pyc new file mode 100644 index 0000000..efffddb Binary files /dev/null and b/sageapi/api/__pycache__/balance.cpython-310.pyc differ diff --git a/sageapi/api/__pycache__/health.cpython-310.pyc b/sageapi/api/__pycache__/health.cpython-310.pyc new file mode 100644 index 0000000..0cbd039 Binary files /dev/null and b/sageapi/api/__pycache__/health.cpython-310.pyc differ diff --git a/sageapi/api/__pycache__/pricing.cpython-310.pyc b/sageapi/api/__pycache__/pricing.cpython-310.pyc new file mode 100644 index 0000000..30e9a9c Binary files /dev/null and b/sageapi/api/__pycache__/pricing.cpython-310.pyc differ diff --git a/sageapi/api/__pycache__/users.cpython-310.pyc b/sageapi/api/__pycache__/users.cpython-310.pyc new file mode 100644 index 0000000..331b3a8 Binary files /dev/null and b/sageapi/api/__pycache__/users.cpython-310.pyc differ diff --git a/sageapi/api/accounting.py b/sageapi/api/accounting.py index 537149a..d61279d 100644 --- a/sageapi/api/accounting.py +++ b/sageapi/api/accounting.py @@ -2,7 +2,7 @@ Provides endpoints for creating and querying accounting records. Writing goes directly to the accounting_records table; reads -are served from the same table with optional date range filtering. +are served from the same table with optional filtering. """ from __future__ import annotations @@ -14,64 +14,86 @@ from typing import Any from appPublic.log import debug, error from sqlor.dbpools import DBPools +from ahserver.serverenv import ServerEnv async def create_accounting_record( customer_id: str, amount: float, - record_type: str = 'charge', - description: str = '', - **extra: Any, + llmid: str = '', + model_name: str = '', + pricing_id: str = '', + input_tokens: int | None = None, + output_tokens: int | None = None, + total_tokens: int | None = None, + quantity: float | None = None, + currency: str = 'CNY', + request_id: str = '', + transno: str = '', ) -> str: - """Create a new accounting record. - - Args: - customer_id: The customer identifier. - amount: The accounting amount (positive for charges, negative for credits). - record_type: Type of record (charge, credit, adjustment, etc.). - description: Optional description of the transaction. - - Returns: - JSON string with success flag and the created record ID. - """ + """Create a new accounting record with idempotency via request_id.""" result: dict[str, Any] = {'success': False, 'record_id': None} try: - from ahserver.serverenv import ServerEnv env = ServerEnv() dbname = env.get_module_dbname('sageapi') - if not dbname: result['error'] = 'No database configured for sageapi module' return json.dumps(result, ensure_ascii=False, default=str) - record_id = str(uuid.uuid4()) + record_id = request_id or str(uuid.uuid4()) now = time.strftime('%Y-%m-%d %H:%M:%S') + # Check idempotency + if request_id: + async with DBPools().sqlorContext(dbname) as sor: + existing = await sor.sqlExe( + "SELECT id FROM accounting_records WHERE request_id = ${request_id}$", + {'request_id': request_id}, + ) + if isinstance(existing, list) and existing: + result['success'] = True + result['record_id'] = existing[0].get('id', existing[0].get('id') if isinstance(existing[0], dict) else existing[0]) + result['duplicate'] = True + return json.dumps(result, ensure_ascii=False, default=str) + sql = """ INSERT INTO accounting_records - (id, customer_id, amount, record_type, description, created_at, extra) + (id, customer_id, llmid, model_name, pricing_id, + input_tokens, output_tokens, total_tokens, quantity, + amount, currency, request_id, transno, status, + created_at, updated_at) VALUES - (${id}$, ${customer_id}$, ${amount}$, ${record_type}$, ${description}$, ${created_at}$, ${extra}$) + (${id}$, ${customer_id}$, ${llmid}$, ${model_name}$, ${pricing_id}$, + ${input_tokens}$, ${output_tokens}$, ${total_tokens}$, ${quantity}$, + ${amount}$, ${currency}$, ${request_id}$, ${transno}$, 'accounted', + ${created_at}$, ${updated_at}$) """ + params = { + 'id': record_id, + 'customer_id': customer_id, + 'llmid': llmid, + 'model_name': model_name, + 'pricing_id': pricing_id, + 'input_tokens': input_tokens, + 'output_tokens': output_tokens, + 'total_tokens': total_tokens, + 'quantity': quantity, + 'amount': amount, + 'currency': currency, + 'request_id': request_id, + 'transno': transno, + 'created_at': now, + 'updated_at': now, + } async with DBPools().sqlorContext(dbname) as sor: - await sor.sqlExe(sql, { - 'id': record_id, - 'customer_id': customer_id, - 'amount': amount, - 'record_type': record_type, - 'description': description, - 'created_at': now, - 'extra': json.dumps(extra, ensure_ascii=False) if extra else None, - }) - - result['success'] = True - result['record_id'] = record_id - debug(f'Accounting record created: id={record_id}, customer={customer_id}, amount={amount}') + await sor.sqlExe(sql, params) + result['success'] = True + result['record_id'] = record_id except Exception as e: - error(f'Accounting record creation failed: {e}') + error(f'create_accounting_record error: {e}') result['error'] = str(e) return json.dumps(result, ensure_ascii=False, default=str) @@ -79,30 +101,19 @@ async def create_accounting_record( async def query_accounting_records( customer_id: str | None = None, - start_date: str | None = None, - end_date: str | None = None, - limit: int = 100, - offset: int = 0, + date_from: str | None = None, + date_to: str | None = None, + llmid: str | None = None, + status: str | None = None, + page: int = 1, + page_size: int = 50, ) -> str: - """Query accounting records with optional filters. - - Args: - customer_id: Filter by customer ID. - start_date: Filter records from this date (inclusive, YYYY-MM-DD). - end_date: Filter records up to this date (inclusive, YYYY-MM-DD). - limit: Maximum number of records to return. - offset: Number of records to skip. - - Returns: - JSON string with success flag and record data. - """ + """Query accounting records with filters and pagination.""" result: dict[str, Any] = {'success': False, 'data': [], 'total': 0} try: - from ahserver.serverenv import ServerEnv env = ServerEnv() dbname = env.get_module_dbname('sageapi') - if not dbname: result['error'] = 'No database configured for sageapi module' return json.dumps(result, ensure_ascii=False, default=str) @@ -113,41 +124,56 @@ async def query_accounting_records( if customer_id: conditions.append('customer_id = ${customer_id}$') params['customer_id'] = customer_id - if start_date: - conditions.append('created_at >= ${start_date}$') - params['start_date'] = start_date - if end_date: - conditions.append('created_at <= ${end_date}$') - params['end_date'] = end_date + if date_from: + conditions.append('created_at >= ${date_from}$') + params['date_from'] = date_from + if date_to: + conditions.append('created_at <= ${date_to}$') + params['date_to'] = date_to + if llmid: + conditions.append('llmid = ${llmid}$') + params['llmid'] = llmid + if status: + conditions.append('status = ${status}$') + params['status'] = status - where_clause = 'WHERE ' + ' AND '.join(conditions) if conditions else '' + where = 'WHERE ' + ' AND '.join(conditions) if conditions else '' # Count query - count_sql = f""" - SELECT COUNT(*) as cnt FROM accounting_records {where_clause} - """ - async with DBPools().sqlorContext(dbname) as sor: - count_rows = await sor.sqlExe(count_sql, params) - total = count_rows[0]['cnt'] if count_rows else 0 - result['total'] = total + count_sql = f"SELECT COUNT(*) as cnt FROM accounting_records {where}" + offset = (page - 1) * page_size - if total > 0: - data_sql = f""" - SELECT id, customer_id, amount, record_type, description, created_at, extra - FROM accounting_records - {where_clause} - ORDER BY created_at DESC - LIMIT ${limit}$ OFFSET ${offset}$ - """ - params['limit'] = limit - params['offset'] = offset - rows = await sor.sqlExe(data_sql, params) - result['data'] = [dict(r) for r in (rows or [])] + # Data query + data_sql = f""" + SELECT id, customer_id, llmid, model_name, pricing_id, + input_tokens, output_tokens, total_tokens, quantity, + amount, currency, request_id, transno, status, + created_at, updated_at + FROM accounting_records + {where} + ORDER BY created_at DESC + LIMIT {page_size} OFFSET {offset} + """ + + async with DBPools().sqlorContext(dbname) as sor: + count_result = await sor.sqlExe(count_sql, params) + if isinstance(count_result, list) and count_result: + result['total'] = count_result[0].get('cnt', 0) + elif isinstance(count_result, dict): + result['total'] = count_result.get('cnt', 0) + + data = await sor.sqlExe(data_sql, params) + if isinstance(data, dict): + result['data'] = data.get('rows', []) + elif isinstance(data, list): + result['data'] = data result['success'] = True + result['page'] = page + result['page_size'] = page_size except Exception as e: - error(f'Accounting query failed: {e}') + error(f'query_accounting_records error: {e}') result['error'] = str(e) return json.dumps(result, ensure_ascii=False, default=str) diff --git a/sageapi/api/balance.py b/sageapi/api/balance.py index d541544..ee0c547 100644 --- a/sageapi/api/balance.py +++ b/sageapi/api/balance.py @@ -1,7 +1,8 @@ """Customer balance query API handler. Provides the RESTful endpoint for querying customer account balances. -Reads from the local customer_balance cache table. +Reads from the local customer_balance cache table, with fallback to +real-time query from Sage acc_balance table. """ from __future__ import annotations @@ -10,15 +11,18 @@ import json from typing import Any from appPublic.log import debug, error -from sqlor.dbpools import DBPools +from sqlor.dbpools import DBPools, get_sor_context +from ahserver.serverenv import ServerEnv async def get_customer_balance(customer_id: str | None = None) -> str: """Query customer balance. + First checks the local customer_balance cache. If not found, + falls back to real-time query from Sage acc_balance table. + Args: - customer_id: Optional customer ID filter. If not provided, - returns all customer balances. + customer_id: Optional customer ID filter. Returns: JSON string with success flag and balance data. @@ -26,42 +30,92 @@ async def get_customer_balance(customer_id: str | None = None) -> str: result: dict[str, Any] = {'success': False, 'data': [], 'total': 0} try: - from ahserver.serverenv import ServerEnv env = ServerEnv() - dbname = env.get_module_dbname('sageapi') - - if not dbname: + cache_dbname = env.get_module_dbname('sageapi') + if not cache_dbname: result['error'] = 'No database configured for sageapi module' return json.dumps(result, ensure_ascii=False, default=str) - params: dict[str, Any] = {} - where_clause = '' - if customer_id: - where_clause = 'WHERE customer_id = ${customer_id}$' - params['customer_id'] = customer_id + async with DBPools().sqlorContext(cache_dbname) as sor: + params: dict[str, Any] = {} + where = '' + if customer_id: + where = 'WHERE id = ${customer_id}$' + params['customer_id'] = customer_id - sql = f""" - SELECT customer_id, balance, currency, updated_at - FROM customer_balance - {where_clause} - ORDER BY customer_id - """ - - async with DBPools().sqlorContext(dbname) as sor: + sql = f""" + SELECT id, balance, currency, credit_limit, + last_recharge, last_consumption, + status, cached_at + FROM customer_balance + {where} + ORDER BY id + """ data = await sor.sqlExe(sql, params) if isinstance(data, dict): - result['total'] = data.get('total', 0) - result['data'] = [dict(r) for r in data.get('rows', [])] - else: - rows = [dict(r) for r in (data or [])] - result['data'] = rows - result['total'] = len(rows) + result['total'] = data.get('total', len(data.get('rows', []))) + result['data'] = data.get('rows', []) + elif isinstance(data, list): + result['total'] = len(data) + result['data'] = data + + # If cache miss for specific customer, try real-time from Sage + if customer_id and not result['data']: + result['data'] = await _query_sage_balance(env, customer_id) + result['total'] = len(result['data']) result['success'] = True - debug(f'Balance query: returned {result["total"]} records') except Exception as e: - error(f'Balance query failed: {e}') + error(f'get_customer_balance error: {e}') + result['error'] = str(e) + + return json.dumps(result, ensure_ascii=False, default=str) + + +async def _query_sage_balance(env: ServerEnv, customer_id: str) -> list[dict]: + """Fallback: query Sage acc_balance table directly.""" + try: + async with get_sor_context(env, 'sage') as sor: + sql = """ + SELECT customer_id, balance, currency, status, updated_at + FROM acc_balance + WHERE customer_id = ${customer_id}$ + """ + rows = await sor.sqlExe(sql, {'customer_id': customer_id}) + if isinstance(rows, list): + return rows + elif isinstance(rows, dict): + return rows.get('rows', []) + except Exception as e: + error(f'_query_sage_balance error: {e}') + return [] + + +async def update_customer_balance(customer_id: str, balance: float) -> str: + """Update customer balance in cache (called by sync or accounting).""" + result: dict[str, Any] = {'success': False} + + try: + env = ServerEnv() + cache_dbname = env.get_module_dbname('sageapi') + + sql = """ + INSERT INTO customer_balance (id, balance, cached_at) + VALUES (${customer_id}$, ${balance}$, NOW()) + ON DUPLICATE KEY UPDATE + balance = ${balance}$, + cached_at = NOW() + """ + async with DBPools().sqlorContext(cache_dbname) as sor: + await sor.sqlExe(sql, { + 'customer_id': customer_id, + 'balance': balance, + }) + result['success'] = True + + except Exception as e: + error(f'update_customer_balance error: {e}') result['error'] = str(e) return json.dumps(result, ensure_ascii=False, default=str) diff --git a/sageapi/api/health.py b/sageapi/api/health.py index 1fed77b..dd11fdd 100644 --- a/sageapi/api/health.py +++ b/sageapi/api/health.py @@ -1,7 +1,6 @@ """Health check API handler. -Provides a simple endpoint for load balancer health checks and -system status monitoring. No authentication required. +Provides endpoints for service health and readiness checks. """ from __future__ import annotations @@ -10,51 +9,101 @@ import json import time from typing import Any -from appPublic.log import debug +from appPublic.log import debug, error from sqlor.dbpools import DBPools +from ahserver.serverenv import ServerEnv + +_START_TIME = time.time() async def health_check() -> str: - """Health check endpoint. + """Basic health check - returns service status.""" + uptime = time.time() - _START_TIME - Returns system status including database connectivity, - cache stats, and uptime information. - - Returns: - JSON string with health status. - """ - result: dict[str, Any] = { + result = { 'status': 'ok', - 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), - 'database': 'unknown', - 'cache': {}, + 'service': 'sageapi', + 'uptime_seconds': round(uptime, 1), + 'timestamp': time.strftime('%Y-%m-%dT%H:%M:%S%z'), + } + return json.dumps(result, ensure_ascii=False, default=str) + + +async def readiness_check() -> str: + """Readiness check - verifies database connectivity.""" + result: dict[str, Any] = { + 'status': 'unknown', + 'checks': {}, } - # Check database connectivity + # Check cache database connection try: - from ahserver.serverenv import ServerEnv env = ServerEnv() dbname = env.get_module_dbname('sageapi') - - if dbname: - async with DBPools().sqlorContext(dbname) as sor: - await sor.sqlExe('SELECT 1') - result['database'] = 'connected' + if not dbname: + result['checks']['cache_db'] = { + 'status': 'fail', + 'error': 'No database configured for sageapi module', + } else: - result['database'] = 'not_configured' - result['status'] = 'degraded' - + async with DBPools().sqlorContext(dbname) as sor: + rows = await sor.sqlExe('SELECT 1 as ping') + result['checks']['cache_db'] = { + 'status': 'ok', + 'dbname': dbname, + } except Exception as e: - result['database'] = f'error: {str(e)}' - result['status'] = 'unhealthy' + error(f'readiness_check cache_db error: {e}') + result['checks']['cache_db'] = { + 'status': 'fail', + 'error': str(e), + } - # Cache stats + # Check Sage database connection try: - from ..cache.cache_manager import _get_cache_manager - cm = _get_cache_manager() - result['cache'] = cm.stats() - except Exception: - result['cache'] = {'error': 'cache not initialized'} + from sqlor.dbpools import get_sor_context + async with get_sor_context(env, 'sage') as sor: + rows = await sor.sqlExe('SELECT 1 as ping') + result['checks']['sage_db'] = {'status': 'ok'} + except Exception as e: + error(f'readiness_check sage_db error: {e}') + result['checks']['sage_db'] = { + 'status': 'fail', + 'error': str(e), + } + + # Check sync state + try: + async with DBPools().sqlorContext(dbname) as sor: + sql = """ + SELECT entity_type, sync_status, last_sync_time + FROM sync_state + ORDER BY last_sync_time DESC + """ + rows = await sor.sqlExe(sql) + if isinstance(rows, list): + result['checks']['sync_status'] = { + 'status': 'ok', + 'entities': [ + { + 'entity_type': r.get('entity_type', ''), + 'sync_status': r.get('sync_status', ''), + 'last_sync_time': str(r.get('last_sync_time', '')), + } + for r in rows + ], + } + except Exception as e: + result['checks']['sync_status'] = { + 'status': 'fail', + 'error': str(e), + } + + # Overall status + all_ok = all( + check.get('status') == 'ok' + for check in result['checks'].values() + ) + result['status'] = 'ready' if all_ok else 'degraded' - debug(f'Health check: status={result["status"]}') return json.dumps(result, ensure_ascii=False, default=str) diff --git a/sageapi/api/pricing.py b/sageapi/api/pricing.py index 7af1a9d..f4ff158 100644 --- a/sageapi/api/pricing.py +++ b/sageapi/api/pricing.py @@ -1,7 +1,6 @@ -"""Pricing query API handler. +"""Pricing API handler. -Provides the RESTful endpoint for querying pricing information. -Reads from the local pricing_cache table synced from Sage. +Provides endpoint for querying cached pricing data. """ from __future__ import annotations @@ -11,72 +10,110 @@ from typing import Any from appPublic.log import debug, error from sqlor.dbpools import DBPools +from ahserver.serverenv import ServerEnv async def query_pricing( - program_id: str | None = None, - model: str | None = None, - limit: int = 200, - offset: int = 0, + ppid: str | None = None, + llmid: str | None = None, + pricing_type: str | None = None, + status: str | None = None, + page: int = 1, + page_size: int = 50, ) -> str: - """Query pricing information from the local cache. - - Args: - program_id: Filter by pricing program ID. - model: Filter by model name (partial match). - limit: Maximum number of records to return. - offset: Number of records to skip. - - Returns: - JSON string with success flag and pricing data. - """ + """Query pricing from cache with filters.""" result: dict[str, Any] = {'success': False, 'data': [], 'total': 0} try: - from ahserver.serverenv import ServerEnv env = ServerEnv() dbname = env.get_module_dbname('sageapi') - if not dbname: result['error'] = 'No database configured for sageapi module' return json.dumps(result, ensure_ascii=False, default=str) conditions = [] - params: dict[str, Any] = {'limit': limit, 'offset': offset} + params: dict[str, Any] = {} - if program_id: - conditions.append('program_id = ${program_id}$') - params['program_id'] = program_id - if model: - conditions.append('model LIKE ${model}$') - params['model'] = f'%{model}%' + if ppid: + conditions.append('id = ${ppid}$') + params['ppid'] = ppid + if llmid: + conditions.append('llmid = ${llmid}$') + params['llmid'] = llmid + if pricing_type: + conditions.append('pricing_type = ${pricing_type}$') + params['pricing_type'] = pricing_type + if status: + conditions.append('status = ${status}$') + params['status'] = status + else: + conditions.append("status = 'active'") - where_clause = 'WHERE ' + ' AND '.join(conditions) if conditions else '' + where = 'WHERE ' + ' AND '.join(conditions) if conditions else '' + + count_sql = f"SELECT COUNT(*) as cnt FROM pricing_cache {where}" + offset = (page - 1) * page_size + data_sql = f""" + SELECT id, llmid, model_name, pricing_type, + input_price, output_price, unit_price, + currency, status, effective_from, effective_to, + cached_at + FROM pricing_cache + {where} + ORDER BY model_name + LIMIT {page_size} OFFSET {offset} + """ - # Count query - count_sql = f'SELECT COUNT(*) as cnt FROM pricing_cache {where_clause}' async with DBPools().sqlorContext(dbname) as sor: - count_rows = await sor.sqlExe(count_sql, params) - total = count_rows[0]['cnt'] if count_rows else 0 - result['total'] = total + count_result = await sor.sqlExe(count_sql, params) + if isinstance(count_result, list) and count_result: + result['total'] = count_result[0].get('cnt', 0) + elif isinstance(count_result, dict): + result['total'] = count_result.get('cnt', 0) - if total > 0: - data_sql = f""" - SELECT program_id, model, input_price, output_price, - unit, currency, updated_at - FROM pricing_cache - {where_clause} - ORDER BY program_id, model - LIMIT ${limit}$ OFFSET ${offset}$ - """ - rows = await sor.sqlExe(data_sql, params) - result['data'] = [dict(r) for r in (rows or [])] + data = await sor.sqlExe(data_sql, params) + if isinstance(data, dict): + result['data'] = data.get('rows', []) + elif isinstance(data, list): + result['data'] = data result['success'] = True - debug(f'Pricing query: returned {result["total"]} records') except Exception as e: - error(f'Pricing query failed: {e}') + error(f'query_pricing error: {e}') + result['error'] = str(e) + + return json.dumps(result, ensure_ascii=False, default=str) + + +async def get_pricing_by_llmid(llmid: str) -> str: + """Get all active pricing entries for a specific model.""" + result: dict[str, Any] = {'success': False, 'data': []} + + try: + env = ServerEnv() + dbname = env.get_module_dbname('sageapi') + + sql = """ + SELECT id, llmid, model_name, pricing_type, + input_price, output_price, unit_price, + currency, status, cached_at + FROM pricing_cache + WHERE llmid = ${llmid}$ AND status = 'active' + ORDER BY pricing_type + """ + + async with DBPools().sqlorContext(dbname) as sor: + data = await sor.sqlExe(sql, {'llmid': llmid}) + if isinstance(data, list): + result['data'] = data + result['success'] = True + elif isinstance(data, dict): + result['data'] = data.get('rows', []) + result['success'] = True + + except Exception as e: + error(f'get_pricing_by_llmid error: {e}') result['error'] = str(e) return json.dumps(result, ensure_ascii=False, default=str) diff --git a/sageapi/api/users.py b/sageapi/api/users.py index d0bd3b2..219d334 100644 --- a/sageapi/api/users.py +++ b/sageapi/api/users.py @@ -1,7 +1,6 @@ -"""User query API handler. +"""Users API handler. -Provides the RESTful endpoint for querying user information. -Reads from the local users_cache table synced from Sage. +Provides endpoint for querying cached user data. """ from __future__ import annotations @@ -11,73 +10,91 @@ from typing import Any from appPublic.log import debug, error from sqlor.dbpools import DBPools +from ahserver.serverenv import ServerEnv -async def query_users( - user_id: str | None = None, - keyword: str | None = None, - limit: int = 100, - offset: int = 0, -) -> str: - """Query user information from the local cache. - - Args: - user_id: Filter by specific user ID. - keyword: Search keyword (matches username, email, or phone). - limit: Maximum number of records to return. - offset: Number of records to skip. - - Returns: - JSON string with success flag and user data. - """ +async def query_users(keyword: str | None = None, orgid: str | None = None, page: int = 1, page_size: int = 50) -> str: + """Query users from cache with keyword search.""" result: dict[str, Any] = {'success': False, 'data': [], 'total': 0} try: - from ahserver.serverenv import ServerEnv env = ServerEnv() dbname = env.get_module_dbname('sageapi') - if not dbname: result['error'] = 'No database configured for sageapi module' return json.dumps(result, ensure_ascii=False, default=str) conditions = [] - params: dict[str, Any] = {'limit': limit, 'offset': offset} + params: dict[str, Any] = {} - if user_id: - conditions.append('user_id = ${user_id}$') - params['user_id'] = user_id if keyword: - conditions.append( - '(username LIKE ${keyword}$ OR email LIKE ${keyword}$ OR phone LIKE ${keyword}$)' - ) + conditions.append("username LIKE ${keyword}$") params['keyword'] = f'%{keyword}%' + if orgid: + conditions.append('orgid = ${orgid}$') + params['orgid'] = orgid - where_clause = 'WHERE ' + ' AND '.join(conditions) if conditions else '' + where = 'WHERE ' + ' AND '.join(conditions) if conditions else '' + + count_sql = f"SELECT COUNT(*) as cnt FROM users_cache {where}" + offset = (page - 1) * page_size + data_sql = f""" + SELECT id, username, orgid, orgname, email, phone, + status, created_at, updated_at, cached_at + FROM users_cache + {where} + ORDER BY username + LIMIT {page_size} OFFSET {offset} + """ - # Count query - count_sql = f'SELECT COUNT(*) as cnt FROM users_cache {where_clause}' async with DBPools().sqlorContext(dbname) as sor: - count_rows = await sor.sqlExe(count_sql, params) - total = count_rows[0]['cnt'] if count_rows else 0 - result['total'] = total + count_result = await sor.sqlExe(count_sql, params) + if isinstance(count_result, list) and count_result: + result['total'] = count_result[0].get('cnt', 0) + elif isinstance(count_result, dict): + result['total'] = count_result.get('cnt', 0) - if total > 0: - data_sql = f""" - SELECT user_id, username, email, phone, status, updated_at - FROM users_cache - {where_clause} - ORDER BY user_id - LIMIT ${limit}$ OFFSET ${offset}$ - """ - rows = await sor.sqlExe(data_sql, params) - result['data'] = [dict(r) for r in (rows or [])] + data = await sor.sqlExe(data_sql, params) + if isinstance(data, dict): + result['data'] = data.get('rows', []) + elif isinstance(data, list): + result['data'] = data result['success'] = True - debug(f'User query: returned {result["total"]} records') except Exception as e: - error(f'User query failed: {e}') + error(f'query_users error: {e}') + result['error'] = str(e) + + return json.dumps(result, ensure_ascii=False, default=str) + + +async def get_user_by_id(user_id: str) -> str: + """Get a single user by ID.""" + result: dict[str, Any] = {'success': False, 'data': None} + + try: + env = ServerEnv() + dbname = env.get_module_dbname('sageapi') + + sql = """ + SELECT id, username, orgid, orgname, email, phone, + status, created_at, updated_at, cached_at + FROM users_cache + WHERE id = ${user_id}$ + """ + + async with DBPools().sqlorContext(dbname) as sor: + data = await sor.sqlExe(sql, {'user_id': user_id}) + if isinstance(data, list) and data: + result['data'] = data[0] + result['success'] = True + elif isinstance(data, dict) and data.get('rows'): + result['data'] = data['rows'][0] + result['success'] = True + + except Exception as e: + error(f'get_user_by_id error: {e}') result['error'] = str(e) return json.dumps(result, ensure_ascii=False, default=str) diff --git a/sageapi/cache/__pycache__/__init__.cpython-310.pyc b/sageapi/cache/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..4fe27af Binary files /dev/null and b/sageapi/cache/__pycache__/__init__.cpython-310.pyc differ diff --git a/sageapi/cache/__pycache__/cache_manager.cpython-310.pyc b/sageapi/cache/__pycache__/cache_manager.cpython-310.pyc new file mode 100644 index 0000000..6a85f8c Binary files /dev/null and b/sageapi/cache/__pycache__/cache_manager.cpython-310.pyc differ diff --git a/sageapi/init.py b/sageapi/init.py index 0c02353..5994414 100644 --- a/sageapi/init.py +++ b/sageapi/init.py @@ -2,40 +2,69 @@ Registers all public functions to ServerEnv so they are accessible from dspy scripts and other modules via the global environment. +Also sets up route registration and database event bindings. """ -from appPublic.log import debug +from appPublic.log import debug, info from sqlor.dbpools import DBPools from ahserver.serverenv import ServerEnv # --------------------------------------------------------------------------- # Auth # --------------------------------------------------------------------------- -from .auth.dapi_auth import dapi_auth_middleware -from .auth.uapi_sign import uapi_sign_verify +from .middleware.dapi_auth import authenticate_request, DapiAuthMiddleware # --------------------------------------------------------------------------- # Sync # --------------------------------------------------------------------------- -from .sync.base_sync import BaseSync -from .sync.user_sync import sync_users -from .sync.pricing_sync import sync_pricing -from .sync.uapi_sync import sync_uapi -from .sync.llmage_sync import sync_llmage +from .sync.base_sync import BaseSync, run_all_syncs +from .sync.user_sync import UserSync +from .sync.pricing_sync import PricingSync +from .sync.uapi_sync import UapiSync +from .sync.llmage_sync import LlmageSync + +# Module-level convenience functions for sync +async def sync_users() -> str: + """Convenience: run user sync.""" + syncer = UserSync() + return await syncer.sync() + +async def sync_pricing() -> str: + """Convenience: run pricing sync.""" + syncer = PricingSync() + return await syncer.sync() + +async def sync_uapi() -> str: + """Convenience: run uapi sync.""" + syncer = UapiSync() + return await syncer.sync() + +async def sync_llmage() -> str: + """Convenience: run llmage sync.""" + syncer = LlmageSync() + return await syncer.sync() # --------------------------------------------------------------------------- # Cache # --------------------------------------------------------------------------- from .cache.cache_manager import CacheManager +# Global cache instance (per-process) +_cache_manager = CacheManager(max_entries=10000, default_ttl=300) + # --------------------------------------------------------------------------- # API # --------------------------------------------------------------------------- -from .api.balance import get_customer_balance +from .api.balance import get_customer_balance, update_customer_balance from .api.accounting import create_accounting_record, query_accounting_records -from .api.users import query_users -from .api.pricing import query_pricing -from .api.health import health_check +from .api.users import query_users, get_user_by_id +from .api.pricing import query_pricing, get_pricing_by_llmid +from .api.health import health_check, readiness_check + +# --------------------------------------------------------------------------- +# Router +# --------------------------------------------------------------------------- +from .router import Router, setup_routes # --------------------------------------------------------------------------- # Utils @@ -45,66 +74,79 @@ from .utils.crypto import encrypt_payload, decrypt_payload def _bind_sageapi_events(dbpools: DBPools, dbname: str) -> None: - """Bind database events to SageAPI cache invalidation handlers. - - When sync state or accounting records change in the database, - the corresponding cache entries are invalidated automatically. - """ + """Bind database events to SageAPI cache invalidation handlers.""" bindings = [ - # sync_state table: clear sync-related caches on change - (f'{dbname}:sync_state:c:after', CacheManager.invalidate_sync_state), - (f'{dbname}:sync_state:u:after', CacheManager.invalidate_sync_state), - (f'{dbname}:sync_state:d:after', CacheManager.invalidate_sync_state), - # accounting_records: clear accounting cache on change - (f'{dbname}:accounting_records:c:after', CacheManager.invalidate_accounting), - (f'{dbname}:accounting_records:u:after', CacheManager.invalidate_accounting), - (f'{dbname}:accounting_records:d:after', CacheManager.invalidate_accounting), + (f'{dbname}:sync_state:c:after', _cache_manager.invalidate_sync_state), + (f'{dbname}:sync_state:u:after', _cache_manager.invalidate_sync_state), + (f'{dbname}:sync_state:d:after', _cache_manager.invalidate_sync_state), + (f'{dbname}:accounting_records:c:after', _cache_manager.invalidate_accounting), + (f'{dbname}:accounting_records:u:after', _cache_manager.invalidate_accounting), + (f'{dbname}:accounting_records:d:after', _cache_manager.invalidate_accounting), ] for event_name, handler in bindings: - dbpools.bind(event_name, handler) - debug(f'SageAPI event bound: {event_name}') + try: + dbpools.bind(event_name, handler) + debug(f'SageAPI event bound: {event_name}') + except Exception as e: + debug(f'SageAPI event bind skipped: {event_name} ({e})') def load_sageapi() -> None: """Register all SageAPI functions into ServerEnv. Called by the Sage server during module loading phase. - All registered functions become available as globals in dspy scripts. """ env = ServerEnv() # Auth - env.dapi_auth_middleware = dapi_auth_middleware - env.uapi_sign_verify = uapi_sign_verify + env.authenticate_request = authenticate_request + env.DapiAuthMiddleware = DapiAuthMiddleware # Sync env.sync_users = sync_users env.sync_pricing = sync_pricing env.sync_uapi = sync_uapi env.sync_llmage = sync_llmage + env.run_all_syncs = run_all_syncs env.BaseSync = BaseSync + env.UserSync = UserSync + env.PricingSync = PricingSync + env.UapiSync = UapiSync + env.LlmageSync = LlmageSync # Cache - env.cache_manager = CacheManager() + env.cache_manager = _cache_manager # API env.get_customer_balance = get_customer_balance + env.update_customer_balance = update_customer_balance env.create_accounting_record = create_accounting_record env.query_accounting_records = query_accounting_records env.query_users = query_users + env.get_user_by_id = get_user_by_id env.query_pricing = query_pricing + env.get_pricing_by_llmid = get_pricing_by_llmid env.health_check = health_check + env.readiness_check = readiness_check + + # Router + router = Router() + setup_routes(router) + env.sageapi_router = router + info(f'SageAPI: {len(router.get_routes())} routes registered') # Utils env.SageHttpClient = SageHttpClient env.encrypt_payload = encrypt_payload env.decrypt_payload = decrypt_payload - # Bind database events for automatic cache invalidation + # Bind database events dbpools = DBPools() dbname = env.get_module_dbname('sageapi') if dbname: _bind_sageapi_events(dbpools, dbname) - debug(f'SageAPI event listeners bound for database: {dbname}') + info(f'SageAPI: event listeners bound for database: {dbname}') else: - debug('SageAPI event listeners skipped: no database configured for sageapi module') + debug('SageAPI: event listeners skipped (no database configured)') + + info('SageAPI module loaded successfully') diff --git a/sageapi/middleware/__init__.py b/sageapi/middleware/__init__.py new file mode 100644 index 0000000..93fa6f7 --- /dev/null +++ b/sageapi/middleware/__init__.py @@ -0,0 +1 @@ +# Middleware package diff --git a/sageapi/middleware/__pycache__/__init__.cpython-310.pyc b/sageapi/middleware/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..4b369c3 Binary files /dev/null and b/sageapi/middleware/__pycache__/__init__.cpython-310.pyc differ diff --git a/sageapi/middleware/__pycache__/dapi_auth.cpython-310.pyc b/sageapi/middleware/__pycache__/dapi_auth.cpython-310.pyc new file mode 100644 index 0000000..853acdc Binary files /dev/null and b/sageapi/middleware/__pycache__/dapi_auth.cpython-310.pyc differ diff --git a/sageapi/middleware/dapi_auth.py b/sageapi/middleware/dapi_auth.py new file mode 100644 index 0000000..83ec06f --- /dev/null +++ b/sageapi/middleware/dapi_auth.py @@ -0,0 +1,254 @@ +""" +DAPI Authentication Middleware for sageapi. + +Authenticates incoming requests using DAPI signature headers by querying +the Sage downapikey table. + +Usage with FastAPI / Starlette: + from sageapi.middleware.dapi_auth import DapiAuthMiddleware + app.add_middleware(DapiAuthMiddleware) + + # Or for specific routes, use the get_dapi_key() dependency. +""" + +import hashlib +import hmac +import time +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Optional + +from starlette.datastructures import Headers +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import JSONResponse + +# Headers expected from the client +HEADER_API_KEY = "X-DAPI-Key" +HEADER_TIMESTAMP = "X-DAPI-Timestamp" +HEADER_SIGNATURE = "X-DAPI-Signature" + +# Default timestamp tolerance: 5 minutes +DEFAULT_TIMESTAMP_TOLERANCE_SEC = 300 + + +@dataclass +class DapiKeyRecord: + """Represents a record from the downapikey table.""" + + id: Any + apikey: str + secret: str + status: str + expire_date: Any # datetime or string + description: str = "" + + @property + def is_active(self) -> bool: + return self.status == "active" + + @property + def is_expired(self) -> bool: + """Check if the key has expired based on expire_date.""" + from datetime import datetime, timezone + + if self.expire_date is None: + return False + # Handle both datetime objects and ISO string + if isinstance(self.expire_date, str): + try: + expire_dt = datetime.fromisoformat(self.expire_date.replace("Z", "+00:00")) + except (ValueError, AttributeError): + return False + else: + expire_dt = self.expire_date + + # Ensure expire_dt is timezone-aware (UTC) + if expire_dt.tzinfo is None: + expire_dt = expire_dt.replace(tzinfo=timezone.utc) + + now = datetime.now(timezone.utc) + return now > expire_dt + + +class DapiAuthError(Exception): + """Raised when DAPI authentication fails.""" + + def __init__(self, status_code: int, detail: str): + self.status_code = status_code + self.detail = detail + super().__init__(detail) + + +def compute_dapi_signature(method: str, path: str, timestamp: str, secret: str, body: Optional[bytes] = None) -> str: + """ + Compute the DAPI HMAC-SHA256 signature. + + The signed string is: "{method}\n{path}\n{timestamp}\n{body_hash}" + where body_hash is SHA-256 hex digest of the request body (or empty string if no body). + """ + if body: + body_hash = hashlib.sha256(body).hexdigest() + else: + body_hash = "" + + string_to_sign = f"{method}\n{path}\n{timestamp}\n{body_hash}" + + signature = hmac.new( + secret.encode("utf-8"), + string_to_sign.encode("utf-8"), + hashlib.sha256, + ).hexdigest() + + return signature + + +def verify_timestamp(timestamp_str: str, tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC) -> bool: + """Verify that the timestamp is within the allowed window.""" + try: + ts = float(timestamp_str) + except (ValueError, TypeError): + return False + + now = time.time() + return abs(now - ts) <= tolerance_sec + + +async def lookup_api_key(api_key: str) -> Optional[DapiKeyRecord]: + """ + Look up an API key in the Sage downapikey table. + + Returns a DapiKeyRecord if found, None otherwise. + """ + from ahserver.serverenv import ServerEnv + from sqlor.dbpools import get_sor_context + + env = ServerEnv() + async with get_sor_context(env, "dapi") as sor: + recs = await sor.R("downapikey", {"apikey": api_key}) + + if not recs: + return None + + rec = recs[0] # Take the first match + return DapiKeyRecord( + id=rec.get("id"), + apikey=rec.get("apikey", ""), + secret=rec.get("secret", ""), + status=rec.get("status", "inactive"), + expire_date=rec.get("expire_date"), + description=rec.get("description", ""), + ) + + +async def authenticate_request( + request: Request, + tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC, +) -> DapiKeyRecord: + """ + Authenticate a request using DAPI headers. + + Returns the DapiKeyRecord on success. + Raises DapiAuthError on failure. + """ + headers = request.headers + + # 1. Check required headers + api_key = headers.get(HEADER_API_KEY) + if not api_key: + raise DapiAuthError(401, f"Missing header: {HEADER_API_KEY}") + + timestamp_str = headers.get(HEADER_TIMESTAMP) + if not timestamp_str: + raise DapiAuthError(401, f"Missing header: {HEADER_TIMESTAMP}") + + signature = headers.get(HEADER_SIGNATURE) + if not signature: + raise DapiAuthError(401, f"Missing header: {HEADER_SIGNATURE}") + + # 2. Validate timestamp window + if not verify_timestamp(timestamp_str, tolerance_sec): + raise DapiAuthError(401, "Request timestamp is outside the allowed window") + + # 3. Look up the API key in the database + key_record = await lookup_api_key(api_key) + if key_record is None: + raise DapiAuthError(401, "Invalid API key") + + # 4. Check key status + if not key_record.is_active: + raise DapiAuthError(403, "API key is inactive") + + # 5. Check expiration + if key_record.is_expired: + raise DapiAuthError(403, "API key has expired") + + # 6. Read request body for signature verification + body = await request.body() + + # 7. Compute expected signature and compare + expected_signature = compute_dapi_signature( + method=request.method, + path=request.url.path, + timestamp=timestamp_str, + secret=key_record.secret, + body=body if body else None, + ) + + if not hmac.compare_digest(signature, expected_signature): + raise DapiAuthError(401, "Invalid signature") + + return key_record + + +class DapiAuthMiddleware(BaseHTTPMiddleware): + """ + Starlette/FastAPI middleware that enforces DAPI authentication on all requests. + + Attributes: + exclude_paths: List of path prefixes to skip authentication (e.g., ['/health', '/docs']). + tolerance_sec: Timestamp tolerance in seconds (default: 300). + """ + + def __init__( + self, + app: Any, + exclude_paths: Optional[list[str]] = None, + tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC, + ): + super().__init__(app) + self.exclude_paths = exclude_paths or [] + self.tolerance_sec = tolerance_sec + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> JSONResponse: + # Skip excluded paths + for prefix in self.exclude_paths: + if request.url.path.startswith(prefix): + return await call_next(request) + + try: + key_record = await authenticate_request(request, self.tolerance_sec) + # Attach the key record to request state for downstream use + request.state.dapi_key = key_record + request.state.dapi_key_id = key_record.id + except DapiAuthError as e: + return JSONResponse( + status_code=e.status_code, + content={"error": e.detail}, + ) + + return await call_next(request) + + +def requires_dapi_auth( + request: Request, + tolerance_sec: int = DEFAULT_TIMESTAMP_TOLERANCE_SEC, +) -> Awaitable[DapiKeyRecord]: + """ + Dependency function for FastAPI route-level DAPI authentication. + + Usage: + @app.get("/protected") + async def protected_route(key: DapiKeyRecord = Depends(requires_dapi_auth)): + ... + """ + return authenticate_request(request, tolerance_sec) diff --git a/sageapi/router.py b/sageapi/router.py index b926481..b13b301 100644 --- a/sageapi/router.py +++ b/sageapi/router.py @@ -29,10 +29,10 @@ class Router: Args: method: HTTP method (GET, POST, PUT, DELETE). - path: URL path pattern, e.g. '/api/v1/balance'. + path: URL path pattern. handler: Callable that handles the request. auth: Authentication method ('dapi', 'uapi', 'none'). - description: Human-readable description of the endpoint. + description: Human-readable description. """ self._routes.append({ 'method': method.upper(), @@ -55,64 +55,75 @@ class Router: return None -# Global router instance -router = Router() +def setup_routes(router: Router) -> None: + """Register all SageAPI routes. + Health endpoints (no auth): + GET /api/v1/health + GET /api/v1/health/ready -def register_routes() -> None: - """Register all SageAPI API routes. + Balance endpoints (dapi auth): + GET /api/v1/balance + POST /api/v1/balance/update - Called during module initialization to populate the router - with all available endpoints. + Accounting endpoints (dapi auth): + POST /api/v1/accounting + GET /api/v1/accounting + + Users endpoints (dapi auth): + GET /api/v1/users + GET /api/v1/users/{user_id} + + Pricing endpoints (dapi auth): + GET /api/v1/pricing + GET /api/v1/pricing/model/{llmid} """ - from .api.health import health_check - from .api.balance import get_customer_balance - from .api.accounting import create_accounting_record, query_accounting_records - from .api.users import query_users - from .api.pricing import query_pricing + # Health (no auth) + from sageapi.api.health import health_check, readiness_check + router.register('GET', '/api/v1/health', health_check, auth='none', description='Health check') + router.register('GET', '/api/v1/health/ready', readiness_check, auth='none', description='Readiness check') - # Health check (no auth required) - router.register( - 'GET', '/api/v1/health', - handler=health_check, - auth='none', - description='Health check endpoint', - ) - - # Customer balance - router.register( - 'GET', '/api/v1/balance', - handler=get_customer_balance, - auth='dapi', - description='Query customer balance', - ) + # Balance + from sageapi.api.balance import get_customer_balance, update_customer_balance + router.register('GET', '/api/v1/balance', get_customer_balance, auth='dapi', description='Query customer balance') + router.register('POST', '/api/v1/balance/update', update_customer_balance, auth='dapi', description='Update customer balance') # Accounting - router.register( - 'POST', '/api/v1/accounting', - handler=create_accounting_record, - auth='dapi', - description='Create an accounting record', - ) - router.register( - 'GET', '/api/v1/accounting', - handler=query_accounting_records, - auth='dapi', - description='Query accounting records', - ) + from sageapi.api.accounting import create_accounting_record, query_accounting_records + router.register('POST', '/api/v1/accounting', create_accounting_record, auth='dapi', description='Create accounting record') + router.register('GET', '/api/v1/accounting', query_accounting_records, auth='dapi', description='Query accounting records') # Users - router.register( - 'GET', '/api/v1/users', - handler=query_users, - auth='dapi', - description='Query user information', - ) + from sageapi.api.users import query_users, get_user_by_id + router.register('GET', '/api/v1/users', query_users, auth='dapi', description='Query users') + router.register('GET', '/api/v1/users/detail', get_user_by_id, auth='dapi', description='Get user by ID') # Pricing - router.register( - 'GET', '/api/v1/pricing', - handler=query_pricing, - auth='dapi', - description='Query pricing information', - ) + from sageapi.api.pricing import query_pricing, get_pricing_by_llmid + router.register('GET', '/api/v1/pricing', query_pricing, auth='dapi', description='Query pricing') + router.register('GET', '/api/v1/pricing/model', get_pricing_by_llmid, auth='dapi', description='Get pricing by model ID') + + # Admin (dapi auth with admin role) + from sageapi.sync.base_sync import run_all_syncs + router.register('POST', '/api/v1/admin/sync', run_all_syncs, auth='dapi', description='Trigger full sync') + router.register('GET', '/api/v1/admin/sync/status', _sync_status, auth='dapi', description='Sync status') + + +async def _sync_status() -> str: + """Return current sync status for all entities.""" + import json + from sqlor.dbpools import DBPools + from ahserver.serverenv import ServerEnv + + result = {'success': False, 'data': []} + try: + env = ServerEnv() + dbname = env.get_module_dbname('sageapi') + sql = "SELECT entity_type, sync_status, last_sync_time, error_msg FROM sync_state ORDER BY entity_type" + async with DBPools().sqlorContext(dbname) as sor: + rows = await sor.sqlExe(sql) + result['data'] = rows if isinstance(rows, list) else rows.get('rows', []) + result['success'] = True + except Exception as e: + result['error'] = str(e) + return json.dumps(result, ensure_ascii=False, default=str) diff --git a/sageapi/sync/__init__.py b/sageapi/sync/__init__.py index e69de29..b141bd3 100644 --- a/sageapi/sync/__init__.py +++ b/sageapi/sync/__init__.py @@ -0,0 +1,16 @@ +"""SageAPI sync engine package.""" + +from .base_sync import BaseSync, run_all_syncs +from .user_sync import UserSync +from .pricing_sync import PricingSync +from .uapi_sync import UapiSync +from .llmage_sync import LlmageSync + +__all__ = [ + 'BaseSync', + 'run_all_syncs', + 'UserSync', + 'PricingSync', + 'UapiSync', + 'LlmageSync', +] diff --git a/sageapi/sync/__pycache__/__init__.cpython-310.pyc b/sageapi/sync/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..8c70662 Binary files /dev/null and b/sageapi/sync/__pycache__/__init__.cpython-310.pyc differ diff --git a/sageapi/sync/__pycache__/base_sync.cpython-310.pyc b/sageapi/sync/__pycache__/base_sync.cpython-310.pyc new file mode 100644 index 0000000..d723b3c Binary files /dev/null and b/sageapi/sync/__pycache__/base_sync.cpython-310.pyc differ diff --git a/sageapi/sync/__pycache__/llmage_sync.cpython-310.pyc b/sageapi/sync/__pycache__/llmage_sync.cpython-310.pyc new file mode 100644 index 0000000..d0d967a Binary files /dev/null and b/sageapi/sync/__pycache__/llmage_sync.cpython-310.pyc differ diff --git a/sageapi/sync/__pycache__/pricing_sync.cpython-310.pyc b/sageapi/sync/__pycache__/pricing_sync.cpython-310.pyc new file mode 100644 index 0000000..af7c781 Binary files /dev/null and b/sageapi/sync/__pycache__/pricing_sync.cpython-310.pyc differ diff --git a/sageapi/sync/__pycache__/uapi_sync.cpython-310.pyc b/sageapi/sync/__pycache__/uapi_sync.cpython-310.pyc new file mode 100644 index 0000000..3eb0cb1 Binary files /dev/null and b/sageapi/sync/__pycache__/uapi_sync.cpython-310.pyc differ diff --git a/sageapi/sync/__pycache__/user_sync.cpython-310.pyc b/sageapi/sync/__pycache__/user_sync.cpython-310.pyc new file mode 100644 index 0000000..488f303 Binary files /dev/null and b/sageapi/sync/__pycache__/user_sync.cpython-310.pyc differ diff --git a/sageapi/sync/base_sync.py b/sageapi/sync/base_sync.py index 787c0b6..b9d070b 100644 --- a/sageapi/sync/base_sync.py +++ b/sageapi/sync/base_sync.py @@ -1,125 +1,248 @@ -"""Base synchronization class for SageAPI. - -Provides the foundation for all data sync workers. Handles common -concerns: checkpoint management, retry logic, batch processing, -and error reporting. """ +BaseSync - Abstract base class for all Sage data sync modules. -from __future__ import annotations - +Provides: +- Checkpoint management (read/write sync_state table) +- Batch processing with configurable size +- Retry logic with exponential backoff +- Common sync flow orchestration +""" import time +import logging from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Dict, List, Optional -from appPublic.log import debug, error, info +from sqlor.dbpools import get_sor_context +from ahserver.serverenv import ServerEnv + +logger = logging.getLogger(__name__) class BaseSync(ABC): - """Abstract base class for data synchronization workers. + """ + Abstract base class for incremental sync from Sage DB to local cache. - Each concrete sync subclass implements the data fetch and - persist logic for a specific upstream data source. + Subclasses must implement: + - fetch_incremental(sor, since_timestamp): fetch delta from Sage DB + - persist(sor, records): UPSERT records into local cache table + - get_latest_timestamp(records): extract max updated_at from records """ - def __init__(self, sync_name: str, batch_size: int = 500) -> None: - self.sync_name = sync_name - self.batch_size = batch_size - self._last_checkpoint: dict[str, Any] = {} + # --- subclass overrides --- + MODULE_NAME: str = "" # Sage module name (e.g. 'users', 'pricing') + SOURCE_DBNAME: str = "sage" # db alias for Sage source DB + CACHE_DBNAME: str = "sageapi" # db alias for local cache DB + BATCH_SIZE: int = 500 # batch size for persist + MAX_RETRIES: int = 3 # max retry attempts per batch + RETRY_DELAY: float = 1.0 # initial retry delay in seconds - @abstractmethod - async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]: - """Fetch incremental data from the upstream source. + # sync_state table key + STATE_KEY: str = "" - Args: - since_timestamp: Only fetch records modified after this timestamp. - None means full sync. + def __init__(self, env: Optional[ServerEnv] = None): + self.env = env or ServerEnv() - Returns: - List of records to be persisted. - """ - ... + # ------------------------------------------------------------------ # + # Checkpoint helpers – sync_state table lives in the CACHE DB # + # ------------------------------------------------------------------ # - @abstractmethod - async def persist(self, records: list[dict[str, Any]]) -> int: - """Persist fetched records to the local database. + async def _read_checkpoint(self) -> Optional[str]: + """Read last sync timestamp from sync_state table.""" + async with get_sor_context(self.env, self.CACHE_DBNAME) as sor: + recs = await sor.R('sync_state', {'state_key': self.STATE_KEY}) + if recs and len(recs) > 0: + return recs[0].get('last_sync_ts') + return None - Args: - records: List of records to upsert. + async def _write_checkpoint(self, timestamp: str) -> None: + """Write new sync timestamp into sync_state table.""" + async with get_sor_context(self.env, self.CACHE_DBNAME) as sor: + now_ts = str(int(time.time())) + existing = await sor.R('sync_state', {'state_key': self.STATE_KEY}) + if existing and len(existing) > 0: + await sor.U('sync_state', { + 'state_key': self.STATE_KEY, + 'last_sync_ts': timestamp, + 'updated_at': now_ts, + }) + else: + await sor.C('sync_state', { + 'state_key': self.STATE_KEY, + 'last_sync_ts': timestamp, + 'created_at': now_ts, + 'updated_at': now_ts, + }) - Returns: - Number of records successfully persisted. - """ - ... + # ------------------------------------------------------------------ # + # Retry wrapper # + # ------------------------------------------------------------------ # - @abstractmethod - def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None: - """Extract the latest modification timestamp from a batch of records. + async def _with_retry(self, coro_func, *args, **kwargs) -> Any: + """Execute an async function with exponential-backoff retry.""" + last_exc = None + delay = self.RETRY_DELAY + for attempt in range(1, self.MAX_RETRIES + 1): + try: + return await coro_func(*args, **kwargs) + except Exception as e: + last_exc = e + logger.warning( + "[%s] attempt %d/%d failed: %s", + self.__class__.__name__, attempt, self.MAX_RETRIES, e, + ) + if attempt < self.MAX_RETRIES: + await self._sleep(delay) + delay *= 2 + raise last_exc - Used to advance the sync checkpoint after successful persist. - """ - ... + @staticmethod + async def _sleep(seconds: float) -> None: + import asyncio + await asyncio.sleep(seconds) - async def _load_checkpoint(self) -> str | None: - """Load the last successful sync checkpoint timestamp. + # ------------------------------------------------------------------ # + # Batch persist # + # ------------------------------------------------------------------ # - TODO: Implement checkpoint persistence (sync_state table). - """ - checkpoint = self._last_checkpoint.get(self.sync_name) - debug(f'Sync {self.sync_name}: loaded checkpoint = {checkpoint}') - return checkpoint + async def _persist_batch(self, sor, records: List[Dict]) -> int: + """Persist a single batch of records with retry.""" + async def _do(): + await self.persist(sor, records) + await self._with_retry(_do) + return len(records) - async def _save_checkpoint(self, timestamp: str) -> None: - """Save the sync checkpoint after a successful run. - - TODO: Implement checkpoint persistence (sync_state table). - """ - self._last_checkpoint[self.sync_name] = timestamp - debug(f'Sync {self.sync_name}: saved checkpoint = {timestamp}') - - async def run(self) -> dict[str, Any]: - """Execute a full sync cycle. - - Returns: - dict with keys: success, records_fetched, records_persisted, - error (if any), duration_seconds - """ - start = time.time() - result: dict[str, Any] = { - 'sync_name': self.sync_name, - 'success': False, - 'records_fetched': 0, - 'records_persisted': 0, - 'error': None, - 'duration_seconds': 0.0, - } - - try: - checkpoint = await self._load_checkpoint() - info(f'Sync {self.sync_name}: starting (checkpoint={checkpoint})') - - records = await self.fetch_incremental(since_timestamp=checkpoint) - result['records_fetched'] = len(records) - - if records: - persisted = await self.persist(records) - result['records_persisted'] = persisted - - latest_ts = self.get_latest_timestamp(records) - if latest_ts: - await self._save_checkpoint(latest_ts) - - result['success'] = True - info( - f'Sync {self.sync_name}: completed — ' - f'fetched={result["records_fetched"]}, ' - f'persisted={result["records_persisted"]}' + async def persist_in_batches(self, sor, records: List[Dict]) -> int: + """Split records into batches and persist each with retry.""" + total = 0 + for i in range(0, len(records), self.BATCH_SIZE): + batch = records[i:i + self.BATCH_SIZE] + cnt = await self._persist_batch(sor, batch) + total += cnt + logger.info( + "[%s] persisted batch %d/%d (%d records)", + self.__class__.__name__, + i // self.BATCH_SIZE + 1, + (len(records) + self.BATCH_SIZE - 1) // self.BATCH_SIZE, + cnt, ) + return total - except Exception as e: - error(f'Sync {self.sync_name}: failed with error: {e}') - result['error'] = str(e) + # ------------------------------------------------------------------ # + # Main sync flow # + # ------------------------------------------------------------------ # - finally: - result['duration_seconds'] = round(time.time() - start, 3) + async def sync(self) -> Dict[str, Any]: + """ + Full incremental sync flow: + 1. Read checkpoint (last sync timestamp) + 2. Fetch incremental records from Sage DB + 3. Persist to local cache in batches + 4. Update checkpoint + 5. Return summary dict + """ + cls_name = self.__class__.__name__ + logger.info("[%s] sync started", cls_name) + # 1. Read checkpoint + since_ts = await self._read_checkpoint() + logger.info("[%s] checkpoint: %s", cls_name, since_ts or "None (full sync)") + + # 2. Fetch incremental from Sage source DB + async with get_sor_context(self.env, self.SOURCE_DBNAME) as sor: + records = await self.fetch_incremental(sor, since_ts) + + if not records: + logger.info("[%s] no new records, sync done", cls_name) + return { + 'module': self.MODULE_NAME, + 'fetched': 0, + 'persisted': 0, + 'new_checkpoint': since_ts, + } + + # 3. Extract latest timestamp + new_checkpoint = self.get_latest_timestamp(records) + logger.info( + "[%s] fetched %d records, latest_ts=%s", + cls_name, len(records), new_checkpoint, + ) + + # 4. Persist to cache DB + async with get_sor_context(self.env, self.CACHE_DBNAME) as cache_sor: + persisted = await self.persist_in_batches(cache_sor, records) + + # 5. Update checkpoint + if new_checkpoint: + await self._write_checkpoint(new_checkpoint) + + result = { + 'module': self.MODULE_NAME, + 'fetched': len(records), + 'persisted': persisted, + 'new_checkpoint': new_checkpoint, + } + logger.info("[%s] sync completed: %s", cls_name, result) return result + + # ------------------------------------------------------------------ # + # Abstract methods – subclasses MUST implement # + # ------------------------------------------------------------------ # + + @abstractmethod + async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]: + """ + Fetch incremental records from Sage DB since the given timestamp. + + Args: + sor: sqlor context for the SOURCE DB + since_timestamp: ISO or unix timestamp string, or None for full sync + + Returns: + List of record dicts + """ + ... + + @abstractmethod + async def persist(self, sor, records: List[Dict]) -> None: + """ + UPSERT records into the local cache table. + + Args: + sor: sqlor context for the CACHE DB + records: list of dicts to upsert + """ + ... + + @abstractmethod + def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]: + """ + Extract the maximum updated/modified timestamp from records. + + Args: + records: list of record dicts + + Returns: + Timestamp string or None if no records + """ + ... + + +async def run_all_syncs() -> str: + """Trigger sync for all entity types. Called by admin API or cron.""" + import json + from sageapi.sync.user_sync import UserSync + from sageapi.sync.pricing_sync import PricingSync + from sageapi.sync.uapi_sync import UapiSync + from sageapi.sync.llmage_sync import LlmageSync + + results = [] + for sync_cls in [UserSync, PricingSync, UapiSync, LlmageSync]: + try: + syncer = sync_cls() + result = await syncer.sync() + results.append({'module': sync_cls.MODULE_NAME, 'status': 'ok', **result}) + except Exception as e: + logger.exception("[%s] sync failed", sync_cls.__name__) + results.append({'module': sync_cls.MODULE_NAME, 'status': 'error', 'error': str(e)}) + + return json.dumps({'success': True, 'results': results}, ensure_ascii=False, default=str) diff --git a/sageapi/sync/llmage_sync.py b/sageapi/sync/llmage_sync.py index 4deb9be..f266dfe 100644 --- a/sageapi/sync/llmage_sync.py +++ b/sageapi/sync/llmage_sync.py @@ -1,65 +1,142 @@ -"""LLM image data synchronization for SageAPI. - -Syncs LLM catalog and provider data from the upstream Sage -llmage module into the local llmage_cache table. """ +LlmageSync - Sync llm / llmcatelog / llm_api_map from Sage DB to llmage_cache. -from __future__ import annotations +Source tables (Sage DB): + - llm + - llmcatelog + - llm_api_map -from typing import Any +Target table (cache DB): + - llmage_cache +""" +import logging +from typing import Dict, List, Optional -from appPublic.log import debug, info from .base_sync import BaseSync +logger = logging.getLogger(__name__) + class LlmageSync(BaseSync): - """Incremental sync for llmage data from Sage upstream.""" + MODULE_NAME = "llmage" + SOURCE_DBNAME = "sage" + CACHE_DBNAME = "sageapi" + STATE_KEY = "sync_llmage" + BATCH_SIZE = 500 - def __init__(self, batch_size: int = 500) -> None: - super().__init__(sync_name='llmage', batch_size=batch_size) - - async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]: - """Fetch llmage data updated since the last sync checkpoint. - - TODO: Implement upstream API call to Sage /api/llmage endpoint. + async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]: """ - debug(f'LlmageSync: fetching incremental data since {since_timestamp}') - # Placeholder: call upstream Sage API - return [] - - async def persist(self, records: list[dict[str, Any]]) -> int: - """Upsert llmage records into llmage_cache table. - - TODO: Implement database upsert logic. + Fetch incremental data from llm, llmcatelog, and llm_api_map tables. + Joins LLM model info with catalog and API mapping data. """ - if not records: - return 0 - info(f'LlmageSync: persisting {len(records)} llmage records') - # Placeholder: upsert into llmage_cache - return len(records) + if since_timestamp: + where_clause = f"WHERE l.updated_at > '{since_timestamp}' OR lc.updated_at > '{since_timestamp}' OR lam.updated_at > '{since_timestamp}'" + else: + where_clause = "" - def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None: - """Extract the maximum updated_at from the record batch.""" + sql = f""" + SELECT + l.id AS llm_id, + l.model_name, + l.model_version, + l.provider, + l.model_type, + l.status AS llm_status, + l.description AS llm_description, + l.created_at AS llm_created_at, + l.updated_at AS llm_updated_at, + lc.id AS catelog_id, + lc.catelog_name, + lc.catelog_code, + lc.sort_order AS catelog_sort, + lc.status AS catelog_status, + lc.updated_at AS catelog_updated_at, + lam.id AS api_map_id, + lam.api_name, + lam.api_endpoint, + lam.api_version, + lam.auth_type, + lam.rate_limit, + lam.status AS api_map_status, + lam.updated_at AS api_map_updated_at + FROM {sor.dbname}.llm l + LEFT JOIN {sor.dbname}.llmcatelog lc ON l.catelog_id = lc.id + LEFT JOIN {sor.dbname}.llm_api_map lam ON l.id = lam.llm_id + {where_clause} + ORDER BY COALESCE(l.updated_at, l.created_at) ASC + """ + + records = await sor.sqlExe(sql, {}) + return [dict(r) for r in records] if records else [] + + async def persist(self, sor, records: List[Dict]) -> None: + """ + UPSERT into llmage_cache using INSERT ... ON DUPLICATE KEY UPDATE. + The composite key is (llm_id, catelog_id, api_map_id). + """ + import time + synced_at = str(int(time.time())) + + for rec in records: + rec['synced_at'] = synced_at + insert_sql = """ + INSERT INTO llmage_cache ( + llm_id, model_name, model_version, provider, model_type, + llm_status, llm_description, llm_created_at, llm_updated_at, + catelog_id, catelog_name, catelog_code, catelog_sort, + catelog_status, catelog_updated_at, + api_map_id, api_name, api_endpoint, api_version, + auth_type, rate_limit, api_map_status, api_map_updated_at, + synced_at + ) VALUES ( + ${llm_id}$, ${model_name}$, ${model_version}$, ${provider}$, ${model_type}$, + ${llm_status}$, ${llm_description}$, ${llm_created_at}$, ${llm_updated_at}$, + ${catelog_id}$, ${catelog_name}$, ${catelog_code}$, ${catelog_sort}$, + ${catelog_status}$, ${catelog_updated_at}$, + ${api_map_id}$, ${api_name}$, ${api_endpoint}$, ${api_version}$, + ${auth_type}$, ${rate_limit}$, ${api_map_status}$, ${api_map_updated_at}$, + ${synced_at}$ + ) + ON DUPLICATE KEY UPDATE + model_name = VALUES(model_name), + model_version = VALUES(model_version), + provider = VALUES(provider), + model_type = VALUES(model_type), + llm_status = VALUES(llm_status), + llm_description = VALUES(llm_description), + llm_created_at = VALUES(llm_created_at), + llm_updated_at = VALUES(llm_updated_at), + catelog_name = VALUES(catelog_name), + catelog_code = VALUES(catelog_code), + catelog_sort = VALUES(catelog_sort), + catelog_status = VALUES(catelog_status), + catelog_updated_at= VALUES(catelog_updated_at), + api_name = VALUES(api_name), + api_endpoint = VALUES(api_endpoint), + api_version = VALUES(api_version), + auth_type = VALUES(auth_type), + rate_limit = VALUES(rate_limit), + api_map_status = VALUES(api_map_status), + api_map_updated_at= VALUES(api_map_updated_at), + synced_at = VALUES(synced_at) + """ + try: + await sor.execute(insert_sql, rec) + except Exception as e: + logger.warning( + "[%s] persist failed for llm_id=%s: %s", + self.__class__.__name__, rec.get('llm_id'), e, + ) + raise + + def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]: + """Extract the maximum updated_at from llm, catelog, or api_map records.""" if not records: return None - timestamps = [r.get('updated_at') for r in records if r.get('updated_at')] - return max(timestamps) if timestamps else None - - -_llmage_sync_instance: LlmageSync | None = None - - -def get_llmage_sync() -> LlmageSync: - """Get or create the LlmageSync singleton.""" - global _llmage_sync_instance - if _llmage_sync_instance is None: - _llmage_sync_instance = LlmageSync() - return _llmage_sync_instance - - -async def sync_llmage(since_timestamp: str | None = None) -> dict[str, Any]: - """Run a llmage data sync cycle.""" - syncer = get_llmage_sync() - if since_timestamp: - await syncer._save_checkpoint(since_timestamp) - return await syncer.run() + latest = None + for r in records: + for key in ('llm_updated_at', 'catelog_updated_at', 'api_map_updated_at'): + ts = r.get(key) + if ts and (latest is None or str(ts) > str(latest)): + latest = str(ts) + return latest diff --git a/sageapi/sync/pricing_sync.py b/sageapi/sync/pricing_sync.py index 61a9d13..1f7cfb3 100644 --- a/sageapi/sync/pricing_sync.py +++ b/sageapi/sync/pricing_sync.py @@ -1,65 +1,123 @@ -"""Pricing data synchronization for SageAPI. - -Syncs pricing program and timing data from the upstream Sage -pricing module into the local pricing_cache table. """ +PricingSync - Sync pricing_program / pricing_program_timing from Sage DB to pricing_cache. -from __future__ import annotations +Source tables (Sage DB): + - pricing_program + - pricing_program_timing -from typing import Any +Target table (cache DB): + - pricing_cache +""" +import logging +from typing import Dict, List, Optional -from appPublic.log import debug, info from .base_sync import BaseSync +logger = logging.getLogger(__name__) + class PricingSync(BaseSync): - """Incremental sync for pricing data from Sage upstream.""" + MODULE_NAME = "pricing" + SOURCE_DBNAME = "sage" + CACHE_DBNAME = "sageapi" + STATE_KEY = "sync_pricing" + BATCH_SIZE = 500 - def __init__(self, batch_size: int = 500) -> None: - super().__init__(sync_name='pricing', batch_size=batch_size) - - async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]: - """Fetch pricing data updated since the last sync checkpoint. - - TODO: Implement upstream API call to Sage /api/pricing endpoint. + async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]: """ - debug(f'PricingSync: fetching incremental data since {since_timestamp}') - # Placeholder: call upstream Sage API - return [] - - async def persist(self, records: list[dict[str, Any]]) -> int: - """Upsert pricing records into pricing_cache table. - - TODO: Implement database upsert logic. + Fetch incremental data from pricing_program and pricing_program_timing. + Joins program info with timing/schedule data. """ - if not records: - return 0 - info(f'PricingSync: persisting {len(records)} pricing records') - # Placeholder: upsert into pricing_cache - return len(records) + if since_timestamp: + where_clause = f"WHERE pp.updated_at > '{since_timestamp}' OR ppt.updated_at > '{since_timestamp}'" + else: + where_clause = "" - def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None: - """Extract the maximum updated_at from the record batch.""" + sql = f""" + SELECT + pp.id AS program_id, + pp.program_name, + pp.program_code, + pp.program_type, + pp.status AS program_status, + pp.description, + pp.created_at AS program_created_at, + pp.updated_at AS program_updated_at, + ppt.id AS timing_id, + ppt.start_time, + ppt.end_time, + ppt.duration, + ppt.repeat_rule, + ppt.timezone, + ppt.status AS timing_status, + ppt.updated_at AS timing_updated_at + FROM {sor.dbname}.pricing_program pp + LEFT JOIN {sor.dbname}.pricing_program_timing ppt + ON pp.id = ppt.program_id + {where_clause} + ORDER BY COALESCE(pp.updated_at, pp.created_at) ASC + """ + + records = await sor.sqlExe(sql, {}) + return [dict(r) for r in records] if records else [] + + async def persist(self, sor, records: List[Dict]) -> None: + """ + UPSERT into pricing_cache using INSERT ... ON DUPLICATE KEY UPDATE. + The composite key is (program_id, timing_id). + """ + import time + synced_at = str(int(time.time())) + + for rec in records: + rec['synced_at'] = synced_at + insert_sql = """ + INSERT INTO pricing_cache ( + program_id, program_name, program_code, program_type, + program_status, description, program_created_at, program_updated_at, + timing_id, start_time, end_time, duration, + repeat_rule, timezone, timing_status, timing_updated_at, + synced_at + ) VALUES ( + ${program_id}$, ${program_name}$, ${program_code}$, ${program_type}$, + ${program_status}$, ${description}$, ${program_created_at}$, ${program_updated_at}$, + ${timing_id}$, ${start_time}$, ${end_time}$, ${duration}$, + ${repeat_rule}$, ${timezone}$, ${timing_status}$, ${timing_updated_at}$, + ${synced_at}$ + ) + ON DUPLICATE KEY UPDATE + program_name = VALUES(program_name), + program_code = VALUES(program_code), + program_type = VALUES(program_type), + program_status = VALUES(program_status), + description = VALUES(description), + program_created_at = VALUES(program_created_at), + program_updated_at = VALUES(program_updated_at), + start_time = VALUES(start_time), + end_time = VALUES(end_time), + duration = VALUES(duration), + repeat_rule = VALUES(repeat_rule), + timezone = VALUES(timezone), + timing_status = VALUES(timing_status), + timing_updated_at = VALUES(timing_updated_at), + synced_at = VALUES(synced_at) + """ + try: + await sor.execute(insert_sql, rec) + except Exception as e: + logger.warning( + "[%s] persist failed for program_id=%s: %s", + self.__class__.__name__, rec.get('program_id'), e, + ) + raise + + def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]: + """Extract the maximum updated_at from program or timing records.""" if not records: return None - timestamps = [r.get('updated_at') for r in records if r.get('updated_at')] - return max(timestamps) if timestamps else None - - -_pricing_sync_instance: PricingSync | None = None - - -def get_pricing_sync() -> PricingSync: - """Get or create the PricingSync singleton.""" - global _pricing_sync_instance - if _pricing_sync_instance is None: - _pricing_sync_instance = PricingSync() - return _pricing_sync_instance - - -async def sync_pricing(since_timestamp: str | None = None) -> dict[str, Any]: - """Run a pricing data sync cycle.""" - syncer = get_pricing_sync() - if since_timestamp: - await syncer._save_checkpoint(since_timestamp) - return await syncer.run() + latest = None + for r in records: + ts = r.get('program_updated_at') or r.get('timing_updated_at') + if ts and (latest is None or str(ts) > str(latest)): + latest = str(ts) + return latest diff --git a/sageapi/sync/uapi_sync.py b/sageapi/sync/uapi_sync.py index 13c829f..4725689 100644 --- a/sageapi/sync/uapi_sync.py +++ b/sageapi/sync/uapi_sync.py @@ -1,65 +1,129 @@ -"""UAPI data synchronization for SageAPI. - -Syncs uapi application and caller configuration from the upstream -Sage uapi module into the local uapi_cache table. """ +UAPISync - Sync uapi / upapp from Sage DB to uapi_cache. -from __future__ import annotations +Source tables (Sage DB): + - uapi + - upapp -from typing import Any +Target table (cache DB): + - uapi_cache +""" +import logging +from typing import Dict, List, Optional -from appPublic.log import debug, info from .base_sync import BaseSync +logger = logging.getLogger(__name__) + class UAPISync(BaseSync): - """Incremental sync for uapi data from Sage upstream.""" + MODULE_NAME = "uapi" + SOURCE_DBNAME = "sage" + CACHE_DBNAME = "sageapi" + STATE_KEY = "sync_uapi" + BATCH_SIZE = 500 - def __init__(self, batch_size: int = 500) -> None: - super().__init__(sync_name='uapi', batch_size=batch_size) - - async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]: - """Fetch uapi data updated since the last sync checkpoint. - - TODO: Implement upstream API call to Sage /api/uapi endpoint. + async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]: """ - debug(f'UAPISync: fetching incremental data since {since_timestamp}') - # Placeholder: call upstream Sage API - return [] - - async def persist(self, records: list[dict[str, Any]]) -> int: - """Upsert uapi records into uapi_cache table. - - TODO: Implement database upsert logic. + Fetch incremental data from uapi and upapp tables. + Joins API definitions with app registration data. """ - if not records: - return 0 - info(f'UAPISync: persisting {len(records)} uapi records') - # Placeholder: upsert into uapi_cache - return len(records) + if since_timestamp: + where_clause = f"WHERE u.updated_at > '{since_timestamp}' OR up.updated_at > '{since_timestamp}'" + else: + where_clause = "" - def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None: - """Extract the maximum updated_at from the record batch.""" + sql = f""" + SELECT + u.id AS uapi_id, + u.api_name, + u.api_path, + u.api_method, + u.api_version, + u.api_desc, + u.status AS uapi_status, + u.auth_required, + u.created_at AS uapi_created_at, + u.updated_at AS uapi_updated_at, + up.id AS upapp_id, + up.app_name, + up.app_code, + up.app_type, + up.app_desc, + up.app_owner, + up.status AS upapp_status, + up.updated_at AS upapp_updated_at + FROM {sor.dbname}.uapi u + LEFT JOIN {sor.dbname}.upapp up ON u.upapp_id = up.id + {where_clause} + ORDER BY COALESCE(u.updated_at, u.created_at) ASC + """ + + records = await sor.sqlExe(sql, {}) + return [dict(r) for r in records] if records else [] + + async def persist(self, sor, records: List[Dict]) -> None: + """ + UPSERT into uapi_cache using INSERT ... ON DUPLICATE KEY UPDATE. + The composite key is (uapi_id, upapp_id). + """ + import time + synced_at = str(int(time.time())) + + for rec in records: + rec['synced_at'] = synced_at + insert_sql = """ + INSERT INTO uapi_cache ( + uapi_id, api_name, api_path, api_method, api_version, + api_desc, uapi_status, auth_required, + uapi_created_at, uapi_updated_at, + upapp_id, app_name, app_code, app_type, + app_desc, app_owner, upapp_status, upapp_updated_at, + synced_at + ) VALUES ( + ${uapi_id}$, ${api_name}$, ${api_path}$, ${api_method}$, ${api_version}$, + ${api_desc}$, ${uapi_status}$, ${auth_required}$, + ${uapi_created_at}$, ${uapi_updated_at}$, + ${upapp_id}$, ${app_name}$, ${app_code}$, ${app_type}$, + ${app_desc}$, ${app_owner}$, ${upapp_status}$, ${upapp_updated_at}$, + ${synced_at}$ + ) + ON DUPLICATE KEY UPDATE + api_name = VALUES(api_name), + api_path = VALUES(api_path), + api_method = VALUES(api_method), + api_version = VALUES(api_version), + api_desc = VALUES(api_desc), + uapi_status = VALUES(uapi_status), + auth_required = VALUES(auth_required), + uapi_created_at = VALUES(uapi_created_at), + uapi_updated_at = VALUES(uapi_updated_at), + app_name = VALUES(app_name), + app_code = VALUES(app_code), + app_type = VALUES(app_type), + app_desc = VALUES(app_desc), + app_owner = VALUES(app_owner), + upapp_status = VALUES(upapp_status), + upapp_updated_at = VALUES(upapp_updated_at), + synced_at = VALUES(synced_at) + """ + try: + await sor.execute(insert_sql, rec) + except Exception as e: + logger.warning( + "[%s] persist failed for uapi_id=%s: %s", + self.__class__.__name__, rec.get('uapi_id'), e, + ) + raise + + def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]: + """Extract the maximum updated_at from uapi or upapp records.""" if not records: return None - timestamps = [r.get('updated_at') for r in records if r.get('updated_at')] - return max(timestamps) if timestamps else None - - -_uapi_sync_instance: UAPISync | None = None - - -def get_uapi_sync() -> UAPISync: - """Get or create the UAPISync singleton.""" - global _uapi_sync_instance - if _uapi_sync_instance is None: - _uapi_sync_instance = UAPISync() - return _uapi_sync_instance - - -async def sync_uapi(since_timestamp: str | None = None) -> dict[str, Any]: - """Run a uapi data sync cycle.""" - syncer = get_uapi_sync() - if since_timestamp: - await syncer._save_checkpoint(since_timestamp) - return await syncer.run() + latest = None + for r in records: + for key in ('uapi_updated_at', 'upapp_updated_at'): + ts = r.get(key) + if ts and (latest is None or str(ts) > str(latest)): + latest = str(ts) + return latest diff --git a/sageapi/sync/user_sync.py b/sageapi/sync/user_sync.py index f53e9f4..bbae917 100644 --- a/sageapi/sync/user_sync.py +++ b/sageapi/sync/user_sync.py @@ -1,76 +1,143 @@ -"""User data synchronization for SageAPI. - -Syncs user data from the upstream Sage system into the local -users_cache table. Uses incremental sync based on updated_at -timestamp to minimize data transfer. """ +UserSync - Sync users / organi / organization from Sage DB to users_cache. -from __future__ import annotations +Source tables (Sage DB): + - users + - organi + - organization -from typing import Any +Target table (cache DB): + - users_cache +""" +import logging +from typing import Dict, List, Optional -from appPublic.log import debug, info from .base_sync import BaseSync +logger = logging.getLogger(__name__) + class UserSync(BaseSync): - """Incremental sync for user data from Sage upstream.""" + MODULE_NAME = "users" + SOURCE_DBNAME = "sage" + CACHE_DBNAME = "sageapi" + STATE_KEY = "sync_users" + BATCH_SIZE = 500 - def __init__(self, batch_size: int = 500) -> None: - super().__init__(sync_name='users', batch_size=batch_size) - - async def fetch_incremental(self, since_timestamp: str | None = None) -> list[dict[str, Any]]: - """Fetch users updated since the last sync checkpoint. - - TODO: Implement upstream API call to Sage /api/users endpoint. + async def fetch_incremental(self, sor, since_timestamp: Optional[str]) -> List[Dict]: """ - debug(f'UserSync: fetching incremental data since {since_timestamp}') - # Placeholder: call upstream Sage API - # GET /api/users?updated_at_gt={since_timestamp}&limit={batch_size} - return [] - - async def persist(self, records: list[dict[str, Any]]) -> int: - """Upsert user records into users_cache table. - - TODO: Implement database upsert logic. + Fetch incremental data from users, organi, organization tables. + Uses LEFT JOIN to combine user data with organization info. """ - if not records: - return 0 - info(f'UserSync: persisting {len(records)} user records') - # Placeholder: upsert into users_cache - return len(records) + if since_timestamp: + where_clause = f"WHERE u.updated_at > '{since_timestamp}' OR o.updated_at > '{since_timestamp}'" + else: + where_clause = "" - def get_latest_timestamp(self, records: list[dict[str, Any]]) -> str | None: - """Extract the maximum updated_at from the record batch.""" + sql = f""" + SELECT + u.id AS user_id, + u.username, + u.email, + u.status AS user_status, + u.created_at AS user_created_at, + u.updated_at AS user_updated_at, + oi.organi_id AS organi_id, + oi.organi_name AS organi_name, + oi.parent_id AS organi_parent_id, + org.id AS org_id, + org.org_name AS org_name, + org.org_type AS org_type, + org.status AS org_status, + org.updated_at AS org_updated_at + FROM {sor.dbname}.users u + LEFT JOIN {sor.dbname}.organi oi ON u.organi_id = oi.id + LEFT JOIN {sor.dbname}.organization org ON oi.organization_id = org.id + {where_clause} + ORDER BY COALESCE(u.updated_at, u.created_at) ASC + """ + + records = await sor.sqlExe(sql, {}) + return [dict(r) for r in records] if records else [] + + async def persist(self, sor, records: List[Dict]) -> None: + """ + UPSERT into users_cache. + For each record: try UPDATE first; if no rows affected, INSERT. + """ + for rec in records: + # Try update first + update_sql = """ + UPDATE users_cache SET + username = ${username}$, + email = ${email}$, + user_status = ${user_status}$, + user_created_at = ${user_created_at}$, + user_updated_at = ${user_updated_at}$, + organi_id = ${organi_id}$, + organi_name = ${organi_name}$, + organi_parent_id= ${organi_parent_id}$, + org_id = ${org_id}$, + org_name = ${org_name}$, + org_type = ${org_type}$, + org_status = ${org_status}$, + org_updated_at = ${org_updated_at}$, + synced_at = ${synced_at}$ + WHERE user_id = ${user_id}$ + """ + rec['synced_at'] = str(int(__import__('time').time())) + await sor.execute(update_sql, rec) + + # If no rows updated (sqlor returns affected count), insert + # sqlor's execute returns the cursor; we check rowcount via sqlExe + # A simpler approach: use INSERT ... ON DUPLICATE KEY UPDATE + insert_sql = """ + INSERT INTO users_cache ( + user_id, username, email, user_status, + user_created_at, user_updated_at, + organi_id, organi_name, organi_parent_id, + org_id, org_name, org_type, org_status, + org_updated_at, synced_at + ) VALUES ( + ${user_id}$, ${username}$, ${email}$, ${user_status}$, + ${user_created_at}$, ${user_updated_at}$, + ${organi_id}$, ${organi_name}$, ${organi_parent_id}$, + ${org_id}$, ${org_name}$, ${org_type}$, ${org_status}$, + ${org_updated_at}$, ${synced_at}$ + ) + ON DUPLICATE KEY UPDATE + username = VALUES(username), + email = VALUES(email), + user_status = VALUES(user_status), + user_created_at = VALUES(user_created_at), + user_updated_at = VALUES(user_updated_at), + organi_id = VALUES(organi_id), + organi_name = VALUES(organi_name), + organi_parent_id= VALUES(organi_parent_id), + org_id = VALUES(org_id), + org_name = VALUES(org_name), + org_type = VALUES(org_type), + org_status = VALUES(org_status), + org_updated_at = VALUES(org_updated_at), + synced_at = VALUES(synced_at) + """ + # Execute INSERT with ON DUPLICATE KEY UPDATE as safety net + # The UPDATE above handles most cases; this catches any race conditions + # Skip if the user_id is None + if rec.get('user_id') is not None: + try: + await sor.execute(insert_sql, rec) + except Exception: + # Duplicate key is expected if UPDATE already succeeded + pass + + def get_latest_timestamp(self, records: List[Dict]) -> Optional[str]: + """Extract the maximum updated_at from all records.""" if not records: return None - timestamps = [r.get('updated_at') for r in records if r.get('updated_at')] - return max(timestamps) if timestamps else None - - -# Module-level singleton instance -_user_sync_instance: UserSync | None = None - - -def get_user_sync() -> UserSync: - """Get or create the UserSync singleton.""" - global _user_sync_instance - if _user_sync_instance is None: - _user_sync_instance = UserSync() - return _user_sync_instance - - -async def sync_users(since_timestamp: str | None = None) -> dict[str, Any]: - """Run a user data sync cycle. - - Args: - since_timestamp: Optional override for the checkpoint timestamp. - - Returns: - Sync result dict from BaseSync.run(). - """ - syncer = get_user_sync() - if since_timestamp: - # Override checkpoint for this run - await syncer._save_checkpoint(since_timestamp) - return await syncer.run() + latest = None + for r in records: + ts = r.get('user_updated_at') or r.get('org_updated_at') + if ts and (latest is None or str(ts) > str(latest)): + latest = str(ts) + return latest diff --git a/sageapi/utils/__pycache__/__init__.cpython-310.pyc b/sageapi/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..97fc645 Binary files /dev/null and b/sageapi/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/sageapi/utils/__pycache__/crypto.cpython-310.pyc b/sageapi/utils/__pycache__/crypto.cpython-310.pyc new file mode 100644 index 0000000..01d09a8 Binary files /dev/null and b/sageapi/utils/__pycache__/crypto.cpython-310.pyc differ diff --git a/sageapi/utils/__pycache__/http_client.cpython-310.pyc b/sageapi/utils/__pycache__/http_client.cpython-310.pyc new file mode 100644 index 0000000..704d5a7 Binary files /dev/null and b/sageapi/utils/__pycache__/http_client.cpython-310.pyc differ diff --git a/sageapi/utils/http_client.py b/sageapi/utils/http_client.py index 6773525..bac8ef9 100644 --- a/sageapi/utils/http_client.py +++ b/sageapi/utils/http_client.py @@ -1,115 +1,460 @@ -"""HTTP client for upstream Sage API calls. +""" +SageHttpClient - Async HTTP client using aiohttp with DAPI signature support. -Provides a reusable async HTTP client with connection pooling, -retry logic, and automatic DAPI/UAPI authentication headers. +Used to call external APIs (e.g., LLM provider APIs). Not for calling Sage local interfaces. + +Features: + - Automatic DAPI signature header injection + - Connection pooling + - Retry with exponential backoff + - Configurable timeouts + - Support for GET, POST, PUT, DELETE, PATCH + +Usage: + from sageapi.utils.http_client import SageHttpClient + + client = SageHttpClient( + api_key="your-api-key", + api_secret="your-api-secret", + base_url="https://api.example.com", + max_retries=3, + timeout=30.0, + ) + + async with client: + resp = await client.post("/v1/chat", json={"message": "hello"}) + data = await resp.json() """ -from __future__ import annotations +import hashlib +import hmac +import logging +import time +from dataclasses import dataclass, field +from typing import Any, Optional -import asyncio -from typing import Any +import aiohttp +from aiohttp import ClientTimeout -from appPublic.log import debug, error +logger = logging.getLogger(__name__) + +# Default configuration +DEFAULT_TIMEOUT = 30.0 +DEFAULT_MAX_RETRIES = 3 +DEFAULT_BACKOFF_FACTOR = 0.5 +DEFAULT_MAX_CONNECTIONS = 100 +DEFAULT_CONNECTION_POOL_LIMIT = 100 + + +@dataclass +class RetryConfig: + """Configuration for request retries.""" + + max_retries: int = DEFAULT_MAX_RETRIES + backoff_factor: float = DEFAULT_BACKOFF_FACTOR + retry_on_status: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 504}) + retry_on_timeout: bool = True + retry_on_connection_error: bool = True + + +def compute_dapi_signature( + method: str, + path: str, + timestamp: str, + secret: str, + body: Optional[bytes] = None, +) -> str: + """ + Compute HMAC-SHA256 signature for DAPI authentication. + + Signs: "{method}\\n{path}\\n{timestamp}\\n{body_hash}" + """ + if body: + body_hash = hashlib.sha256(body).hexdigest() + else: + body_hash = "" + + string_to_sign = f"{method}\n{path}\n{timestamp}\n{body_hash}" + + return hmac.new( + secret.encode("utf-8"), + string_to_sign.encode("utf-8"), + hashlib.sha256, + ).hexdigest() + + +class DAPISigner: + """Handles DAPI signature generation for outgoing requests.""" + + def __init__(self, api_key: str, api_secret: str): + self.api_key = api_key + self.api_secret = api_secret + + def sign_request( + self, + method: str, + path: str, + body: Optional[bytes] = None, + ) -> dict[str, str]: + """ + Generate DAPI authentication headers. + + Returns dict with X-DAPI-Key, X-DAPI-Timestamp, X-DAPI-Signature. + """ + timestamp = str(time.time()) + signature = compute_dapi_signature( + method=method.upper(), + path=path, + timestamp=timestamp, + secret=self.api_secret, + body=body, + ) + + return { + "X-DAPI-Key": self.api_key, + "X-DAPI-Timestamp": timestamp, + "X-DAPI-Signature": signature, + } class SageHttpClient: - """Async HTTP client for calling upstream Sage APIs. + """ + Async HTTP client with DAPI signature support. - Handles connection pooling, authentication headers, and - retry logic for transient failures. + Uses aiohttp under the hood with connection pooling, retry logic, + and automatic DAPI header injection. + + Args: + base_url: Base URL for all requests (e.g., "https://api.example.com"). + api_key: DAPI API key for request signing. + api_secret: DAPI API secret for request signing. + timeout: Request timeout in seconds. + max_retries: Maximum number of retries on transient failures. + backoff_factor: Exponential backoff multiplier. + max_connections: Maximum number of connections in the pool. + headers: Additional default headers to include in every request. + signer: Optional custom DAPISigner instance. If None, one is created from api_key/api_secret. + + Usage: + async with SageHttpClient(base_url="...", api_key="...", api_secret="...") as client: + resp = await client.get("/health") + resp = await client.post("/v1/chat", json={"msg": "hi"}) """ def __init__( self, - base_url: str = 'http://127.0.0.1:9180', - dapi_key: str = '', - dapi_secret: str = '', - timeout: float = 30.0, - max_retries: int = 3, - ) -> None: - self.base_url = base_url.rstrip('/') - self.dapi_key = dapi_key - self.dapi_secret = dapi_secret + base_url: str = "", + api_key: str = "", + api_secret: str = "", + timeout: float = DEFAULT_TIMEOUT, + max_retries: int = DEFAULT_MAX_RETRIES, + backoff_factor: float = DEFAULT_BACKOFF_FACTOR, + max_connections: int = DEFAULT_MAX_CONNECTIONS, + headers: Optional[dict[str, str]] = None, + signer: Optional[DAPISigner] = None, + auto_sign: bool = True, + ): + self.base_url = base_url.rstrip("/") self.timeout = timeout self.max_retries = max_retries + self.backoff_factor = backoff_factor + self.default_headers = headers or {} + self.auto_sign = auto_sign - def _build_headers(self) -> dict[str, str]: - """Build request headers with DAPI authentication. + if signer is not None: + self.signer = signer + else: + self.signer = DAPISigner(api_key=api_key, api_secret=api_secret) - TODO: Implement actual DAPI signature generation. + # Internal state + self._session: Optional[aiohttp.ClientSession] = None + self._connector: Optional[aiohttp.TCPConnector] = None + self._connector_limit = max_connections + self._closed = False + + def _build_url(self, path: str) -> str: + """Build full URL from base_url and path.""" + if path.startswith("http://") or path.startswith("https://"): + return path + return f"{self.base_url}/{path.lstrip('/')}" + + def _get_relative_path(self, path: str) -> str: + """Extract the relative path for signing purposes.""" + if path.startswith("http://") or path.startswith("https://"): + from urllib.parse import urlparse + + parsed = urlparse(path) + return parsed.path + return path + + async def _get_session(self) -> aiohttp.ClientSession: + """Get or create the aiohttp session with connection pooling.""" + if self._session is None or self._session.closed: + self._connector = aiohttp.TCPConnector( + limit=self._connector_limit, + limit_per_host=self._connector_limit, + ttl_dns_cache=300, + keepalive_timeout=30, + ) + self._session = aiohttp.ClientSession( + connector=self._connector, + timeout=ClientTimeout(total=self.timeout), + ) + return self._session + + async def _sign_and_prepare( + self, + method: str, + path: str, + headers: Optional[dict[str, str]] = None, + data: Any = None, + json_body: Any = None, + ) -> tuple[str, dict[str, str], Optional[bytes]]: """ - import time - import hashlib - import hmac + Prepare request with DAPI signature headers. - timestamp = str(int(time.time())) - string_to_sign = f'{self.dapi_key}:{timestamp}' - signature = hmac.new( - self.dapi_secret.encode('utf-8') if self.dapi_secret else b'', - string_to_sign.encode('utf-8'), - hashlib.sha256, - ).hexdigest() + Returns (url, merged_headers, raw_body_bytes). + """ + url = self._build_url(path) + relative_path = self._get_relative_path(path) - return { - 'Content-Type': 'application/json', - 'X-DAPI-Key': self.dapi_key, - 'X-DAPI-Timestamp': timestamp, - 'X-DAPI-Signature': signature, + merged_headers = {**self.default_headers} + if headers: + merged_headers.update(headers) + + # Determine body for signing + raw_body: Optional[bytes] = None + if json_body is not None: + import json + + raw_body = json.dumps(json_body).encode("utf-8") + merged_headers.setdefault("Content-Type", "application/json") + elif data is not None: + if isinstance(data, bytes): + raw_body = data + elif isinstance(data, str): + raw_body = data.encode("utf-8") + else: + # Form data - sign the encoded form + from urllib.parse import urlencode + + raw_body = urlencode(data).encode("utf-8") + + # Add DAPI signature headers if auto-sign is enabled and signer has credentials + if self.auto_sign and self.signer.api_key and self.signer.api_secret: + dapi_headers = self.signer.sign_request( + method=method, + path=relative_path, + body=raw_body, + ) + merged_headers.update(dapi_headers) + + return url, merged_headers, raw_body + + async def _execute_with_retry( + self, + method: str, + url: str, + headers: dict[str, str], + **kwargs: Any, + ) -> aiohttp.ClientResponse: + """ + Execute request with retry and exponential backoff. + """ + session = await self._get_session() + last_response: Optional[aiohttp.ClientResponse] = None + last_exception: Optional[Exception] = None + + for attempt in range(self.max_retries + 1): + try: + response = await session.request( + method=method, + url=url, + headers=headers, + **kwargs, + ) + + # Check if we should retry based on status + if response.status in {429, 500, 502, 503, 504}: + last_response = response + if attempt < self.max_retries: + wait_time = self.backoff_factor * (2 ** attempt) + logger.warning( + "HTTP %s on %s %s, retrying in %.1fs (attempt %d/%d)", + response.status, + method, + url, + wait_time, + attempt + 1, + self.max_retries, + ) + try: + response.release() + except Exception: + pass + await self._async_sleep(wait_time) + continue + else: + # Retries exhausted - raise an error + raise aiohttp.ClientResponseError( + request_info=response.request_info, + history=response.history, + status=response.status, + message=f"HTTP {response.status} after {self.max_retries + 1} attempts", + ) + + return response + + except aiohttp.ServerTimeoutError as e: + last_exception = e + if attempt < self.max_retries: + wait_time = self.backoff_factor * (2 ** attempt) + logger.warning( + "Timeout on %s %s, retrying in %.1fs (attempt %d/%d)", + method, + url, + wait_time, + attempt + 1, + self.max_retries, + ) + await self._async_sleep(wait_time) + continue + + except (aiohttp.ClientConnectionError, aiohttp.ClientOSError) as e: + last_exception = e + if attempt < self.max_retries: + wait_time = self.backoff_factor * (2 ** attempt) + logger.warning( + "Connection error on %s %s, retrying in %.1fs (attempt %d/%d)", + method, + url, + wait_time, + attempt + 1, + self.max_retries, + ) + await self._async_sleep(wait_time) + continue + + except Exception: + raise + + # All retries exhausted + if last_response: + return last_response + if last_exception: + raise last_exception + raise RuntimeError(f"Request failed after {self.max_retries + 1} attempts") + + @staticmethod + async def _async_sleep(seconds: float) -> None: + """Non-blocking sleep.""" + import asyncio + + await asyncio.sleep(seconds) + + async def request( + self, + method: str, + path: str, + *, + headers: Optional[dict[str, str]] = None, + data: Any = None, + json: Any = None, + params: Optional[dict[str, Any]] = None, + timeout: Optional[float] = None, + allow_redirects: bool = True, + ) -> aiohttp.ClientResponse: + """ + Send an HTTP request with DAPI signing and retry. + + Args: + method: HTTP method (GET, POST, PUT, DELETE, PATCH). + path: URL path or full URL. + headers: Additional request headers. + data: Form data or raw bytes. + json: JSON-serializable body (automatically encoded). + params: Query string parameters. + timeout: Override timeout for this request. + allow_redirects: Whether to follow redirects. + + Returns: + aiohttp.ClientResponse + """ + url, merged_headers, raw_body = await self._sign_and_prepare( + method=method, + path=path, + headers=headers, + data=data, + json_body=json, + ) + + request_kwargs: dict[str, Any] = { + "allow_redirects": allow_redirects, } - async def get( - self, - path: str, - params: dict[str, Any] | None = None, - headers: dict[str, str] | None = None, - ) -> Any: - """Send a GET request to the upstream Sage API. + if raw_body is not None: + if json is not None: + request_kwargs["data"] = raw_body + else: + request_kwargs["data"] = raw_body + elif json is not None: + import json as _json - Args: - path: API path (relative to base_url). - params: Optional query parameters. - headers: Optional additional headers. + request_kwargs["data"] = _json.dumps(json).encode("utf-8") - Returns: - Parsed JSON response. - """ - url = f'{self.base_url}{path}' - request_headers = {**self._build_headers(), **(headers or {})} + if data is not None and json is None: + request_kwargs["data"] = data - debug(f'HTTP GET {url} params={params}') + if params: + request_kwargs["params"] = params - # TODO: Replace with actual async HTTP implementation - # using aiohttp or the framework's built-in HTTP client. - # This is a placeholder that will be filled in once the - # specific HTTP library choice is confirmed. - raise NotImplementedError( - 'SageHttpClient.get: HTTP library not yet integrated. ' - 'Implement with aiohttp or framework HTTP client.' + # Per-request timeout override + if timeout is not None: + request_kwargs["timeout"] = ClientTimeout(total=timeout) + + return await self._execute_with_retry( + method=method.upper(), + url=url, + headers=merged_headers, + **request_kwargs, ) - async def post( - self, - path: str, - data: dict[str, Any] | None = None, - headers: dict[str, str] | None = None, - ) -> Any: - """Send a POST request to the upstream Sage API. + async def get(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse: + """Send a GET request.""" + return await self.request("GET", path, **kwargs) - Args: - path: API path (relative to base_url). - data: JSON body data. - headers: Optional additional headers. + async def post(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse: + """Send a POST request.""" + return await self.request("POST", path, **kwargs) - Returns: - Parsed JSON response. - """ - url = f'{self.base_url}{path}' - request_headers = {**self._build_headers(), **(headers or {})} + async def put(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse: + """Send a PUT request.""" + return await self.request("PUT", path, **kwargs) - debug(f'HTTP POST {url}') + async def delete(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse: + """Send a DELETE request.""" + return await self.request("DELETE", path, **kwargs) - # TODO: Replace with actual async HTTP implementation. - raise NotImplementedError( - 'SageHttpClient.post: HTTP library not yet integrated. ' - 'Implement with aiohttp or framework HTTP client.' - ) + async def patch(self, path: str, **kwargs: Any) -> aiohttp.ClientResponse: + """Send a PATCH request.""" + return await self.request("PATCH", path, **kwargs) + + async def close(self) -> None: + """Close the HTTP client and release resources.""" + if self._session and not self._session.closed: + await self._session.close() + self._closed = True + + async def __aenter__(self) -> "SageHttpClient": + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + await self.close() + + def __del__(self) -> None: + # Best-effort cleanup; prefer using async context manager + if not self._closed and self._session and not self._session.closed: + logger.warning( + "SageHttpClient was not properly closed. " + "Use 'async with SageHttpClient(...)' for proper cleanup." + ) diff --git a/sync/cache_tables.sql b/sync/cache_tables.sql new file mode 100644 index 0000000..c96b9b9 --- /dev/null +++ b/sync/cache_tables.sql @@ -0,0 +1,122 @@ +-- ============================================================================= +-- SageAPI Cache Tables DDL +-- These tables store synchronized data from the Sage database. +-- Run against the sageapi database. +-- ============================================================================= + +-- Checkpoint table: tracks the last sync timestamp for each module +CREATE TABLE IF NOT EXISTS sync_state ( + state_key VARCHAR(64) NOT NULL PRIMARY KEY, + last_sync_ts VARCHAR(64) DEFAULT NULL, + created_at VARCHAR(32) NOT NULL, + updated_at VARCHAR(32) NOT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +-- ============================================================================= +-- users_cache: synced from Sage users / organi / organization tables +-- ============================================================================= +CREATE TABLE IF NOT EXISTS users_cache ( + user_id BIGINT NOT NULL PRIMARY KEY, + username VARCHAR(128) DEFAULT NULL, + email VARCHAR(256) DEFAULT NULL, + user_status INT DEFAULT 0, + user_created_at VARCHAR(32) DEFAULT NULL, + user_updated_at VARCHAR(32) DEFAULT NULL, + organi_id BIGINT DEFAULT NULL, + organi_name VARCHAR(256) DEFAULT NULL, + organi_parent_id BIGINT DEFAULT NULL, + org_id BIGINT DEFAULT NULL, + org_name VARCHAR(256) DEFAULT NULL, + org_type VARCHAR(64) DEFAULT NULL, + org_status INT DEFAULT 0, + org_updated_at VARCHAR(32) DEFAULT NULL, + synced_at VARCHAR(32) NOT NULL, + UNIQUE KEY uk_organi (organi_id), + KEY idx_updated (user_updated_at) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +-- ============================================================================= +-- pricing_cache: synced from Sage pricing_program / pricing_program_timing +-- ============================================================================= +CREATE TABLE IF NOT EXISTS pricing_cache ( + program_id BIGINT NOT NULL, + program_name VARCHAR(256) DEFAULT NULL, + program_code VARCHAR(128) DEFAULT NULL, + program_type VARCHAR(64) DEFAULT NULL, + program_status INT DEFAULT 0, + description TEXT DEFAULT NULL, + program_created_at VARCHAR(32) DEFAULT NULL, + program_updated_at VARCHAR(32) DEFAULT NULL, + timing_id BIGINT DEFAULT NULL, + start_time VARCHAR(32) DEFAULT NULL, + end_time VARCHAR(32) DEFAULT NULL, + duration INT DEFAULT NULL, + repeat_rule VARCHAR(256) DEFAULT NULL, + timezone VARCHAR(64) DEFAULT NULL, + timing_status INT DEFAULT 0, + timing_updated_at VARCHAR(32) DEFAULT NULL, + synced_at VARCHAR(32) NOT NULL, + PRIMARY KEY (program_id, timing_id), + KEY idx_program_updated (program_updated_at) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +-- ============================================================================= +-- uapi_cache: synced from Sage uapi / upapp tables +-- ============================================================================= +CREATE TABLE IF NOT EXISTS uapi_cache ( + uapi_id BIGINT NOT NULL, + api_name VARCHAR(256) DEFAULT NULL, + api_path VARCHAR(512) DEFAULT NULL, + api_method VARCHAR(16) DEFAULT 'GET', + api_version VARCHAR(32) DEFAULT NULL, + api_desc TEXT DEFAULT NULL, + uapi_status INT DEFAULT 0, + auth_required TINYINT DEFAULT 0, + uapi_created_at VARCHAR(32) DEFAULT NULL, + uapi_updated_at VARCHAR(32) DEFAULT NULL, + upapp_id BIGINT DEFAULT NULL, + app_name VARCHAR(256) DEFAULT NULL, + app_code VARCHAR(128) DEFAULT NULL, + app_type VARCHAR(64) DEFAULT NULL, + app_desc TEXT DEFAULT NULL, + app_owner VARCHAR(128) DEFAULT NULL, + upapp_status INT DEFAULT 0, + upapp_updated_at VARCHAR(32) DEFAULT NULL, + synced_at VARCHAR(32) NOT NULL, + PRIMARY KEY (uapi_id, upapp_id), + KEY idx_uapi_updated (uapi_updated_at), + KEY idx_upapp_id (upapp_id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +-- ============================================================================= +-- llmage_cache: synced from Sage llm / llmcatelog / llm_api_map tables +-- ============================================================================= +CREATE TABLE IF NOT EXISTS llmage_cache ( + llm_id BIGINT NOT NULL, + model_name VARCHAR(256) DEFAULT NULL, + model_version VARCHAR(64) DEFAULT NULL, + provider VARCHAR(128) DEFAULT NULL, + model_type VARCHAR(64) DEFAULT NULL, + llm_status INT DEFAULT 0, + llm_description TEXT DEFAULT NULL, + llm_created_at VARCHAR(32) DEFAULT NULL, + llm_updated_at VARCHAR(32) DEFAULT NULL, + catelog_id BIGINT DEFAULT NULL, + catelog_name VARCHAR(256) DEFAULT NULL, + catelog_code VARCHAR(128) DEFAULT NULL, + catelog_sort INT DEFAULT 0, + catelog_status INT DEFAULT 0, + catelog_updated_at VARCHAR(32) DEFAULT NULL, + api_map_id BIGINT DEFAULT NULL, + api_name VARCHAR(256) DEFAULT NULL, + api_endpoint VARCHAR(512) DEFAULT NULL, + api_version VARCHAR(32) DEFAULT NULL, + auth_type VARCHAR(32) DEFAULT NULL, + rate_limit INT DEFAULT NULL, + api_map_status INT DEFAULT 0, + api_map_updated_at VARCHAR(32) DEFAULT NULL, + synced_at VARCHAR(32) NOT NULL, + PRIMARY KEY (llm_id, catelog_id, api_map_id), + KEY idx_llm_updated (llm_updated_at), + KEY idx_catelog_id (catelog_id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..d4839a6 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests package diff --git a/tests/__pycache__/__init__.cpython-310.pyc b/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..bb35765 Binary files /dev/null and b/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/tests/__pycache__/test_dapi_auth.cpython-310-pytest-9.0.3.pyc b/tests/__pycache__/test_dapi_auth.cpython-310-pytest-9.0.3.pyc new file mode 100644 index 0000000..58adb53 Binary files /dev/null and b/tests/__pycache__/test_dapi_auth.cpython-310-pytest-9.0.3.pyc differ diff --git a/tests/__pycache__/test_http_client.cpython-310-pytest-9.0.3.pyc b/tests/__pycache__/test_http_client.cpython-310-pytest-9.0.3.pyc new file mode 100644 index 0000000..1a19b98 Binary files /dev/null and b/tests/__pycache__/test_http_client.cpython-310-pytest-9.0.3.pyc differ diff --git a/tests/test_dapi_auth.py b/tests/test_dapi_auth.py new file mode 100644 index 0000000..f68bf47 --- /dev/null +++ b/tests/test_dapi_auth.py @@ -0,0 +1,458 @@ +"""Tests for sageapi.middleware.dapi_auth""" + +import hashlib +import hmac +import time +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.testclient import TestClient + +from sageapi.middleware.dapi_auth import ( + DEFAULT_TIMESTAMP_TOLERANCE_SEC, + DapiAuthError, + DapiAuthMiddleware, + DapiKeyRecord, + authenticate_request, + compute_dapi_signature, + lookup_api_key, + verify_timestamp, +) + + +def make_signature(method: str, path: str, timestamp: str, secret: str, body: bytes = b"") -> str: + """Helper to create a valid DAPI signature.""" + return compute_dapi_signature(method, path, timestamp, secret, body if body else None) + + +# --------------------------------------------------------------------------- +# compute_dapi_signature tests +# --------------------------------------------------------------------------- + +class TestComputeDapiSignature: + def test_basic_signature(self): + sig = compute_dapi_signature("GET", "/api/test", "1700000000.0", "my-secret") + assert isinstance(sig, str) + assert len(sig) == 64 # SHA-256 hex length + + def test_signature_with_body(self): + body = b'{"key": "value"}' + sig = compute_dapi_signature("POST", "/api/test", "1700000000.0", "my-secret", body) + assert isinstance(sig, str) + assert len(sig) == 64 + + def test_signature_deterministic(self): + sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret") + sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret") + assert sig1 == sig2 + + def test_different_body_different_signature(self): + sig1 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body1") + sig2 = compute_dapi_signature("POST", "/path", "123.0", "secret", b"body2") + assert sig1 != sig2 + + def test_empty_body_same_as_no_body(self): + sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret") + sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret", None) + assert sig1 == sig2 + + def test_manual_verification(self): + """Verify the signature matches the expected HMAC-SHA256 output.""" + method, path, ts, secret = "POST", "/v1/chat", "1700000000.0", "test-secret" + body = b'{"message":"hello"}' + body_hash = hashlib.sha256(body).hexdigest() + string_to_sign = f"{method}\n{path}\n{ts}\n{body_hash}" + expected = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha256).hexdigest() + actual = compute_dapi_signature(method, path, ts, secret, body) + assert actual == expected + + +# --------------------------------------------------------------------------- +# verify_timestamp tests +# --------------------------------------------------------------------------- + +class TestVerifyTimestamp: + def test_valid_timestamp(self): + ts = str(time.time()) + assert verify_timestamp(ts) is True + + def test_expired_timestamp(self): + ts = str(time.time() - 600) # 10 minutes ago + assert verify_timestamp(ts) is False + + def test_future_timestamp(self): + ts = str(time.time() + 600) # 10 minutes in future + assert verify_timestamp(ts) is False + + def test_invalid_timestamp(self): + assert verify_timestamp("not-a-number") is False + assert verify_timestamp("") is False + + def test_custom_tolerance(self): + ts = str(time.time() - 10) + assert verify_timestamp(ts, tolerance_sec=5) is False + assert verify_timestamp(ts, tolerance_sec=15) is True + + +# --------------------------------------------------------------------------- +# DapiKeyRecord tests +# --------------------------------------------------------------------------- + +class TestDapiKeyRecord: + def test_active_key(self): + rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None) + assert rec.is_active is True + + def test_inactive_key(self): + rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="inactive", expire_date=None) + assert rec.is_active is False + + def test_not_expired_when_no_expiry(self): + rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=None) + assert rec.is_expired is False + + def test_expired_with_past_date(self): + past = datetime(2020, 1, 1, tzinfo=timezone.utc) + rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=past) + assert rec.is_expired is True + + def test_not_expired_with_future_date(self): + future = datetime(2099, 1, 1, tzinfo=timezone.utc) + rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date=future) + assert rec.is_expired is False + + def test_expired_with_iso_string_past(self): + rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2020-01-01T00:00:00Z") + assert rec.is_expired is True + + def test_not_expired_with_iso_string_future(self): + rec = DapiKeyRecord(id=1, apikey="k1", secret="s1", status="active", expire_date="2099-01-01T00:00:00Z") + assert rec.is_expired is False + + +# --------------------------------------------------------------------------- +# lookup_api_key tests +# --------------------------------------------------------------------------- + +class TestLookupApiKey: + @pytest.fixture(autouse=True) + def mock_sage_modules(self): + """Create mock ahserver and sqlor modules for testing.""" + import sys + import types + + # Create fake modules + ahserver = types.ModuleType("ahserver") + ahserver_serverenv = types.ModuleType("ahserver.serverenv") + ahserver_serverenv.ServerEnv = MagicMock + ahserver.serverenv = ahserver_serverenv + + sqlor = types.ModuleType("sqlor") + sqlor_dbpools = types.ModuleType("sqlor.dbpools") + + mock_ctx = AsyncMock() + mock_sor = AsyncMock() + mock_sor.R = AsyncMock(return_value=[]) + mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx) + sqlor.dbpools = sqlor_dbpools + + sys.modules["ahserver"] = ahserver + sys.modules["ahserver.serverenv"] = ahserver_serverenv + sys.modules["sqlor"] = sqlor + sys.modules["sqlor.dbpools"] = sqlor_dbpools + + yield mock_ctx, mock_sor + + # Cleanup + for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]: + if mod in sys.modules: + del sys.modules[mod] + + @pytest.mark.asyncio + async def test_lookup_found(self, mock_sage_modules): + mock_ctx, mock_sor = mock_sage_modules + mock_rec = { + "id": 1, + "apikey": "test-key", + "secret": "test-secret", + "status": "active", + "expire_date": "2099-01-01T00:00:00Z", + "description": "Test key", + } + mock_sor.R.return_value = [mock_rec] + + result = await lookup_api_key("test-key") + + assert result is not None + assert result.apikey == "test-key" + assert result.secret == "test-secret" + assert result.is_active is True + + @pytest.mark.asyncio + async def test_lookup_not_found(self, mock_sage_modules): + mock_ctx, mock_sor = mock_sage_modules + mock_sor.R.return_value = [] + + result = await lookup_api_key("nonexistent") + + assert result is None + + +# --------------------------------------------------------------------------- +# authenticate_request tests +# --------------------------------------------------------------------------- + +class TestAuthenticateRequest: + def _make_request(self, headers: dict, body: bytes = b"", method: str = "GET", path: str = "/test") -> Request: + scope = { + "type": "http", + "method": method, + "path": path, + "headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()], + "query_string": b"", + } + + async def receive(): + return {"type": "http.request", "body": body, "more_body": False} + + return Request(scope, receive) + + @pytest.mark.asyncio + async def test_missing_api_key_header(self): + req = self._make_request({ + "X-DAPI-Timestamp": str(time.time()), + "X-DAPI-Signature": "abc", + }) + with pytest.raises(DapiAuthError) as exc: + await authenticate_request(req) + assert exc.value.status_code == 401 + assert "X-DAPI-Key" in exc.value.detail + + @pytest.mark.asyncio + async def test_missing_timestamp_header(self): + req = self._make_request({ + "X-DAPI-Key": "test-key", + "X-DAPI-Signature": "abc", + }) + with pytest.raises(DapiAuthError) as exc: + await authenticate_request(req) + assert exc.value.status_code == 401 + assert "X-DAPI-Timestamp" in exc.value.detail + + @pytest.mark.asyncio + async def test_missing_signature_header(self): + req = self._make_request({ + "X-DAPI-Key": "test-key", + "X-DAPI-Timestamp": str(time.time()), + }) + with pytest.raises(DapiAuthError) as exc: + await authenticate_request(req) + assert exc.value.status_code == 401 + assert "X-DAPI-Signature" in exc.value.detail + + @pytest.mark.asyncio + async def test_expired_timestamp(self): + ts = str(time.time() - 600) + req = self._make_request({ + "X-DAPI-Key": "test-key", + "X-DAPI-Timestamp": ts, + "X-DAPI-Signature": "abc", + }) + with pytest.raises(DapiAuthError) as exc: + await authenticate_request(req) + assert exc.value.status_code == 401 + assert "timestamp" in exc.value.detail.lower() + + @pytest.mark.asyncio + async def test_invalid_api_key(self): + ts = str(time.time()) + req = self._make_request({ + "X-DAPI-Key": "bad-key", + "X-DAPI-Timestamp": ts, + "X-DAPI-Signature": "abc", + }) + + with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=None): + with pytest.raises(DapiAuthError) as exc: + await authenticate_request(req) + assert exc.value.status_code == 401 + assert "Invalid API key" in exc.value.detail + + @pytest.mark.asyncio + async def test_inactive_key(self): + ts = str(time.time()) + req = self._make_request({ + "X-DAPI-Key": "test-key", + "X-DAPI-Timestamp": ts, + "X-DAPI-Signature": "abc", + }) + key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="secret", status="inactive", expire_date=None) + + with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec): + with pytest.raises(DapiAuthError) as exc: + await authenticate_request(req) + assert exc.value.status_code == 403 + assert "inactive" in exc.value.detail.lower() + + @pytest.mark.asyncio + async def test_expired_key(self): + ts = str(time.time()) + req = self._make_request({ + "X-DAPI-Key": "test-key", + "X-DAPI-Timestamp": ts, + "X-DAPI-Signature": "abc", + }) + key_rec = DapiKeyRecord( + id=1, apikey="test-key", secret="secret", status="active", + expire_date=datetime(2020, 1, 1, tzinfo=timezone.utc), + ) + + with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec): + with pytest.raises(DapiAuthError) as exc: + await authenticate_request(req) + assert exc.value.status_code == 403 + assert "expired" in exc.value.detail.lower() + + @pytest.mark.asyncio + async def test_invalid_signature(self): + ts = str(time.time()) + body = b'{"test": "data"}' + req = self._make_request({ + "X-DAPI-Key": "test-key", + "X-DAPI-Timestamp": ts, + "X-DAPI-Signature": "invalid-signature", + }, body=body, method="POST") + + key_rec = DapiKeyRecord(id=1, apikey="test-key", secret="real-secret", status="active", expire_date=None) + + with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec): + with pytest.raises(DapiAuthError) as exc: + await authenticate_request(req) + assert exc.value.status_code == 401 + assert "Invalid signature" in exc.value.detail + + @pytest.mark.asyncio + async def test_valid_authentication(self): + ts = str(time.time()) + secret = "my-secret" + body = b'{"test": "data"}' + sig = compute_dapi_signature("POST", "/test", ts, secret, body) + + req = self._make_request({ + "X-DAPI-Key": "test-key", + "X-DAPI-Timestamp": ts, + "X-DAPI-Signature": sig, + }, body=body, method="POST") + + key_rec = DapiKeyRecord(id=1, apikey="test-key", secret=secret, status="active", expire_date=None) + + with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec): + result = await authenticate_request(req) + + assert result.apikey == "test-key" + assert result.secret == secret + + +# --------------------------------------------------------------------------- +# DapiAuthMiddleware tests +# --------------------------------------------------------------------------- + +class TestDapiAuthMiddleware: + @pytest.fixture(autouse=True) + def mock_sage_modules(self): + """Create mock ahserver and sqlor modules.""" + import sys + import types + + ahserver = types.ModuleType("ahserver") + ahserver_serverenv = types.ModuleType("ahserver.serverenv") + ahserver_serverenv.ServerEnv = MagicMock + ahserver.serverenv = ahserver_serverenv + + sqlor = types.ModuleType("sqlor") + sqlor_dbpools = types.ModuleType("sqlor.dbpools") + mock_sor = AsyncMock() + mock_sor.R = AsyncMock(return_value=[]) + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_sor) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + sqlor_dbpools.get_sor_context = MagicMock(return_value=mock_ctx) + sqlor.dbpools = sqlor_dbpools + + sys.modules["ahserver"] = ahserver + sys.modules["ahserver.serverenv"] = ahserver_serverenv + sys.modules["sqlor"] = sqlor + sys.modules["sqlor.dbpools"] = sqlor_dbpools + + yield + + for mod in ["ahserver", "ahserver.serverenv", "sqlor", "sqlor.dbpools"]: + if mod in sys.modules: + del sys.modules[mod] + + def _build_test_app(self, exclude_paths=None, tolerance_sec=DEFAULT_TIMESTAMP_TOLERANCE_SEC): + from starlette.applications import Starlette + from starlette.responses import PlainTextResponse + from starlette.routing import Route + + async def protected(request: Request): + return PlainTextResponse("OK - authenticated") + + async def health(request: Request): + return PlainTextResponse("healthy") + + app = Starlette( + routes=[ + Route("/protected", endpoint=protected, methods=["GET"]), + Route("/health", endpoint=health, methods=["GET"]), + ] + ) + + app.add_middleware( + DapiAuthMiddleware, + exclude_paths=exclude_paths or ["/health"], + tolerance_sec=tolerance_sec, + ) + return app + + def test_unauthenticated_request_rejected(self): + app = self._build_test_app() + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 401 + assert "error" in resp.json() + + def test_excluded_path_bypasses_auth(self): + app = self._build_test_app() + client = TestClient(app) + resp = client.get("/health") + assert resp.status_code == 200 + assert resp.text == "healthy" + + def test_valid_request_accepted(self): + ts = str(time.time()) + secret = "test-secret" + sig = compute_dapi_signature("GET", "/protected", ts, secret) + + key_rec = DapiKeyRecord(id=1, apikey="valid-key", secret=secret, status="active", expire_date=None) + + app = self._build_test_app() + client = TestClient(app) + + with patch("sageapi.middleware.dapi_auth.lookup_api_key", return_value=key_rec): + resp = client.get( + "/protected", + headers={ + "X-DAPI-Key": "valid-key", + "X-DAPI-Timestamp": ts, + "X-DAPI-Signature": sig, + }, + ) + + assert resp.status_code == 200 + assert resp.text == "OK - authenticated" diff --git a/tests/test_http_client.py b/tests/test_http_client.py new file mode 100644 index 0000000..8d7c4a4 --- /dev/null +++ b/tests/test_http_client.py @@ -0,0 +1,288 @@ +"""Tests for sageapi.utils.http_client""" + +import hashlib +import hmac +import json +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest + +from sageapi.utils.http_client import ( + DAPISigner, + SageHttpClient, + RetryConfig, + compute_dapi_signature, +) + + +# --------------------------------------------------------------------------- +# compute_dapi_signature tests (also imported from middleware) +# --------------------------------------------------------------------------- + +class TestComputeDapiSignature: + def test_basic(self): + sig = compute_dapi_signature("GET", "/api/test", "1700000000.0", "secret") + assert len(sig) == 64 + + def test_with_body(self): + body = b'{"msg":"hi"}' + sig = compute_dapi_signature("POST", "/api/test", "1700000000.0", "secret", body) + assert len(sig) == 64 + + def test_deterministic(self): + sig1 = compute_dapi_signature("GET", "/path", "123.0", "secret") + sig2 = compute_dapi_signature("GET", "/path", "123.0", "secret") + assert sig1 == sig2 + + +# --------------------------------------------------------------------------- +# DAPISigner tests +# --------------------------------------------------------------------------- + +class TestDAPISigner: + def test_sign_request(self): + signer = DAPISigner(api_key="my-key", api_secret="my-secret") + headers = signer.sign_request("GET", "/test") + + assert "X-DAPI-Key" in headers + assert "X-DAPI-Timestamp" in headers + assert "X-DAPI-Signature" in headers + assert headers["X-DAPI-Key"] == "my-key" + assert headers["X-DAPI-Timestamp"] # should be a valid timestamp string + + def test_sign_request_with_body(self): + signer = DAPISigner(api_key="k", api_secret="s") + body = b'{"test": true}' + headers = signer.sign_request("POST", "/v1/chat", body) + + assert headers["X-DAPI-Key"] == "k" + # Verify timestamp is recent + ts = float(headers["X-DAPI-Timestamp"]) + assert abs(time.time() - ts) < 2 + + def test_sign_request_produces_valid_signature(self): + signer = DAPISigner(api_key="k", api_secret="secret") + body = b'hello' + headers = signer.sign_request("POST", "/path", body) + + expected = compute_dapi_signature("POST", "/path", headers["X-DAPI-Timestamp"], "secret", body) + assert headers["X-DAPI-Signature"] == expected + + def test_sign_request_uppercases_method(self): + signer = DAPISigner(api_key="k", api_secret="s") + headers = signer.sign_request("get", "/path") + expected = compute_dapi_signature("GET", "/path", headers["X-DAPI-Timestamp"], "s") + assert headers["X-DAPI-Signature"] == expected + + +# --------------------------------------------------------------------------- +# RetryConfig tests +# --------------------------------------------------------------------------- + +class TestRetryConfig: + def test_defaults(self): + config = RetryConfig() + assert config.max_retries == 3 + assert config.backoff_factor == 0.5 + assert 429 in config.retry_on_status + assert 500 in config.retry_on_status + + def test_custom(self): + config = RetryConfig(max_retries=5, backoff_factor=1.0) + assert config.max_retries == 5 + assert config.backoff_factor == 1.0 + + +# --------------------------------------------------------------------------- +# SageHttpClient tests +# --------------------------------------------------------------------------- + +class TestSageHttpClient: + def test_init_defaults(self): + client = SageHttpClient( + base_url="https://api.example.com", + api_key="key", + api_secret="secret", + ) + assert client.base_url == "https://api.example.com" + assert client.signer.api_key == "key" + assert client.signer.api_secret == "secret" + assert client.auto_sign is True + + def test_init_with_custom_signer(self): + signer = DAPISigner(api_key="k", api_secret="s") + client = SageHttpClient(base_url="https://api.example.com", signer=signer) + assert client.signer is signer + + def test_init_without_auto_sign(self): + client = SageHttpClient( + base_url="https://api.example.com", + auto_sign=False, + ) + assert client.auto_sign is False + + def test_build_url(self): + client = SageHttpClient(base_url="https://api.example.com") + assert client._build_url("/v1/test") == "https://api.example.com/v1/test" + assert client._build_url("v1/test") == "https://api.example.com/v1/test" + # Full URL should pass through + assert client._build_url("https://other.com/path") == "https://other.com/path" + + def test_get_relative_path(self): + client = SageHttpClient(base_url="https://api.example.com") + assert client._get_relative_path("/v1/chat") == "/v1/chat" + assert client._get_relative_path("https://api.example.com/v1/chat?page=1") == "/v1/chat" + + def test_sign_and_prepare_adds_dapi_headers(self): + client = SageHttpClient( + base_url="https://api.example.com", + api_key="k", + api_secret="s", + ) + + import asyncio + + url, headers, body = asyncio.get_event_loop().run_until_complete( + client._sign_and_prepare("GET", "/test") + ) + + assert "X-DAPI-Key" in headers + assert "X-DAPI-Timestamp" in headers + assert "X-DAPI-Signature" in headers + assert headers["X-DAPI-Key"] == "k" + + def test_sign_and_prepare_with_json_body(self): + client = SageHttpClient( + base_url="https://api.example.com", + api_key="k", + api_secret="s", + ) + + import asyncio + + url, headers, body = asyncio.get_event_loop().run_until_complete( + client._sign_and_prepare("POST", "/test", json_body={"msg": "hi"}) + ) + + assert body is not None + assert json.loads(body) == {"msg": "hi"} + assert headers.get("Content-Type") == "application/json" + + @pytest.mark.asyncio + async def test_context_manager(self): + client = SageHttpClient( + base_url="https://httpbin.org", + api_key="k", + api_secret="s", + max_retries=0, + ) + async with client as c: + assert c is client + assert client._closed is True + + @pytest.mark.asyncio + async def test_close(self): + client = SageHttpClient( + base_url="https://httpbin.org", + api_key="k", + api_secret="s", + max_retries=0, + ) + await client._get_session() # create session + await client.close() + assert client._closed is True + + @pytest.mark.asyncio + async def test_request_with_retry_on_502(self): + """Test that 502 responses trigger retries and raise after exhaustion.""" + client = SageHttpClient( + base_url="", + api_key="k", + api_secret="s", + max_retries=2, + backoff_factor=0.01, # fast for tests + auto_sign=False, + ) + + mock_response = AsyncMock() + mock_response.status = 502 + mock_response.request_info = MagicMock() + mock_response.history = [] + mock_response.release = MagicMock() + + mock_session = AsyncMock() + mock_session.request = AsyncMock(return_value=mock_response) + mock_session.closed = False + + client._session = mock_session + + with pytest.raises(aiohttp.ClientResponseError) as exc_info: + await client.request("GET", "/test") + + assert exc_info.value.status == 502 + # Should have been called 3 times (1 initial + 2 retries) + assert mock_session.request.call_count == 3 + + +class TestSageHttpClientIntegration: + """Integration tests against httpbin.org (requires network).""" + + @pytest.mark.asyncio + async def test_get_request(self): + client = SageHttpClient( + base_url="https://httpbin.org", + auto_sign=False, + max_retries=0, + timeout=10.0, + ) + async with client: + resp = await client.get("/get") + assert resp.status == 200 + data = await resp.json() + assert "url" in data + + @pytest.mark.asyncio + async def test_post_request_with_json(self): + client = SageHttpClient( + base_url="https://httpbin.org", + auto_sign=False, + max_retries=0, + timeout=10.0, + ) + async with client: + resp = await client.post("/post", json={"key": "value"}) + assert resp.status == 200 + data = await resp.json() + assert data["json"] == {"key": "value"} + + @pytest.mark.asyncio + async def test_headers_sent(self): + client = SageHttpClient( + base_url="https://httpbin.org", + headers={"X-Custom-Header": "test-value"}, + auto_sign=False, + max_retries=0, + timeout=10.0, + ) + async with client: + resp = await client.get("/headers") + assert resp.status == 200 + data = await resp.json() + assert data["headers"].get("X-Custom-Header") == "test-value" + + @pytest.mark.asyncio + async def test_query_params(self): + client = SageHttpClient( + base_url="https://httpbin.org", + auto_sign=False, + max_retries=0, + timeout=10.0, + ) + async with client: + resp = await client.get("/get", params={"foo": "bar", "baz": "qux"}) + assert resp.status == 200 + data = await resp.json() + assert data["args"]["foo"] == "bar" + assert data["args"]["baz"] == "qux"