From cddb4733fb4cbe51c21e87a9e201864a4d3390b0 Mon Sep 17 00:00:00 2001 From: wangmeihua <13383952685@163.com> Date: Mon, 15 Sep 2025 18:06:45 +0800 Subject: [PATCH] rag --- rag/folderinfo.py | 133 +++++++++++++++++++++++++------------ rag/rag_operations.py | 28 ++++++-- rag/transaction_manager.py | 31 ++++++--- 3 files changed, 135 insertions(+), 57 deletions(-) diff --git a/rag/folderinfo.py b/rag/folderinfo.py index d858441..5ec2637 100644 --- a/rag/folderinfo.py +++ b/rag/folderinfo.py @@ -20,7 +20,7 @@ from typing import List, Dict, Any from rag.service_opts import get_service_params, sor_get_service_params from rag.rag_operations import RagOperations import json - +from rag.transaction_manager import TransactionContext from dataclasses import dataclass from enum import Enum @@ -53,7 +53,6 @@ where a.orgid = b.orgid return r.quota, r.expired_date return None, None - async def file_uploaded(self, request, ns, userid): """将文档插入 Milvus 并抽取三元组到 Neo4j""" debug(f'Received ns: {ns=}') @@ -64,53 +63,103 @@ where a.orgid = b.orgid orgid = ns.get('ownerid', '') db_type = '' - debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}') + debug( + f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}') timings = {} start_total = time.time() + result = { + "status": "error", + "userid": orgid, + "document_id": id, + "collection_name": "ragdb", + "timings": timings, + "message": "", + "status_code": 400 + } - try: - if not orgid or not fiid or not id: - raise ValueError("orgid、fiid 和 id 不能为空") - if len(orgid) > 32 or len(fiid) > 255: - raise ValueError("orgid 或 fiid 的长度超出限制") - if not os.path.exists(realpath): - raise ValueError(f"文件 {realpath} 不存在") + # 初始化回滚上下文 + rollback_context = { + "request": request, + "userid": userid, + "service_params": None # 在 try 块中设置 + } - # 获取服务参数 - service_params = await get_service_params(orgid) - if not service_params: - raise ValueError("无法获取服务参数") + async with TransactionContext(f"file_upload_{id}") as transaction_mgr: + # 将 rollback_context 绑定到 TransactionContext + transaction_mgr.transaction_context = rollback_context + try: + # 验证必填字段 + if not orgid or not fiid or not id: + raise ValueError("orgid、fiid 和 id 不能为空") + if len(orgid) > 32 or len(fiid) > 255: + raise ValueError("orgid 或 fiid 的长度超出限制") + if not os.path.exists(realpath): + raise ValueError(f"文件 {realpath} 不存在") - chunks = await self.rag_ops.load_and_chunk_document(realpath, timings) - embeddings = await self.rag_ops.generate_embeddings(request, chunks, service_params, userid, timings) - await self.rag_ops.insert_to_vector_db(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings) - triples = await self.rag_ops.extract_triples(request, chunks, service_params, userid, timings) - await self.rag_ops.insert_to_graph_db(request, triples, id, fiid, orgid, service_params, userid, timings) + # 获取服务参数 + service_params = await get_service_params(orgid) + if not service_params: + raise ValueError("无法获取服务参数") + rollback_context["service_params"] = service_params - timings["total"] = time.time() - start_total - debug(f"总耗时: {timings['total']:.2f} 秒") - return { - "status": "success", - "userid": orgid, - "document_id": id, - "collection_name": "ragdb", - "timings": timings, - "unique_triples": triples, - "message": f"文件 {realpath} 成功嵌入并处理三元组", - "status_code": 200 - } - except Exception as e: - error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}") - timings["total"] = time.time() - start_total - return { - "status": "error", - "document_id": id, - "collection_name": "ragdb", - "timings": timings, - "message": f"插入文档失败: {str(e)}", - "status_code": 400 - } + # 加载和分片文档 + chunks = await self.rag_ops.load_and_chunk_document( + realpath, timings, transaction_mgr=transaction_mgr + ) + + # 生成嵌入向量 + embeddings = await self.rag_ops.generate_embeddings( + request, chunks, service_params, userid, timings, transaction_mgr=transaction_mgr + ) + + # 插入 Milvus + chunks_data = await self.rag_ops.insert_to_vector_db( + request, chunks, embeddings, realpath, orgid, fiid, id, + service_params, userid, db_type, timings, transaction_mgr=transaction_mgr + ) + + # 抽取三元组 + triples = await self.rag_ops.extract_triples( + request, chunks, service_params, userid, timings, transaction_mgr=transaction_mgr + ) + + # 插入 Neo4j + await self.rag_ops.insert_to_graph_db( + request, triples, id, fiid, orgid, service_params, userid, timings, transaction_mgr=transaction_mgr + ) + + timings["total"] = time.time() - start_total + result.update({ + "status": "success", + "unique_triples": triples, + "message": f"文件 {realpath} 成功嵌入并处理三元组", + "status_code": 200 + }) + debug(f"总耗时: {timings['total']:.2f} 秒") + + except Exception as e: + error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}") + timings["total"] = time.time() - start_total + result.update({ + "message": f"插入文档失败: {str(e)}", + "timings": timings + }) + # 记录回滚日志 + rollback_log = { + "transaction_id": transaction_mgr.transaction_id, + "document_id": id, + "realpath": realpath, + "action": "rollback", + "reason": f"插入失败: {str(e)}", + "timestamp": datetime.now().isoformat() + } + with open('/home/wangmeihua/kyrag/logs/rollback.log', 'a') as f: + f.write(f"{json.dumps(rollback_log, ensure_ascii=False)}\n") + # 提供回滚上下文 + raise ValueError(str(e)) from e + debug(f"最终结果是:{result}") + return result async def file_deleted(self, request, recs, userid): """删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录""" diff --git a/rag/rag_operations.py b/rag/rag_operations.py index 3bd771b..e9a0d41 100644 --- a/rag/rag_operations.py +++ b/rag/rag_operations.py @@ -148,12 +148,27 @@ class RagOperations: # 记录事务操作,包含回滚函数 if transaction_mgr: async def rollback_vdb_insert(data, context): - await self.delete_from_vector_db( - context['request'], data['orgid'], data['realpath'], - data['fiid'], data['id'], context['service_params'], - context['userid'], data['db_type'] - ) - return f"已回滚向量数据库插入: {data['id']}" + try: + # 防御性检查 + required_context = ['request', 'service_params', 'userid'] + missing_context = [k for k in required_context if k not in context or context[k] is None] + if missing_context: + raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}") + + required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type'] + missing_data = [k for k in required_data if k not in data or data[k] is None] + if missing_data: + raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}") + + await self.delete_from_vector_db( + context['request'], data['orgid'], data['realpath'], + data['fiid'], data['id'], context['service_params'], + context['userid'], data['db_type'] + ) + return f"已回滚向量数据库插入: {data['id']}" + except Exception as e: + error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}") + raise transaction_mgr.add_operation( OperationType.VDB_INSERT, @@ -207,6 +222,7 @@ class RagOperations: async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str, orgid: str, service_params: Dict, userid: str, timings: Dict, transaction_mgr: TransactionManager = None): + """插入图数据库""" debug(f"插入 {len(triples)} 个三元组到 Neo4j") start_neo4j = time.time() diff --git a/rag/transaction_manager.py b/rag/transaction_manager.py index a593105..6303549 100644 --- a/rag/transaction_manager.py +++ b/rag/transaction_manager.py @@ -27,7 +27,7 @@ class RollbackOperation: operation_type: OperationType data: Dict[str, Any] timestamp: str - rollback_func: Optional[Callable] = None # 自定义回滚函数 + rollback_func: Optional[Callable] = None class TransactionManager: @@ -38,7 +38,7 @@ class TransactionManager: self.transaction_id: str = datetime.now().strftime("%Y%m%d_%H%M%S_%f") def add_operation(self, operation_type: OperationType, data: Dict[str, Any], - rollback_func: Optional[Callable] = None): + rollback_func: Optional[Callable] = None): """添加操作记录""" operation = RollbackOperation( operation_type=operation_type, @@ -57,6 +57,7 @@ class TransactionManager: """执行所有回滚操作""" debug(f"事务 {self.transaction_id} 开始回滚,共 {len(self.rollback_ops)} 个操作") rollback_results = [] + rollback_context = rollback_context or {} # 按相反顺序执行回滚 for op in reversed(self.rollback_ops): @@ -64,6 +65,19 @@ class TransactionManager: debug(f"回滚操作: {op.operation_type.value}") if op.rollback_func: + # 验证上下文和数据 + required_context_keys = ['request', 'service_params', 'userid'] + missing_context = [k for k in required_context_keys if + k not in rollback_context or rollback_context[k] is None] + if missing_context: + raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}") + + required_data_keys = ['orgid', 'realpath', 'fiid', 'id', 'db_type'] + if op.operation_type == OperationType.VDB_INSERT: + missing_data = [k for k in required_data_keys if k not in op.data or op.data[k] is None] + if missing_data: + raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}") + # 执行自定义回滚函数 result = await op.rollback_func(op.data, rollback_context) rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}") @@ -73,7 +87,7 @@ class TransactionManager: rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}") except Exception as e: - error(f"回滚操作失败 {op.operation_type.value}: {str(e)}") + error(f"回滚操作失败 {op.operation_type.value}: {str(e)}, 堆栈: {traceback.format_exc()}") rollback_results.append(f"回滚失败: {op.operation_type.value} - {str(e)}") return rollback_results @@ -81,12 +95,10 @@ class TransactionManager: async def _default_rollback(self, op: RollbackOperation, context: Dict[str, Any]) -> str: """默认回滚处理""" if op.operation_type in [OperationType.FILE_LOAD, OperationType.EMBEDDING, - OperationType.TRIPLES_EXTRACT, OperationType.ENTITY_EXTRACT]: - # 内存操作,无需特殊回滚 + OperationType.TRIPLES_EXTRACT, OperationType.ENTITY_EXTRACT]: return "内存操作,已自动清理" elif op.operation_type in [OperationType.VDB_INSERT, OperationType.VDB_DELETE, - OperationType.GDB_INSERT, OperationType.GDB_DELETE]: - # 数据库操作需要在具体实现中处理 + OperationType.GDB_INSERT, OperationType.GDB_DELETE]: return "需要在具体实现中处理数据库回滚" else: return f"跳过回滚操作: {op.operation_type.value}" @@ -118,7 +130,9 @@ class TransactionContext: if exc_type is not None: # 发生异常,执行回滚 error(f"事务失败: {self.transaction_name}, 异常: {exc_val}") - rollback_results = await self.transaction_manager.rollback_all() + # 使用 transaction_manager 的 rollback_context + rollback_context = getattr(self.transaction_manager, 'transaction_context', {}) + rollback_results = await self.transaction_manager.rollback_all(rollback_context) error(f"回滚结果: {rollback_results}") info(f"事务回滚完成: {self.transaction_name}, 耗时: {duration:.2f}秒") else: @@ -128,7 +142,6 @@ class TransactionContext: self.transaction_manager.clear() -# 工厂函数 def create_transaction_manager() -> TransactionManager: """创建事务管理器""" return TransactionManager()