- 修复 validate_ip_and_apikey 装饰器无法获取 Request 对象的问题: 原代码通过遍历 args 查找 Request 对象,但 FastAPI 端点使用 Pydantic 模型作为参数时找不到 Request,导致装饰器跳过验证并可能引发异常。 改为显式声明 request: Request 作为 wrapper 的第一个参数,由 FastAPI 自动注入。 - 增强 ensure_user_hermes_env 的错误处理: 添加 try/except 包裹,检查 BASE_HERMES_PATH 是否存在, 将未处理的异常转为带详细信息的 HTTPException(500)。
443 lines
16 KiB
Python
443 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Hermes Service with global session management and Nginx security support
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import asyncio
|
|
import uuid
|
|
import threading
|
|
import json as json_mod
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from typing import Optional, Dict, Any, List
|
|
import json
|
|
import shutil
|
|
import ipaddress
|
|
import yaml
|
|
from functools import wraps
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Resolve Hermes base path dynamically via get_hermes_home() when available,
|
|
# with fallback to the hardcoded default for backward compatibility.
|
|
# ---------------------------------------------------------------------------
|
|
def _resolve_hermes_home() -> str:
|
|
"""Resolve the Hermes home directory.
|
|
|
|
Tries:
|
|
1. HERMES_HOME env var (explicit override)
|
|
2. get_hermes_home() from hermes_constants module
|
|
3. Fallback to the default path
|
|
"""
|
|
env_override = os.environ.get("HERMES_HOME")
|
|
if env_override:
|
|
return env_override
|
|
try:
|
|
# Add the hermes-agent directory to the Python path so we can import
|
|
# from the bundled agent without installing it.
|
|
default_path = "/d/hermesai/.hermes/hermes-agent"
|
|
if default_path not in sys.path:
|
|
sys.path.insert(0, default_path)
|
|
from hermes_constants import get_hermes_home
|
|
return get_hermes_home()
|
|
except Exception:
|
|
return "/d/hermesai/.hermes"
|
|
|
|
HERMES_HOME = _resolve_hermes_home()
|
|
BASE_HERMES_PATH = os.path.join(HERMES_HOME, "hermes-agent")
|
|
|
|
# Clean user data directory structure: /d/hermesai/users/{user_id}/.hermes
|
|
USERS_BASE = "/d/hermesai/users"
|
|
|
|
# Load configuration
|
|
CONFIG_FILE = os.path.join(os.path.dirname(__file__), "config.yaml")
|
|
if os.path.exists(CONFIG_FILE):
|
|
with open(CONFIG_FILE, 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
else:
|
|
# Default configuration
|
|
config = {
|
|
'security': {
|
|
'enable_ip_check': False,
|
|
'allowed_ips': ['127.0.0.1', '::1'],
|
|
'enable_api_key': False,
|
|
'api_keys': [],
|
|
'api_key_header': 'X-API-Key',
|
|
'auth_method': 'header' # 'header' or 'bearer'
|
|
},
|
|
'nginx': {
|
|
'trusted_proxies': ['127.0.0.1', '::1'],
|
|
'enable_real_ip': True
|
|
},
|
|
'service': {
|
|
'host': '127.0.0.1',
|
|
'port': 9120,
|
|
'log_level': 'info'
|
|
},
|
|
'cors': {
|
|
'allow_origins': ['*'],
|
|
'allow_credentials': True,
|
|
'allow_methods': ['*'],
|
|
'allow_headers': ['*']
|
|
}
|
|
}
|
|
|
|
print(f"Security config - IP check: {config['security']['enable_ip_check']}, API key: {config['security']['enable_api_key']}")
|
|
|
|
app = FastAPI(title="Hermes Service API", version="1.3.0")
|
|
|
|
# Configure CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=config['cors']['allow_origins'],
|
|
allow_credentials=config['cors']['allow_credentials'],
|
|
allow_methods=config['cors']['allow_methods'],
|
|
allow_headers=config['cors']['allow_headers'],
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Persistent session registry: global_session_id -> {user_id, local_session_id, ...}
|
|
# Saved to disk so sessions survive service restarts.
|
|
# ---------------------------------------------------------------------------
|
|
SESSION_STORE_FILE = os.path.join(os.path.dirname(__file__), "data", "sessions.json")
|
|
os.makedirs(os.path.dirname(SESSION_STORE_FILE), exist_ok=True)
|
|
|
|
global_sessions = {}
|
|
_sessions_lock = threading.Lock()
|
|
|
|
def _save_sessions():
|
|
"""Persist global_sessions to disk."""
|
|
try:
|
|
with _sessions_lock:
|
|
with open(SESSION_STORE_FILE, 'w') as f:
|
|
json_mod.dump(global_sessions, f, indent=2)
|
|
except Exception as e:
|
|
print(f"WARNING: Failed to save session store: {e}")
|
|
|
|
def _load_sessions():
|
|
"""Load global_sessions from disk."""
|
|
global global_sessions
|
|
if os.path.exists(SESSION_STORE_FILE):
|
|
try:
|
|
with open(SESSION_STORE_FILE, 'r') as f:
|
|
global_sessions = json_mod.load(f)
|
|
print(f"Loaded {len(global_sessions)} sessions from store")
|
|
except Exception as e:
|
|
print(f"WARNING: Failed to load session store, starting fresh: {e}")
|
|
global_sessions = {}
|
|
|
|
_load_sessions()
|
|
|
|
def get_real_ip(request: Request) -> str:
|
|
"""Get the real client IP address, considering X-Forwarded-For header"""
|
|
if not config['nginx']['enable_real_ip']:
|
|
return request.client.host
|
|
|
|
# Check if the request comes from a trusted proxy
|
|
client_host = request.client.host
|
|
trusted_proxies = config['nginx']['trusted_proxies']
|
|
|
|
is_trusted = False
|
|
for trusted_proxy in trusted_proxies:
|
|
try:
|
|
if ipaddress.ip_address(client_host) in ipaddress.ip_network(trusted_proxy, strict=False):
|
|
is_trusted = True
|
|
break
|
|
except ValueError:
|
|
# Invalid IP or network, skip
|
|
continue
|
|
|
|
if is_trusted:
|
|
# Get the real IP from X-Forwarded-For header
|
|
forwarded_for = request.headers.get("x-forwarded-for")
|
|
if forwarded_for:
|
|
# X-Forwarded-For can contain multiple IPs, take the first one
|
|
real_ip = forwarded_for.split(",")[0].strip()
|
|
return real_ip
|
|
|
|
return client_host
|
|
|
|
def validate_ip_and_apikey():
|
|
"""Decorator to validate IP and API key for protected endpoints"""
|
|
def decorator(func):
|
|
@wraps(func)
|
|
async def wrapper(request: Request, *args, **kwargs):
|
|
# IP validation
|
|
if config['security']['enable_ip_check']:
|
|
client_ip = get_real_ip(request)
|
|
print(f"DEBUG: Client IP: {client_ip}") # Debug log
|
|
allowed = False
|
|
for allowed_ip in config['security']['allowed_ips']:
|
|
try:
|
|
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(allowed_ip, strict=False):
|
|
allowed = True
|
|
break
|
|
except ValueError:
|
|
# Invalid IP or network, skip
|
|
continue
|
|
|
|
if not allowed:
|
|
raise HTTPException(status_code=403, detail="IP address not allowed")
|
|
|
|
# API Key validation
|
|
if config['security']['enable_api_key']:
|
|
provided_key = None
|
|
|
|
if config['security']['auth_method'] == 'bearer':
|
|
# Check Authorization header for Bearer token
|
|
auth_header = request.headers.get("authorization")
|
|
if auth_header and auth_header.lower().startswith("bearer "):
|
|
provided_key = auth_header[7:].strip() # Remove "Bearer " prefix
|
|
else:
|
|
# Check custom header (default: X-API-Key)
|
|
api_key_header = config['security']['api_key_header']
|
|
provided_key = request.headers.get(api_key_header)
|
|
|
|
if not provided_key:
|
|
raise HTTPException(status_code=401, detail="API key required")
|
|
|
|
valid_key = False
|
|
for key_config in config['security']['api_keys']:
|
|
if key_config['key'] == provided_key:
|
|
# Check expiration if set
|
|
if 'expires_at' in key_config and key_config['expires_at']:
|
|
# TODO: Implement expiration check
|
|
pass
|
|
valid_key = True
|
|
break
|
|
|
|
if not valid_key:
|
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
|
|
return await func(*args, **kwargs)
|
|
return wrapper
|
|
return decorator
|
|
|
|
def get_user_hermes_path(user_id: str) -> str:
|
|
"""Get isolated Hermes environment path for a user"""
|
|
if not user_id or user_id == "anonymous":
|
|
user_id = "anonymous"
|
|
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in "-_.")
|
|
return os.path.join(USERS_BASE, safe_user_id)
|
|
|
|
def ensure_user_hermes_env(user_id: str):
|
|
"""Ensure user has isolated Hermes environment"""
|
|
try:
|
|
user_base_path = get_user_hermes_path(user_id)
|
|
user_hermes_path = os.path.join(user_base_path, "hermes-agent")
|
|
user_dot_hermes = os.path.join(user_base_path, ".hermes")
|
|
|
|
if not os.path.exists(user_hermes_path):
|
|
os.makedirs(user_base_path, exist_ok=True, mode=0o700)
|
|
if not os.path.exists(BASE_HERMES_PATH):
|
|
raise FileNotFoundError(f"Base Hermes path not found: {BASE_HERMES_PATH}")
|
|
shutil.copytree(
|
|
BASE_HERMES_PATH,
|
|
user_hermes_path,
|
|
dirs_exist_ok=True,
|
|
ignore=shutil.ignore_patterns('.git', '__pycache__', '*.pyc', '.venv', 'web_dist')
|
|
)
|
|
os.makedirs(user_dot_hermes, exist_ok=True, mode=0o700)
|
|
venv_link = os.path.join(user_hermes_path, '.venv')
|
|
if not os.path.exists(venv_link):
|
|
os.symlink(
|
|
os.path.join(BASE_HERMES_PATH, '.venv'),
|
|
venv_link
|
|
)
|
|
|
|
return user_hermes_path
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to ensure user Hermes env for {user_id}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
raise HTTPException(status_code=500, detail=f"Failed to initialize user environment: {str(e)}")
|
|
|
|
@app.get("/health")
|
|
@validate_ip_and_apikey()
|
|
async def health_check():
|
|
return {"status": "healthy", "service": "hermes-service", "multi_user": True}
|
|
|
|
@app.get("/api/v1/status")
|
|
@validate_ip_and_apikey()
|
|
async def get_hermes_status(request: Request):
|
|
try:
|
|
result = await execute_hermes_command(["--version"], user_id=None)
|
|
return {"status": "running", "version": result.get("stdout", "").strip()}
|
|
except Exception as e:
|
|
return {"status": "error", "error": str(e)}
|
|
|
|
class SessionCreateRequest(BaseModel):
|
|
user_id: str
|
|
initial_message: Optional[str] = None
|
|
|
|
class CommandRequest(BaseModel):
|
|
command: list[str]
|
|
user_context: Optional[Dict[str, Any]] = None
|
|
timeout: int = 300
|
|
|
|
class SessionMessageRequest(BaseModel):
|
|
message: str
|
|
user_context: Optional[Dict[str, Any]] = None
|
|
|
|
@app.post("/api/v1/sessions")
|
|
@validate_ip_and_apikey()
|
|
async def create_session(request: SessionCreateRequest):
|
|
if not request.user_id:
|
|
raise HTTPException(status_code=400, detail="user_id is required")
|
|
|
|
# Create global session ID
|
|
global_session_id = str(uuid.uuid4())
|
|
|
|
# Ensure user environment exists
|
|
user_hermes_path = ensure_user_hermes_env(request.user_id)
|
|
|
|
# For now, we'll use the global session ID as the local session ID
|
|
# In a production system, we might want to create a proper local session
|
|
local_session_id = global_session_id
|
|
|
|
# Register global session
|
|
global_sessions[global_session_id] = {
|
|
"user_id": request.user_id,
|
|
"local_session_id": local_session_id,
|
|
"created_at": datetime.now().isoformat(),
|
|
"hermes_path": user_hermes_path,
|
|
"status": "active"
|
|
}
|
|
_save_sessions()
|
|
|
|
return {
|
|
"session_id": global_session_id,
|
|
"user_id": request.user_id,
|
|
"hermes_path": user_hermes_path,
|
|
"status": "created"
|
|
}
|
|
|
|
@app.post("/api/v1/execute")
|
|
@validate_ip_and_apikey()
|
|
async def execute_command(request: CommandRequest):
|
|
# If no user context provided, use anonymous user
|
|
user_id = None
|
|
if request.user_context:
|
|
user_id = request.user_context.get("user_id")
|
|
|
|
result = await execute_hermes_command(
|
|
request.command,
|
|
user_id=user_id,
|
|
timeout=request.timeout
|
|
)
|
|
|
|
if not result["success"]:
|
|
raise HTTPException(status_code=500, detail=result["stderr"])
|
|
|
|
return result
|
|
|
|
@app.post("/api/v1/sessions/{session_id}/messages")
|
|
@validate_ip_and_apikey()
|
|
async def send_session_message(session_id: str, request: SessionMessageRequest):
|
|
if session_id not in global_sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
session_info = global_sessions[session_id]
|
|
user_id = session_info["user_id"]
|
|
local_session_id = session_info["local_session_id"]
|
|
|
|
# For chat messages, we need to think about how to properly integrate
|
|
# with Hermes' session system. For now, we'll execute commands directly.
|
|
# In production, this would interface with Hermes' internal session management.
|
|
|
|
# Execute the message as a command using non-interactive mode
|
|
command_args = ["chat", "-q", request.message, "--source", "tool"]
|
|
|
|
result = await execute_hermes_command(
|
|
command_args,
|
|
user_id=user_id,
|
|
timeout=300
|
|
)
|
|
|
|
response_content = result.get("stdout", "") if result["success"] else result.get("stderr", "Command failed")
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"response": response_content,
|
|
"success": result["success"]
|
|
}
|
|
|
|
@app.get("/api/v1/sessions/{session_id}")
|
|
@validate_ip_and_apikey()
|
|
async def get_session(session_id: str):
|
|
if session_id not in global_sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
session_info = global_sessions[session_id].copy()
|
|
session_info.pop("hermes_path", None) # Don't expose internal paths
|
|
return session_info
|
|
|
|
async def execute_hermes_command(command_args, user_id=None, timeout=300):
|
|
try:
|
|
if user_id:
|
|
user_base_path = get_user_hermes_path(user_id)
|
|
user_hermes_path = os.path.join(user_base_path, "hermes-agent")
|
|
hermes_dot_path = os.path.join(user_base_path, ".hermes")
|
|
else:
|
|
user_hermes_path = BASE_HERMES_PATH
|
|
hermes_dot_path = HERMES_HOME
|
|
|
|
# Resolve Python interpreter path dynamically
|
|
python_path = os.path.join(BASE_HERMES_PATH, ".venv", "bin", "python3")
|
|
cmd = [python_path, "-m", "hermes_cli.main"] + command_args
|
|
|
|
env = os.environ.copy()
|
|
env['HOME'] = hermes_dot_path
|
|
env['HERMES_USER_ID'] = str(user_id or 'anonymous')
|
|
env['HERMES_SESSION_ID'] = str(uuid.uuid4())
|
|
|
|
process = await asyncio.create_subprocess_exec(
|
|
*cmd,
|
|
cwd=user_hermes_path,
|
|
env=env,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
|
|
try:
|
|
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
|
return {
|
|
'success': process.returncode == 0,
|
|
'stdout': stdout.decode('utf-8', errors='replace'),
|
|
'stderr': stderr.decode('utf-8', errors='replace'),
|
|
'returncode': process.returncode
|
|
}
|
|
except asyncio.TimeoutError:
|
|
process.kill()
|
|
await process.wait()
|
|
return {
|
|
'success': False,
|
|
'stdout': '',
|
|
'stderr': f'Command timed out after {timeout} seconds',
|
|
'returncode': -1
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
'success': False,
|
|
'stdout': '',
|
|
'stderr': str(e),
|
|
'returncode': -1
|
|
}
|
|
|
|
os.makedirs(USERS_BASE, exist_ok=True, mode=0o755)
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(
|
|
app,
|
|
host=config['service']['host'],
|
|
port=config['service']['port'],
|
|
log_level=config['service']['log_level']
|
|
) |