This commit is contained in:
wangmeihua 2025-09-15 18:06:45 +08:00
parent 524f1e93d1
commit cddb4733fb
3 changed files with 135 additions and 57 deletions

View File

@ -20,7 +20,7 @@ from typing import List, Dict, Any
from rag.service_opts import get_service_params, sor_get_service_params from rag.service_opts import get_service_params, sor_get_service_params
from rag.rag_operations import RagOperations from rag.rag_operations import RagOperations
import json import json
from rag.transaction_manager import TransactionContext
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
@ -53,7 +53,6 @@ where a.orgid = b.orgid
return r.quota, r.expired_date return r.quota, r.expired_date
return None, None return None, None
async def file_uploaded(self, request, ns, userid): async def file_uploaded(self, request, ns, userid):
"""将文档插入 Milvus 并抽取三元组到 Neo4j""" """将文档插入 Milvus 并抽取三元组到 Neo4j"""
debug(f'Received ns: {ns=}') debug(f'Received ns: {ns=}')
@ -64,53 +63,103 @@ where a.orgid = b.orgid
orgid = ns.get('ownerid', '') orgid = ns.get('ownerid', '')
db_type = '' db_type = ''
debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}') debug(
f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
timings = {} timings = {}
start_total = time.time() start_total = time.time()
result = {
"status": "error",
"userid": orgid,
"document_id": id,
"collection_name": "ragdb",
"timings": timings,
"message": "",
"status_code": 400
}
try: # 初始化回滚上下文
if not orgid or not fiid or not id: rollback_context = {
raise ValueError("orgid、fiid 和 id 不能为空") "request": request,
if len(orgid) > 32 or len(fiid) > 255: "userid": userid,
raise ValueError("orgid 或 fiid 的长度超出限制") "service_params": None # 在 try 块中设置
if not os.path.exists(realpath): }
raise ValueError(f"文件 {realpath} 不存在")
# 获取服务参数 async with TransactionContext(f"file_upload_{id}") as transaction_mgr:
service_params = await get_service_params(orgid) # 将 rollback_context 绑定到 TransactionContext
if not service_params: transaction_mgr.transaction_context = rollback_context
raise ValueError("无法获取服务参数") try:
# 验证必填字段
if not orgid or not fiid or not id:
raise ValueError("orgid、fiid 和 id 不能为空")
if len(orgid) > 32 or len(fiid) > 255:
raise ValueError("orgid 或 fiid 的长度超出限制")
if not os.path.exists(realpath):
raise ValueError(f"文件 {realpath} 不存在")
chunks = await self.rag_ops.load_and_chunk_document(realpath, timings) # 获取服务参数
embeddings = await self.rag_ops.generate_embeddings(request, chunks, service_params, userid, timings) service_params = await get_service_params(orgid)
await self.rag_ops.insert_to_vector_db(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings) if not service_params:
triples = await self.rag_ops.extract_triples(request, chunks, service_params, userid, timings) raise ValueError("无法获取服务参数")
await self.rag_ops.insert_to_graph_db(request, triples, id, fiid, orgid, service_params, userid, timings) rollback_context["service_params"] = service_params
timings["total"] = time.time() - start_total # 加载和分片文档
debug(f"总耗时: {timings['total']:.2f}") chunks = await self.rag_ops.load_and_chunk_document(
return { realpath, timings, transaction_mgr=transaction_mgr
"status": "success", )
"userid": orgid,
"document_id": id, # 生成嵌入向量
"collection_name": "ragdb", embeddings = await self.rag_ops.generate_embeddings(
"timings": timings, request, chunks, service_params, userid, timings, transaction_mgr=transaction_mgr
"unique_triples": triples, )
"message": f"文件 {realpath} 成功嵌入并处理三元组",
"status_code": 200 # 插入 Milvus
} chunks_data = await self.rag_ops.insert_to_vector_db(
except Exception as e: request, chunks, embeddings, realpath, orgid, fiid, id,
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}") service_params, userid, db_type, timings, transaction_mgr=transaction_mgr
timings["total"] = time.time() - start_total )
return {
"status": "error", # 抽取三元组
"document_id": id, triples = await self.rag_ops.extract_triples(
"collection_name": "ragdb", request, chunks, service_params, userid, timings, transaction_mgr=transaction_mgr
"timings": timings, )
"message": f"插入文档失败: {str(e)}",
"status_code": 400 # 插入 Neo4j
} await self.rag_ops.insert_to_graph_db(
request, triples, id, fiid, orgid, service_params, userid, timings, transaction_mgr=transaction_mgr
)
timings["total"] = time.time() - start_total
result.update({
"status": "success",
"unique_triples": triples,
"message": f"文件 {realpath} 成功嵌入并处理三元组",
"status_code": 200
})
debug(f"总耗时: {timings['total']:.2f}")
except Exception as e:
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total
result.update({
"message": f"插入文档失败: {str(e)}",
"timings": timings
})
# 记录回滚日志
rollback_log = {
"transaction_id": transaction_mgr.transaction_id,
"document_id": id,
"realpath": realpath,
"action": "rollback",
"reason": f"插入失败: {str(e)}",
"timestamp": datetime.now().isoformat()
}
with open('/home/wangmeihua/kyrag/logs/rollback.log', 'a') as f:
f.write(f"{json.dumps(rollback_log, ensure_ascii=False)}\n")
# 提供回滚上下文
raise ValueError(str(e)) from e
debug(f"最终结果是:{result}")
return result
async def file_deleted(self, request, recs, userid): async def file_deleted(self, request, recs, userid):
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录""" """删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""

