fix: 修复装饰器 Request 参数与端点 body 参数名冲突

- 使用 inspect.signature 保留原函数签名并在前面追加 http_request: Request
  参数,使 FastAPI 能同时注入 Request 对象和原有的 body 参数
- 将装饰器内部的 request 引用全部改为 http_request,避免与端点函数的
  request: SessionCreateRequest 等 body 参数冲突
This commit is contained in:
yumoqing 2026-04-27 15:53:46 +08:00
parent 6adae569be
commit 5f962dcc90

29
main.py
View File

@ -3,6 +3,7 @@
Hermes Service with global session management and Nginx security support Hermes Service with global session management and Nginx security support
""" """
import inspect
import os import os
import sys import sys
import asyncio import asyncio
@ -164,12 +165,21 @@ def get_real_ip(request: Request) -> str:
def validate_ip_and_apikey(): def validate_ip_and_apikey():
"""Decorator to validate IP and API key for protected endpoints""" """Decorator to validate IP and API key for protected endpoints"""
def decorator(func): def decorator(func):
# Preserve original function signature + prepend http_request: Request
# so FastAPI can inject both the Request and the original parameters.
original_sig = inspect.signature(func)
http_request_param = inspect.Parameter(
'http_request', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
)
new_params = [http_request_param] + list(original_sig.parameters.values())
new_sig = original_sig.replace(parameters=new_params)
@wraps(func) @wraps(func)
async def wrapper(request: Request, *args, **kwargs): async def wrapper(http_request: Request, *args, **kwargs):
# IP validation # IP validation
if config['security']['enable_ip_check']: if config['security']['enable_ip_check']:
client_ip = get_real_ip(request) client_ip = get_real_ip(http_request)
print(f"DEBUG: Client IP: {client_ip}") # Debug log print(f"DEBUG: Client IP: {client_ip}")
allowed = False allowed = False
for allowed_ip in config['security']['allowed_ips']: for allowed_ip in config['security']['allowed_ips']:
try: try:
@ -177,7 +187,6 @@ def validate_ip_and_apikey():
allowed = True allowed = True
break break
except ValueError: except ValueError:
# Invalid IP or network, skip
continue continue
if not allowed: if not allowed:
@ -188,14 +197,12 @@ def validate_ip_and_apikey():
provided_key = None provided_key = None
if config['security']['auth_method'] == 'bearer': if config['security']['auth_method'] == 'bearer':
# Check Authorization header for Bearer token auth_header = http_request.headers.get("authorization")
auth_header = request.headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer "): if auth_header and auth_header.lower().startswith("bearer "):
provided_key = auth_header[7:].strip() # Remove "Bearer " prefix provided_key = auth_header[7:].strip()
else: else:
# Check custom header (default: X-API-Key)
api_key_header = config['security']['api_key_header'] api_key_header = config['security']['api_key_header']
provided_key = request.headers.get(api_key_header) provided_key = http_request.headers.get(api_key_header)
if not provided_key: if not provided_key:
raise HTTPException(status_code=401, detail="API key required") raise HTTPException(status_code=401, detail="API key required")
@ -203,9 +210,7 @@ def validate_ip_and_apikey():
valid_key = False valid_key = False
for key_config in config['security']['api_keys']: for key_config in config['security']['api_keys']:
if key_config['key'] == provided_key: if key_config['key'] == provided_key:
# Check expiration if set
if 'expires_at' in key_config and key_config['expires_at']: if 'expires_at' in key_config and key_config['expires_at']:
# TODO: Implement expiration check
pass pass
valid_key = True valid_key = True
break break
@ -214,6 +219,8 @@ def validate_ip_and_apikey():
raise HTTPException(status_code=401, detail="Invalid API key") raise HTTPException(status_code=401, detail="Invalid API key")
return await func(*args, **kwargs) return await func(*args, **kwargs)
wrapper.__signature__ = new_sig
return wrapper return wrapper
return decorator return decorator