rag
This commit is contained in:
parent
68a9c43390
commit
524f1e93d1
139
rag/transaction_manager.py
Normal file
139
rag/transaction_manager.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Dict, Any, Optional, Callable
|
||||||
|
from datetime import datetime
|
||||||
|
import traceback
|
||||||
|
from appPublic.log import debug, error, info
|
||||||
|
|
||||||
|
|
||||||
|
class OperationType(Enum):
|
||||||
|
"""操作类型枚举"""
|
||||||
|
FILE_LOAD = "file_load"
|
||||||
|
EMBEDDING = "embedding"
|
||||||
|
VDB_INSERT = "vdb_insert"
|
||||||
|
TRIPLES_EXTRACT = "triples_extract"
|
||||||
|
GDB_INSERT = "gdb_insert"
|
||||||
|
VDB_DELETE = "vdb_delete"
|
||||||
|
GDB_DELETE = "gdb_delete"
|
||||||
|
ENTITY_EXTRACT = "entity_extract"
|
||||||
|
TRIPLET_MATCH = "triplet_match"
|
||||||
|
VECTOR_SEARCH = "vector_search"
|
||||||
|
RERANK = "rerank"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RollbackOperation:
|
||||||
|
"""回滚操作记录"""
|
||||||
|
operation_type: OperationType
|
||||||
|
data: Dict[str, Any]
|
||||||
|
timestamp: str
|
||||||
|
rollback_func: Optional[Callable] = None # 自定义回滚函数
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionManager:
|
||||||
|
"""事务管理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.rollback_ops: List[RollbackOperation] = []
|
||||||
|
self.transaction_id: str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||||
|
|
||||||
|
def add_operation(self, operation_type: OperationType, data: Dict[str, Any],
|
||||||
|
rollback_func: Optional[Callable] = None):
|
||||||
|
"""添加操作记录"""
|
||||||
|
operation = RollbackOperation(
|
||||||
|
operation_type=operation_type,
|
||||||
|
data=data,
|
||||||
|
timestamp=datetime.now().isoformat(),
|
||||||
|
rollback_func=rollback_func
|
||||||
|
)
|
||||||
|
self.rollback_ops.append(operation)
|
||||||
|
debug(f"事务 {self.transaction_id} 添加操作: {operation_type.value}")
|
||||||
|
|
||||||
|
def get_operations_count(self) -> int:
|
||||||
|
"""获取操作数量"""
|
||||||
|
return len(self.rollback_ops)
|
||||||
|
|
||||||
|
async def rollback_all(self, rollback_context: Dict[str, Any] = None) -> List[str]:
|
||||||
|
"""执行所有回滚操作"""
|
||||||
|
debug(f"事务 {self.transaction_id} 开始回滚,共 {len(self.rollback_ops)} 个操作")
|
||||||
|
rollback_results = []
|
||||||
|
|
||||||
|
# 按相反顺序执行回滚
|
||||||
|
for op in reversed(self.rollback_ops):
|
||||||
|
try:
|
||||||
|
debug(f"回滚操作: {op.operation_type.value}")
|
||||||
|
|
||||||
|
if op.rollback_func:
|
||||||
|
# 执行自定义回滚函数
|
||||||
|
result = await op.rollback_func(op.data, rollback_context)
|
||||||
|
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
|
||||||
|
else:
|
||||||
|
# 默认回滚处理
|
||||||
|
result = await self._default_rollback(op, rollback_context)
|
||||||
|
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error(f"回滚操作失败 {op.operation_type.value}: {str(e)}")
|
||||||
|
rollback_results.append(f"回滚失败: {op.operation_type.value} - {str(e)}")
|
||||||
|
|
||||||
|
return rollback_results
|
||||||
|
|
||||||
|
async def _default_rollback(self, op: RollbackOperation, context: Dict[str, Any]) -> str:
|
||||||
|
"""默认回滚处理"""
|
||||||
|
if op.operation_type in [OperationType.FILE_LOAD, OperationType.EMBEDDING,
|
||||||
|
OperationType.TRIPLES_EXTRACT, OperationType.ENTITY_EXTRACT]:
|
||||||
|
# 内存操作,无需特殊回滚
|
||||||
|
return "内存操作,已自动清理"
|
||||||
|
elif op.operation_type in [OperationType.VDB_INSERT, OperationType.VDB_DELETE,
|
||||||
|
OperationType.GDB_INSERT, OperationType.GDB_DELETE]:
|
||||||
|
# 数据库操作需要在具体实现中处理
|
||||||
|
return "需要在具体实现中处理数据库回滚"
|
||||||
|
else:
|
||||||
|
return f"跳过回滚操作: {op.operation_type.value}"
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""清空事务记录"""
|
||||||
|
self.rollback_ops.clear()
|
||||||
|
debug(f"事务 {self.transaction_id} 已清空")
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionContext:
|
||||||
|
"""事务上下文管理器"""
|
||||||
|
|
||||||
|
def __init__(self, transaction_name: str = ""):
|
||||||
|
self.transaction_name = transaction_name
|
||||||
|
self.transaction_manager = TransactionManager()
|
||||||
|
self.start_time = None
|
||||||
|
self.end_time = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self.start_time = datetime.now()
|
||||||
|
info(f"开始事务: {self.transaction_name} [{self.transaction_manager.transaction_id}]")
|
||||||
|
return self.transaction_manager
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.end_time = datetime.now()
|
||||||
|
duration = (self.end_time - self.start_time).total_seconds()
|
||||||
|
|
||||||
|
if exc_type is not None:
|
||||||
|
# 发生异常,执行回滚
|
||||||
|
error(f"事务失败: {self.transaction_name}, 异常: {exc_val}")
|
||||||
|
rollback_results = await self.transaction_manager.rollback_all()
|
||||||
|
error(f"回滚结果: {rollback_results}")
|
||||||
|
info(f"事务回滚完成: {self.transaction_name}, 耗时: {duration:.2f}秒")
|
||||||
|
else:
|
||||||
|
# 正常完成
|
||||||
|
info(f"事务成功: {self.transaction_name}, 耗时: {duration:.2f}秒")
|
||||||
|
|
||||||
|
self.transaction_manager.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# 工厂函数
|
||||||
|
def create_transaction_manager() -> TransactionManager:
|
||||||
|
"""创建事务管理器"""
|
||||||
|
return TransactionManager()
|
||||||
|
|
||||||
|
|
||||||
|
def create_transaction_context(name: str = "") -> TransactionContext:
|
||||||
|
"""创建事务上下文"""
|
||||||
|
return TransactionContext(name)
|
||||||
Loading…
x
Reference in New Issue
Block a user