diff --git a/main.py b/main.py index 20e6417..69e3821 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ Hermes Service with global session management and Nginx security support """ +import inspect import os import sys import asyncio @@ -164,12 +165,21 @@ def get_real_ip(request: Request) -> str: def validate_ip_and_apikey(): """Decorator to validate IP and API key for protected endpoints""" def decorator(func): + # Preserve original function signature + prepend http_request: Request + # so FastAPI can inject both the Request and the original parameters. + original_sig = inspect.signature(func) + http_request_param = inspect.Parameter( + 'http_request', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request + ) + new_params = [http_request_param] + list(original_sig.parameters.values()) + new_sig = original_sig.replace(parameters=new_params) + @wraps(func) - async def wrapper(request: Request, *args, **kwargs): + async def wrapper(http_request: Request, *args, **kwargs): # IP validation if config['security']['enable_ip_check']: - client_ip = get_real_ip(request) - print(f"DEBUG: Client IP: {client_ip}") # Debug log + client_ip = get_real_ip(http_request) + print(f"DEBUG: Client IP: {client_ip}") allowed = False for allowed_ip in config['security']['allowed_ips']: try: @@ -177,43 +187,40 @@ def validate_ip_and_apikey(): allowed = True break except ValueError: - # Invalid IP or network, skip continue - + if not allowed: raise HTTPException(status_code=403, detail="IP address not allowed") - + # API Key validation if config['security']['enable_api_key']: provided_key = None - + if config['security']['auth_method'] == 'bearer': - # Check Authorization header for Bearer token - auth_header = request.headers.get("authorization") + auth_header = http_request.headers.get("authorization") if auth_header and auth_header.lower().startswith("bearer "): - provided_key = auth_header[7:].strip() # Remove "Bearer " prefix + provided_key = auth_header[7:].strip() else: - # Check custom header (default: X-API-Key) api_key_header = config['security']['api_key_header'] - provided_key = request.headers.get(api_key_header) - + provided_key = http_request.headers.get(api_key_header) + if not provided_key: raise HTTPException(status_code=401, detail="API key required") - + valid_key = False for key_config in config['security']['api_keys']: if key_config['key'] == provided_key: - # Check expiration if set if 'expires_at' in key_config and key_config['expires_at']: - # TODO: Implement expiration check pass valid_key = True break - + if not valid_key: raise HTTPException(status_code=401, detail="Invalid API key") - + return await func(*args, **kwargs) + + wrapper.__signature__ = new_sig return wrapper return decorator