rag
This commit is contained in:
parent
68a9c43390
commit
524f1e93d1
139
rag/transaction_manager.py
Normal file
139
rag/transaction_manager.py
Normal file
@ -0,0 +1,139 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
from appPublic.log import debug, error, info
|
||||
|
||||
|
||||
class OperationType(Enum):
|
||||
"""操作类型枚举"""
|
||||
FILE_LOAD = "file_load"
|
||||
EMBEDDING = "embedding"
|
||||
VDB_INSERT = "vdb_insert"
|
||||
TRIPLES_EXTRACT = "triples_extract"
|
||||
GDB_INSERT = "gdb_insert"
|
||||
VDB_DELETE = "vdb_delete"
|
||||
GDB_DELETE = "gdb_delete"
|
||||
ENTITY_EXTRACT = "entity_extract"
|
||||
TRIPLET_MATCH = "triplet_match"
|
||||
VECTOR_SEARCH = "vector_search"
|
||||
RERANK = "rerank"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RollbackOperation:
|
||||
"""回滚操作记录"""
|
||||
operation_type: OperationType
|
||||
data: Dict[str, Any]
|
||||
timestamp: str
|
||||
rollback_func: Optional[Callable] = None # 自定义回滚函数
|
||||
|
||||
|
||||
class TransactionManager:
|
||||
"""事务管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.rollback_ops: List[RollbackOperation] = []
|
||||
self.transaction_id: str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
|
||||
def add_operation(self, operation_type: OperationType, data: Dict[str, Any],
|
||||
rollback_func: Optional[Callable] = None):
|
||||
"""添加操作记录"""
|
||||
operation = RollbackOperation(
|
||||
operation_type=operation_type,
|
||||
data=data,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
rollback_func=rollback_func
|
||||
)
|
||||
self.rollback_ops.append(operation)
|
||||
debug(f"事务 {self.transaction_id} 添加操作: {operation_type.value}")
|
||||
|
||||
def get_operations_count(self) -> int:
|
||||
"""获取操作数量"""
|
||||
return len(self.rollback_ops)
|
||||
|
||||
async def rollback_all(self, rollback_context: Dict[str, Any] = None) -> List[str]:
|
||||
"""执行所有回滚操作"""
|
||||
debug(f"事务 {self.transaction_id} 开始回滚,共 {len(self.rollback_ops)} 个操作")
|
||||
rollback_results = []
|
||||
|
||||
# 按相反顺序执行回滚
|
||||
for op in reversed(self.rollback_ops):
|
||||
try:
|
||||
debug(f"回滚操作: {op.operation_type.value}")
|
||||
|
||||
if op.rollback_func:
|
||||
# 执行自定义回滚函数
|
||||
result = await op.rollback_func(op.data, rollback_context)
|
||||
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
|
||||
else:
|
||||
# 默认回滚处理
|
||||
result = await self._default_rollback(op, rollback_context)
|
||||
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
|
||||
|
||||
except Exception as e:
|
||||
error(f"回滚操作失败 {op.operation_type.value}: {str(e)}")
|
||||
rollback_results.append(f"回滚失败: {op.operation_type.value} - {str(e)}")
|
||||
|
||||
return rollback_results
|
||||
|
||||
async def _default_rollback(self, op: RollbackOperation, context: Dict[str, Any]) -> str:
|
||||
"""默认回滚处理"""
|
||||
if op.operation_type in [OperationType.FILE_LOAD, OperationType.EMBEDDING,
|
||||
OperationType.TRIPLES_EXTRACT, OperationType.ENTITY_EXTRACT]:
|
||||
# 内存操作,无需特殊回滚
|
||||
return "内存操作,已自动清理"
|
||||
elif op.operation_type in [OperationType.VDB_INSERT, OperationType.VDB_DELETE,
|
||||
OperationType.GDB_INSERT, OperationType.GDB_DELETE]:
|
||||
# 数据库操作需要在具体实现中处理
|
||||
return "需要在具体实现中处理数据库回滚"
|
||||
else:
|
||||
return f"跳过回滚操作: {op.operation_type.value}"
|
||||
|
||||
def clear(self):
|
||||
"""清空事务记录"""
|
||||
self.rollback_ops.clear()
|
||||
debug(f"事务 {self.transaction_id} 已清空")
|
||||
|
||||
|
||||
class TransactionContext:
|
||||
"""事务上下文管理器"""
|
||||
|
||||
def __init__(self, transaction_name: str = ""):
|
||||
self.transaction_name = transaction_name
|
||||
self.transaction_manager = TransactionManager()
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.start_time = datetime.now()
|
||||
info(f"开始事务: {self.transaction_name} [{self.transaction_manager.transaction_id}]")
|
||||
return self.transaction_manager
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
self.end_time = datetime.now()
|
||||
duration = (self.end_time - self.start_time).total_seconds()
|
||||
|
||||
if exc_type is not None:
|
||||
# 发生异常,执行回滚
|
||||
error(f"事务失败: {self.transaction_name}, 异常: {exc_val}")
|
||||
rollback_results = await self.transaction_manager.rollback_all()
|
||||
error(f"回滚结果: {rollback_results}")
|
||||
info(f"事务回滚完成: {self.transaction_name}, 耗时: {duration:.2f}秒")
|
||||
else:
|
||||
# 正常完成
|
||||
info(f"事务成功: {self.transaction_name}, 耗时: {duration:.2f}秒")
|
||||
|
||||
self.transaction_manager.clear()
|
||||
|
||||
|
||||
# 工厂函数
|
||||
def create_transaction_manager() -> TransactionManager:
|
||||
"""创建事务管理器"""
|
||||
return TransactionManager()
|
||||
|
||||
|
||||
def create_transaction_context(name: str = "") -> TransactionContext:
|
||||
"""创建事务上下文"""
|
||||
return TransactionContext(name)
|
||||
Loading…
x
Reference in New Issue
Block a user