View File

@ -148,12 +148,27 @@ class RagOperations:
# 记录事务操作,包含回滚函数 # 记录事务操作,包含回滚函数
if transaction_mgr: if transaction_mgr:
async def rollback_vdb_insert(data, context): async def rollback_vdb_insert(data, context):
await self.delete_from_vector_db( try:
context['request'], data['orgid'], data['realpath'], # 防御性检查
data['fiid'], data['id'], context['service_params'], required_context = ['request', 'service_params', 'userid']
context['userid'], data['db_type'] missing_context = [k for k in required_context if k not in context or context[k] is None]
) if missing_context:
return f"已回滚向量数据库插入: {data['id']}" raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
missing_data = [k for k in required_data if k not in data or data[k] is None]
if missing_data:
raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
await self.delete_from_vector_db(
context['request'], data['orgid'], data['realpath'],
data['fiid'], data['id'], context['service_params'],
context['userid'], data['db_type']
)
return f"已回滚向量数据库插入: {data['id']}"
except Exception as e:
error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
raise
transaction_mgr.add_operation( transaction_mgr.add_operation(
OperationType.VDB_INSERT, OperationType.VDB_INSERT,
@ -207,6 +222,7 @@ class RagOperations:
async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str, async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str,
orgid: str, service_params: Dict, userid: str, timings: Dict, orgid: str, service_params: Dict, userid: str, timings: Dict,
transaction_mgr: TransactionManager = None): transaction_mgr: TransactionManager = None):
"""插入图数据库""" """插入图数据库"""
debug(f"插入 {len(triples)} 个三元组到 Neo4j") debug(f"插入 {len(triples)} 个三元组到 Neo4j")
start_neo4j = time.time() start_neo4j = time.time()

View File

