This commit is contained in:
wangmeihua 2025-09-15 15:40:03 +08:00
parent 68a9c43390
commit 524f1e93d1

139
rag/transaction_manager.py Normal file
View File

@ -0,0 +1,139 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Dict, Any, Optional, Callable
from datetime import datetime
import traceback
from appPublic.log import debug, error, info
class OperationType(Enum):
"""操作类型枚举"""
FILE_LOAD = "file_load"
EMBEDDING = "embedding"
VDB_INSERT = "vdb_insert"
TRIPLES_EXTRACT = "triples_extract"
GDB_INSERT = "gdb_insert"
VDB_DELETE = "vdb_delete"
GDB_DELETE = "gdb_delete"
ENTITY_EXTRACT = "entity_extract"
TRIPLET_MATCH = "triplet_match"
VECTOR_SEARCH = "vector_search"
RERANK = "rerank"
@dataclass
class RollbackOperation:
"""回滚操作记录"""
operation_type: OperationType
data: Dict[str, Any]
timestamp: str
rollback_func: Optional[Callable] = None # 自定义回滚函数
class TransactionManager:
"""事务管理器"""
def __init__(self):
self.rollback_ops: List[RollbackOperation] = []
self.transaction_id: str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
def add_operation(self, operation_type: OperationType, data: Dict[str, Any],
rollback_func: Optional[Callable] = None):
"""添加操作记录"""
operation = RollbackOperation(
operation_type=operation_type,
data=data,
timestamp=datetime.now().isoformat(),
rollback_func=rollback_func
)
self.rollback_ops.append(operation)
debug(f"事务 {self.transaction_id} 添加操作: {operation_type.value}")
def get_operations_count(self) -> int:
"""获取操作数量"""
return len(self.rollback_ops)
async def rollback_all(self, rollback_context: Dict[str, Any] = None) -> List[str]:
"""执行所有回滚操作"""
debug(f"事务 {self.transaction_id} 开始回滚,共 {len(self.rollback_ops)} 个操作")
rollback_results = []
# 按相反顺序执行回滚
for op in reversed(self.rollback_ops):
try:
debug(f"回滚操作: {op.operation_type.value}")
if op.rollback_func:
# 执行自定义回滚函数
result = await op.rollback_func(op.data, rollback_context)
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
else:
# 默认回滚处理
result = await self._default_rollback(op, rollback_context)
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
except Exception as e:
error(f"回滚操作失败 {op.operation_type.value}: {str(e)}")
rollback_results.append(f"回滚失败: {op.operation_type.value} - {str(e)}")
return rollback_results
async def _default_rollback(self, op: RollbackOperation, context: Dict[str, Any]) -> str:
"""默认回滚处理"""
if op.operation_type in [OperationType.FILE_LOAD, OperationType.EMBEDDING,
OperationType.TRIPLES_EXTRACT, OperationType.ENTITY_EXTRACT]:
# 内存操作,无需特殊回滚
return "内存操作,已自动清理"
elif op.operation_type in [OperationType.VDB_INSERT, OperationType.VDB_DELETE,
OperationType.GDB_INSERT, OperationType.GDB_DELETE]:
# 数据库操作需要在具体实现中处理
return "需要在具体实现中处理数据库回滚"
else:
return f"跳过回滚操作: {op.operation_type.value}"
def clear(self):
"""清空事务记录"""
self.rollback_ops.clear()
debug(f"事务 {self.transaction_id} 已清空")
class TransactionContext:
"""事务上下文管理器"""
def __init__(self, transaction_name: str = ""):
self.transaction_name = transaction_name
self.transaction_manager = TransactionManager()
self.start_time = None
self.end_time = None
async def __aenter__(self):
self.start_time = datetime.now()
info(f"开始事务: {self.transaction_name} [{self.transaction_manager.transaction_id}]")
return self.transaction_manager
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.end_time = datetime.now()
duration = (self.end_time - self.start_time).total_seconds()
if exc_type is not None:
# 发生异常,执行回滚
error(f"事务失败: {self.transaction_name}, 异常: {exc_val}")
rollback_results = await self.transaction_manager.rollback_all()
error(f"回滚结果: {rollback_results}")
info(f"事务回滚完成: {self.transaction_name}, 耗时: {duration:.2f}")
else:
# 正常完成
info(f"事务成功: {self.transaction_name}, 耗时: {duration:.2f}")
self.transaction_manager.clear()
# 工厂函数
def create_transaction_manager() -> TransactionManager:
"""创建事务管理器"""
return TransactionManager()
def create_transaction_context(name: str = "") -> TransactionContext:
"""创建事务上下文"""
return TransactionContext(name)