@ -27,7 +27,7 @@ class RollbackOperation:
operation_type: OperationType operation_type: OperationType
data: Dict[str, Any] data: Dict[str, Any]
timestamp: str timestamp: str
rollback_func: Optional[Callable] = None # 自定义回滚函数 rollback_func: Optional[Callable] = None
class TransactionManager: class TransactionManager:
@ -38,7 +38,7 @@ class TransactionManager:
self.transaction_id: str = datetime.now().strftime("%Y%m%d_%H%M%S_%f") self.transaction_id: str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
def add_operation(self, operation_type: OperationType, data: Dict[str, Any], def add_operation(self, operation_type: OperationType, data: Dict[str, Any],
rollback_func: Optional[Callable] = None): rollback_func: Optional[Callable] = None):
"""添加操作记录""" """添加操作记录"""
operation = RollbackOperation( operation = RollbackOperation(
operation_type=operation_type, operation_type=operation_type,
@ -57,6 +57,7 @@ class TransactionManager:
"""执行所有回滚操作""" """执行所有回滚操作"""
debug(f"事务 {self.transaction_id} 开始回滚,共 {len(self.rollback_ops)} 个操作") debug(f"事务 {self.transaction_id} 开始回滚,共 {len(self.rollback_ops)} 个操作")
rollback_results = [] rollback_results = []
rollback_context = rollback_context or {}
# 按相反顺序执行回滚 # 按相反顺序执行回滚
for op in reversed(self.rollback_ops): for op in reversed(self.rollback_ops):
@ -64,6 +65,19 @@ class TransactionManager:
debug(f"回滚操作: {op.operation_type.value}") debug(f"回滚操作: {op.operation_type.value}")
if op.rollback_func: if op.rollback_func:
# 验证上下文和数据
required_context_keys = ['request', 'service_params', 'userid']
missing_context = [k for k in required_context_keys if
k not in rollback_context or rollback_context[k] is None]
if missing_context:
raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
required_data_keys = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
if op.operation_type == OperationType.VDB_INSERT:
missing_data = [k for k in required_data_keys if k not in op.data or op.data[k] is None]
if missing_data:
raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
# 执行自定义回滚函数 # 执行自定义回滚函数
result = await op.rollback_func(op.data, rollback_context) result = await op.rollback_func(op.data, rollback_context)
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}") rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
@ -73,7 +87,7 @@ class TransactionManager:
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}") rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
except Exception as e: except Exception as e:
error(f"回滚操作失败 {op.operation_type.value}: {str(e)}") error(f"回滚操作失败 {op.operation_type.value}: {str(e)}, 堆栈: {traceback.format_exc()}")
rollback_results.append(f"回滚失败: {op.operation_type.value} - {str(e)}") rollback_results.append(f"回滚失败: {op.operation_type.value} - {str(e)}")
return rollback_results return rollback_results
@ -81,12 +95,10 @@ class TransactionManager:
async def _default_rollback(self, op: RollbackOperation, context: Dict[str, Any]) -> str: async def _default_rollback(self, op: RollbackOperation, context: Dict[str, Any]) -> str:
"""默认回滚处理""" """默认回滚处理"""
if op.operation_type in [OperationType.FILE_LOAD, OperationType.EMBEDDING, if op.operation_type in [OperationType.FILE_LOAD, OperationType.EMBEDDING,
OperationType.TRIPLES_EXTRACT, OperationType.ENTITY_EXTRACT]: OperationType.TRIPLES_EXTRACT, OperationType.ENTITY_EXTRACT]:
# 内存操作,无需特殊回滚
return "内存操作,已自动清理" return "内存操作,已自动清理"
elif op.operation_type in [OperationType.VDB_INSERT, OperationType.VDB_DELETE, elif op.operation_type in [OperationType.VDB_INSERT, OperationType.VDB_DELETE,
OperationType.GDB_INSERT, OperationType.GDB_DELETE]: OperationType.GDB_INSERT, OperationType.GDB_DELETE]:
# 数据库操作需要在具体实现中处理
return "需要在具体实现中处理数据库回滚" return "需要在具体实现中处理数据库回滚"
else: else:
return f"跳过回滚操作: {op.operation_type.value}" return f"跳过回滚操作: {op.operation_type.value}"
@ -118,7 +130,9 @@ class TransactionContext:
if exc_type is not None: if exc_type is not None:
# 发生异常,执行回滚 # 发生异常,执行回滚
error(f"事务失败: {self.transaction_name}, 异常: {exc_val}") error(f"事务失败: {self.transaction_name}, 异常: {exc_val}")
rollback_results = await self.transaction_manager.rollback_all() # 使用 transaction_manager 的 rollback_context
rollback_context = getattr(self.transaction_manager, 'transaction_context', {})
rollback_results = await self.transaction_manager.rollback_all(rollback_context)
error(f"回滚结果: {rollback_results}") error(f"回滚结果: {rollback_results}")
info(f"事务回滚完成: {self.transaction_name}, 耗时: {duration:.2f}") info(f"事务回滚完成: {self.transaction_name}, 耗时: {duration:.2f}")
else: else:
@ -128,7 +142,6 @@ class TransactionContext:
self.transaction_manager.clear() self.transaction_manager.clear()
# 工厂函数
def create_transaction_manager() -> TransactionManager: def create_transaction_manager() -> TransactionManager:
"""创建事务管理器""" """创建事务管理器"""
return TransactionManager() return TransactionManager()